aruna_file/transformers/
zstd_decomp.rs1use crate::notifications::Message;
2use crate::notifications::Response;
3use crate::transformer::Transformer;
4use crate::transformer::TransformerType;
5use anyhow::Result;
6use async_compression::tokio::write::ZstdDecoder;
7use bytes::BufMut;
8use bytes::BytesMut;
9use tokio::io::AsyncWriteExt;
10use tracing::debug;
11
12const RAW_FRAME_SIZE: usize = 5_242_880;
13const CHUNK: usize = 65_536;
14
15pub struct ZstdDec {
16 internal_buf: ZstdDecoder<Vec<u8>>,
17 prev_buf: BytesMut,
18 finished: bool,
19 skip_me: bool,
20}
21
22impl ZstdDec {
23 #[tracing::instrument(level = "trace", skip())]
24 #[allow(dead_code)]
25 pub fn new() -> ZstdDec {
26 ZstdDec {
27 internal_buf: ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK)),
28 prev_buf: BytesMut::with_capacity(RAW_FRAME_SIZE + CHUNK),
29 finished: false,
30 skip_me: false,
31 }
32 }
33}
34
35impl Default for ZstdDec {
36 #[tracing::instrument(level = "trace", skip())]
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[async_trait::async_trait]
43impl Transformer for ZstdDec {
44 #[tracing::instrument(level = "trace", skip(self, buf, finished, should_flush))]
45 async fn process_bytes(
46 &mut self,
47 buf: &mut bytes::BytesMut,
48 finished: bool,
49 should_flush: bool,
50 ) -> Result<bool> {
51 if self.skip_me {
52 debug!("skipped zstd decoder");
53 return Ok(finished);
54 }
55 if should_flush {
56 debug!("flushed zstd decoder");
57 self.internal_buf.write_all_buf(buf).await?;
58 self.internal_buf.shutdown().await?;
59 self.prev_buf.put(self.internal_buf.get_ref().as_slice());
60 self.internal_buf = ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
61 buf.put(self.prev_buf.split().freeze());
62 return Ok(finished);
63 }
64
65 if !buf.is_empty() && !self.finished {
67 self.internal_buf.write_buf(buf).await?;
68 while !buf.is_empty() {
69 self.internal_buf.shutdown().await?;
70 self.prev_buf.put(self.internal_buf.get_ref().as_slice());
71 self.internal_buf = ZstdDecoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
72 self.internal_buf.write_buf(buf).await?;
73 }
74 }
75
76 if !self.finished && buf.is_empty() && finished {
77 debug!("finish zstd decoder");
78 self.internal_buf.shutdown().await?;
79 self.prev_buf.put(self.internal_buf.get_ref().as_slice());
80 self.finished = true;
81 }
82
83 buf.put(self.prev_buf.split().freeze());
84 Ok(self.finished && self.prev_buf.is_empty())
85 }
86
87 #[tracing::instrument(level = "trace", skip(self))]
88 #[inline]
89 fn get_type(&self) -> TransformerType {
90 TransformerType::ZstdDecompressor
91 }
92 #[tracing::instrument(level = "trace", skip(self, message))]
93 async fn notify(&mut self, message: &Message) -> Result<Response> {
94 if message.target == TransformerType::All {
95 if let crate::notifications::MessageData::NextFile(nfile) = &message.data {
96 self.skip_me = nfile.context.skip_decompression
97 }
98 }
99 Ok(Response::Ok)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105
106 use super::*;
107
108 #[tokio::test]
109 async fn test_zstd_decoder_with_skip() {
110 let mut decoder = ZstdDec::new();
111 let mut buf = BytesMut::new();
112 let expected = hex::decode(format!(
113 "28b52ffd00582900003132333435502a4d18eaff{}",
114 "00".repeat(65516)
115 ))
116 .unwrap();
117 buf.put(expected.as_slice());
118 decoder.process_bytes(&mut buf, true, false).await.unwrap();
119 assert_eq!(buf.len(), 5);
121 assert_eq!(buf, b"12345".as_slice());
122 }
123
124 #[tokio::test]
125 async fn test_zstd_encoder_without_skip() {
126 let mut decoder = ZstdDec::new();
127 let mut buf = BytesMut::new();
128 let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
129 buf.put(expected.as_slice());
130 decoder.process_bytes(&mut buf, true, true).await.unwrap();
131 assert_eq!(buf.len(), 5);
133 assert_eq!(buf, b"12345".as_slice());
134 }
135}