aruna_file/transformers/
zstd_decomp.rs

1use crate::notifications::Message;
2use crate::notifications::Response;
3use crate::transformer::Transformer;
4use crate::transformer::TransformerType;
5use anyhow::Result;
6use async_compression::tokio::write::ZstdDecoder;
7use bytes::BufMut;
8use bytes::BytesMut;
9use tokio::io::AsyncWriteExt;
10use tracing::debug;
11
12const RAW_FRAME_SIZE: usize = 5_242_880;
13const CHUNK: usize = 65_536;
14
15pub struct ZstdDec {
16    internal_buf: ZstdDecoder<Vec<u8>>,
17    prev_buf: BytesMut,
18    finished: bool,
19    skip_me: bool,
20}
21
22impl ZstdDec {
23    #[tracing::instrument(level = "trace", skip())]
24    #[allow(dead_code)]
25    pub fn new() -> ZstdDec {
26        ZstdDec {
27            internal_buf: ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK)),
28            prev_buf: BytesMut::with_capacity(RAW_FRAME_SIZE + CHUNK),
29            finished: false,
30            skip_me: false,
31        }
32    }
33}
34
35impl Default for ZstdDec {
36    #[tracing::instrument(level = "trace", skip())]
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42#[async_trait::async_trait]
43impl Transformer for ZstdDec {
44    #[tracing::instrument(level = "trace", skip(self, buf, finished, should_flush))]
45    async fn process_bytes(
46        &mut self,
47        buf: &mut bytes::BytesMut,
48        finished: bool,
49        should_flush: bool,
50    ) -> Result<bool> {
51        if self.skip_me {
52            debug!("skipped zstd decoder");
53            return Ok(finished);
54        }
55        if should_flush {
56            debug!("flushed zstd decoder");
57            self.internal_buf.write_all_buf(buf).await?;
58            self.internal_buf.shutdown().await?;
59            self.prev_buf.put(self.internal_buf.get_ref().as_slice());
60            self.internal_buf = ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
61            buf.put(self.prev_buf.split().freeze());
62            return Ok(finished);
63        }
64
65        // Only write if the buffer contains data and the current process is not finished
66        if !buf.is_empty() && !self.finished {
67            self.internal_buf.write_buf(buf).await?;
68            while !buf.is_empty() {
69                self.internal_buf.shutdown().await?;
70                self.prev_buf.put(self.internal_buf.get_ref().as_slice());
71                self.internal_buf = ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
72                self.internal_buf.write_buf(buf).await?;
73            }
74        }
75
76        if !self.finished && buf.is_empty() && finished {
77            debug!("finish zstd decoder");
78            self.internal_buf.shutdown().await?;
79            self.prev_buf.put(self.internal_buf.get_ref().as_slice());
80            self.finished = true;
81        }
82
83        buf.put(self.prev_buf.split().freeze());
84        Ok(self.finished && self.prev_buf.is_empty())
85    }
86
87    #[tracing::instrument(level = "trace", skip(self))]
88    #[inline]
89    fn get_type(&self) -> TransformerType {
90        TransformerType::ZstdDecompressor
91    }
92    #[tracing::instrument(level = "trace", skip(self, message))]
93    async fn notify(&mut self, message: &Message) -> Result<Response> {
94        if message.target == TransformerType::All {
95            if let crate::notifications::MessageData::NextFile(nfile) = &message.data {
96                self.skip_me = nfile.context.skip_decompression
97            }
98        }
99        Ok(Response::Ok)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105
106    use super::*;
107
108    #[tokio::test]
109    async fn test_zstd_decoder_with_skip() {
110        let mut decoder = ZstdDec::new();
111        let mut buf = BytesMut::new();
112        let expected = hex::decode(format!(
113            "28b52ffd00582900003132333435502a4d18eaff{}",
114            "00".repeat(65516)
115        ))
116        .unwrap();
117        buf.put(expected.as_slice());
118        decoder.process_bytes(&mut buf, true, false).await.unwrap();
119        // Expect 65kb size
120        assert_eq!(buf.len(), 5);
121        assert_eq!(buf, b"12345".as_slice());
122    }
123
124    #[tokio::test]
125    async fn test_zstd_encoder_without_skip() {
126        let mut decoder = ZstdDec::new();
127        let mut buf = BytesMut::new();
128        let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
129        buf.put(expected.as_slice());
130        decoder.process_bytes(&mut buf, true, true).await.unwrap();
131        // Expect 65kb size
132        assert_eq!(buf.len(), 5);
133        assert_eq!(buf, b"12345".as_slice());
134    }
135}