aruna_file/transformers/
zstd_comp.rs1use crate::notifications::FooterData;
2use crate::notifications::Message;
3use crate::notifications::MessageData;
4use crate::transformer::Transformer;
5use crate::transformer::TransformerType;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_channel::Sender;
9use async_compression::tokio::write::ZstdEncoder;
10use byteorder::LittleEndian;
11use byteorder::WriteBytesExt;
12use bytes::BufMut;
13use bytes::{Bytes, BytesMut};
14use tokio::io::AsyncWriteExt;
15use tracing::debug;
16use tracing::error;
17
18const RAW_FRAME_SIZE: usize = 5_242_880;
19const CHUNK: usize = 65_536;
20
21pub struct ZstdEnc {
22 internal_buf: ZstdEncoder<Vec<u8>>,
23 prev_buf: BytesMut,
24 size_counter: usize,
25 chunks: Vec<u8>,
26 is_last: bool,
27 finished: bool,
28 sender: Option<Sender<Message>>,
29}
30
31impl ZstdEnc {
32 #[tracing::instrument(level = "trace", skip(last))]
33 #[allow(dead_code)]
34 pub fn new(last: bool) -> Self {
35 ZstdEnc {
36 internal_buf: ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK)),
37 prev_buf: BytesMut::with_capacity(RAW_FRAME_SIZE + CHUNK),
38 size_counter: 0,
39 chunks: Vec::new(),
40 is_last: last,
41 finished: false,
42 sender: None,
43 }
44 }
45}
46
47#[async_trait::async_trait]
48impl Transformer for ZstdEnc {
49 #[tracing::instrument(level = "trace", skip(self, buf, finished, should_flush))]
50 async fn process_bytes(
51 &mut self,
52 buf: &mut bytes::BytesMut,
53 finished: bool,
54 should_flush: bool,
55 ) -> Result<bool> {
56 if should_flush {
57 debug!("flushed zstd encoder");
58 self.internal_buf.write_all_buf(buf).await?;
59 self.internal_buf.shutdown().await?;
60 self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
61 self.internal_buf = ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
62 buf.put(self.prev_buf.split().freeze());
63 return Ok(finished);
64 }
65
66 if self.size_counter + buf.len() > RAW_FRAME_SIZE {
68 let mut all_data = buf.split().freeze();
69
70 while self.size_counter + all_data.len() >= RAW_FRAME_SIZE {
71 let dif = RAW_FRAME_SIZE - self.size_counter;
73 assert!(dif <= RAW_FRAME_SIZE);
75 self.internal_buf
76 .write_all_buf(&mut all_data.split_to(dif))
77 .await?;
78 self.internal_buf.shutdown().await?;
80 self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
82 self.internal_buf = ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
84 self.add_skippable().await;
86 self.size_counter = 0;
88 self.chunks.push(u8::try_from(self.prev_buf.len() / CHUNK)?);
90 buf.put(self.prev_buf.split().freeze());
91 }
92 if !all_data.is_empty() {
93 assert!(all_data.len() <= RAW_FRAME_SIZE);
94 self.size_counter = all_data.len();
95 self.internal_buf.write_all_buf(&mut all_data).await?;
96 }
97
98 return Ok(self.finished && self.prev_buf.is_empty());
99 }
100
101 if !buf.is_empty() && !self.finished {
103 self.size_counter += buf.len();
104 assert!(self.size_counter <= RAW_FRAME_SIZE);
105 self.internal_buf.write_buf(buf).await?;
106 }
107
108 if !self.finished && finished && buf.is_empty() {
110 self.internal_buf.shutdown().await?;
111 self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
112 if !self.is_last {
113 self.add_skippable().await;
114 };
115 self.chunks.push(u8::try_from(self.prev_buf.len() / CHUNK)?);
116 buf.put(self.prev_buf.split().freeze());
117 if let Some(s) = &self.sender {
118 debug!(chunks = ?self.chunks, "sending footer");
119 s.send(Message {
120 target: TransformerType::FooterGenerator,
121 data: MessageData::Footer(FooterData {
122 chunks: self.chunks.clone(),
123 }),
124 })
125 .await?;
126 };
127 self.finished = true;
128 return Ok(self.finished && self.prev_buf.is_empty());
129 }
130 buf.put(self.prev_buf.split().freeze());
131 Ok(self.finished && self.prev_buf.is_empty())
132 }
133
134 #[tracing::instrument(level = "trace", skip(self, s))]
135 fn add_sender(&mut self, s: Sender<Message>) {
136 self.sender = Some(s);
137 }
138
139 #[tracing::instrument(level = "trace", skip(self))]
140 fn get_type(&self) -> TransformerType {
141 TransformerType::ZstdCompressor
142 }
143}
144
145impl ZstdEnc {
146 #[tracing::instrument(level = "trace", skip(self))]
147 async fn add_skippable(&mut self) {
148 if self.prev_buf.is_empty() {
150 return;
151 }
152 if CHUNK - (self.prev_buf.len() % CHUNK) > 8 {
153 self.prev_buf.extend(create_skippable_padding_frame(
154 CHUNK - (self.prev_buf.len() % CHUNK),
155 ));
156 } else {
157 self.prev_buf.extend(create_skippable_padding_frame(
158 (CHUNK - (self.prev_buf.len() % CHUNK)) + CHUNK,
159 ));
160 }
161 }
162}
163
164#[tracing::instrument(level = "trace", skip(size))]
165#[inline]
166fn create_skippable_padding_frame(size: usize) -> Result<Bytes> {
167 if size < 8 {
168 error!(size = size, "Size too small");
169 return Err(anyhow!("{size} is too small, minimum is 8 bytes"));
170 }
171 let mut frame = hex::decode("502A4D18")?;
173 WriteBytesExt::write_u32::<LittleEndian>(&mut frame, size as u32 - 8)?;
175 frame.extend(vec![0; size - 8]);
176 Ok(Bytes::from(frame))
177}
178
179#[cfg(test)]
180mod tests {
181
182 use super::*;
183
184 #[tokio::test]
185 async fn test_zstd_encoder_with_skip() {
186 let mut encoder = ZstdEnc::new(false);
187 let mut buf = BytesMut::new();
188 buf.put(b"12345".as_slice());
189 encoder.process_bytes(&mut buf, true, false).await.unwrap();
190 assert!(buf.starts_with(&hex::decode("28B52FFD").unwrap()));
192 assert_eq!(buf.len(), 65536);
194 let expected = hex::decode(format!(
195 "28b52ffd00582900003132333435502a4d18eaff{}",
196 "00".repeat(65516)
197 ))
198 .unwrap();
199 assert_eq!(buf.as_ref(), &expected)
200 }
201
202 #[tokio::test]
203 async fn test_zstd_encoder_without_skip() {
204 let mut encoder = ZstdEnc::new(true);
205 let mut buf = BytesMut::new();
206 buf.put(b"12345".as_slice());
207 encoder.process_bytes(&mut buf, true, false).await.unwrap();
208 assert!(buf.starts_with(&hex::decode("28B52FFD").unwrap()));
210 assert_eq!(buf.len(), 14);
212 let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
213 assert_eq!(buf.as_ref(), &expected)
214 }
215
216 #[tokio::test]
217 async fn test_zstd_encoder_with_notify() {
218 let mut encoder = ZstdEnc::new(true);
219 let mut buf = BytesMut::new();
220
221 let (sx, rx) = async_channel::unbounded::<Message>();
222
223 encoder.add_sender(sx);
224
225 buf.put(b"12345".as_slice());
226 assert!(encoder.process_bytes(&mut buf, true, false).await.unwrap());
227
228 let taken = buf.split();
229 assert!(taken.starts_with(&hex::decode("28B52FFD").unwrap()));
231 assert_eq!(taken.len(), 14);
233 let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
234 assert_eq!(taken.as_ref(), &expected);
235 let received = rx.recv().await.unwrap();
236 assert_eq!(
237 received,
238 Message {
239 target: TransformerType::FooterGenerator,
240 data: MessageData::Footer(FooterData { chunks: vec![0u8] })
241 }
242 )
243 }
244}