aruna_file/transformers/
zstd_comp.rs

1use crate::notifications::FooterData;
2use crate::notifications::Message;
3use crate::notifications::MessageData;
4use crate::transformer::Transformer;
5use crate::transformer::TransformerType;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_channel::Sender;
9use async_compression::tokio::write::ZstdEncoder;
10use byteorder::LittleEndian;
11use byteorder::WriteBytesExt;
12use bytes::BufMut;
13use bytes::{Bytes, BytesMut};
14use tokio::io::AsyncWriteExt;
15use tracing::debug;
16use tracing::error;
17
18const RAW_FRAME_SIZE: usize = 5_242_880;
19const CHUNK: usize = 65_536;
20
21pub struct ZstdEnc {
22    internal_buf: ZstdEncoder<Vec<u8>>,
23    prev_buf: BytesMut,
24    size_counter: usize,
25    chunks: Vec<u8>,
26    is_last: bool,
27    finished: bool,
28    sender: Option<Sender<Message>>,
29}
30
31impl ZstdEnc {
32    #[tracing::instrument(level = "trace", skip(last))]
33    #[allow(dead_code)]
34    pub fn new(last: bool) -> Self {
35        ZstdEnc {
36            internal_buf: ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK)),
37            prev_buf: BytesMut::with_capacity(RAW_FRAME_SIZE + CHUNK),
38            size_counter: 0,
39            chunks: Vec::new(),
40            is_last: last,
41            finished: false,
42            sender: None,
43        }
44    }
45}
46
47#[async_trait::async_trait]
48impl Transformer for ZstdEnc {
49    #[tracing::instrument(level = "trace", skip(self, buf, finished, should_flush))]
50    async fn process_bytes(
51        &mut self,
52        buf: &mut bytes::BytesMut,
53        finished: bool,
54        should_flush: bool,
55    ) -> Result<bool> {
56        if should_flush {
57            debug!("flushed zstd encoder");
58            self.internal_buf.write_all_buf(buf).await?;
59            self.internal_buf.shutdown().await?;
60            self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
61            self.internal_buf = ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
62            buf.put(self.prev_buf.split().freeze());
63            return Ok(finished);
64        }
65
66        // Create a new frame if buf would increase size_counter to more than RAW_FRAME_SIZE
67        if self.size_counter + buf.len() > RAW_FRAME_SIZE {
68            let mut all_data = buf.split().freeze();
69
70            while self.size_counter + all_data.len() >= RAW_FRAME_SIZE {
71                // Check how much bytes are missing
72                let dif = RAW_FRAME_SIZE - self.size_counter;
73                // Make sure that dif is <= RAW_FRAME_SIZE
74                assert!(dif <= RAW_FRAME_SIZE);
75                self.internal_buf
76                    .write_all_buf(&mut all_data.split_to(dif))
77                    .await?;
78                // Shut the writer down -> Calls flush()
79                self.internal_buf.shutdown().await?;
80                // Get data from the vector buffer to the "prev_buf" -> Output buffer
81                self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
82                // Create a new Encoder
83                self.internal_buf = ZstdEncoder::new(Vec::with_capacity(RAW_FRAME_SIZE + CHUNK));
84                // Add a skippable frame to the output buffer
85                self.add_skippable().await;
86                // Reset the size_counter
87                self.size_counter = 0;
88                // Add the number of chunks to the chunksvec (for indexing)
89                self.chunks.push(u8::try_from(self.prev_buf.len() / CHUNK)?);
90                buf.put(self.prev_buf.split().freeze());
91            }
92            if !all_data.is_empty() {
93                assert!(all_data.len() <= RAW_FRAME_SIZE);
94                self.size_counter = all_data.len();
95                self.internal_buf.write_all_buf(&mut all_data).await?;
96            }
97
98            return Ok(self.finished && self.prev_buf.is_empty());
99        }
100
101        // Only write if the buffer contains data and the current process is not finished
102        if !buf.is_empty() && !self.finished {
103            self.size_counter += buf.len();
104            assert!(self.size_counter <= RAW_FRAME_SIZE);
105            self.internal_buf.write_buf(buf).await?;
106        }
107
108        // Add the "last" skippable frame if the previous writer is finished but this one is not!
109        if !self.finished && finished && buf.is_empty() {
110            self.internal_buf.shutdown().await?;
111            self.prev_buf.extend_from_slice(self.internal_buf.get_ref());
112            if !self.is_last {
113                self.add_skippable().await;
114            };
115            self.chunks.push(u8::try_from(self.prev_buf.len() / CHUNK)?);
116            buf.put(self.prev_buf.split().freeze());
117            if let Some(s) = &self.sender {
118                debug!(chunks = ?self.chunks, "sending footer");
119                s.send(Message {
120                    target: TransformerType::FooterGenerator,
121                    data: MessageData::Footer(FooterData {
122                        chunks: self.chunks.clone(),
123                    }),
124                })
125                .await?;
126            };
127            self.finished = true;
128            return Ok(self.finished && self.prev_buf.is_empty());
129        }
130        buf.put(self.prev_buf.split().freeze());
131        Ok(self.finished && self.prev_buf.is_empty())
132    }
133
134    #[tracing::instrument(level = "trace", skip(self, s))]
135    fn add_sender(&mut self, s: Sender<Message>) {
136        self.sender = Some(s);
137    }
138
139    #[tracing::instrument(level = "trace", skip(self))]
140    fn get_type(&self) -> TransformerType {
141        TransformerType::ZstdCompressor
142    }
143}
144
145impl ZstdEnc {
146    #[tracing::instrument(level = "trace", skip(self))]
147    async fn add_skippable(&mut self) {
148        // No skippable frame needed if the buffer is empty
149        if self.prev_buf.is_empty() {
150            return;
151        }
152        if CHUNK - (self.prev_buf.len() % CHUNK) > 8 {
153            self.prev_buf.extend(create_skippable_padding_frame(
154                CHUNK - (self.prev_buf.len() % CHUNK),
155            ));
156        } else {
157            self.prev_buf.extend(create_skippable_padding_frame(
158                (CHUNK - (self.prev_buf.len() % CHUNK)) + CHUNK,
159            ));
160        }
161    }
162}
163
164#[tracing::instrument(level = "trace", skip(size))]
165#[inline]
166fn create_skippable_padding_frame(size: usize) -> Result<Bytes> {
167    if size < 8 {
168        error!(size = size, "Size too small");
169        return Err(anyhow!("{size} is too small, minimum is 8 bytes"));
170    }
171    // Add frame_header
172    let mut frame = hex::decode("502A4D18")?;
173    // 4 Bytes (little-endian) for size
174    WriteBytesExt::write_u32::<LittleEndian>(&mut frame, size as u32 - 8)?;
175    frame.extend(vec![0; size - 8]);
176    Ok(Bytes::from(frame))
177}
178
179#[cfg(test)]
180mod tests {
181
182    use super::*;
183
184    #[tokio::test]
185    async fn test_zstd_encoder_with_skip() {
186        let mut encoder = ZstdEnc::new(false);
187        let mut buf = BytesMut::new();
188        buf.put(b"12345".as_slice());
189        encoder.process_bytes(&mut buf, true, false).await.unwrap();
190        // Starts with magic zstd header (little-endian)
191        assert!(buf.starts_with(&hex::decode("28B52FFD").unwrap()));
192        // Expect 65kb size
193        assert_eq!(buf.len(), 65536);
194        let expected = hex::decode(format!(
195            "28b52ffd00582900003132333435502a4d18eaff{}",
196            "00".repeat(65516)
197        ))
198        .unwrap();
199        assert_eq!(buf.as_ref(), &expected)
200    }
201
202    #[tokio::test]
203    async fn test_zstd_encoder_without_skip() {
204        let mut encoder = ZstdEnc::new(true);
205        let mut buf = BytesMut::new();
206        buf.put(b"12345".as_slice());
207        encoder.process_bytes(&mut buf, true, false).await.unwrap();
208        // Starts with magic zstd header (little-endian)
209        assert!(buf.starts_with(&hex::decode("28B52FFD").unwrap()));
210        // Expect 14b size
211        assert_eq!(buf.len(), 14);
212        let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
213        assert_eq!(buf.as_ref(), &expected)
214    }
215
216    #[tokio::test]
217    async fn test_zstd_encoder_with_notify() {
218        let mut encoder = ZstdEnc::new(true);
219        let mut buf = BytesMut::new();
220
221        let (sx, rx) = async_channel::unbounded::<Message>();
222
223        encoder.add_sender(sx);
224
225        buf.put(b"12345".as_slice());
226        assert!(encoder.process_bytes(&mut buf, true, false).await.unwrap());
227
228        let taken = buf.split();
229        // Starts with magic zstd header (little-endian)
230        assert!(taken.starts_with(&hex::decode("28B52FFD").unwrap()));
231        // Expect 14b size
232        assert_eq!(taken.len(), 14);
233        let expected = hex::decode("28b52ffd00582900003132333435").unwrap();
234        assert_eq!(taken.as_ref(), &expected);
235        let received = rx.recv().await.unwrap();
236        assert_eq!(
237            received,
238            Message {
239                target: TransformerType::FooterGenerator,
240                data: MessageData::Footer(FooterData { chunks: vec![0u8] })
241            }
242        )
243    }
244}