Skip to main content

datum/io/
compression.rs

1use crate::stream::{BoxStream, Flow};
2use crate::{StreamError, StreamResult};
3use flate2::Compression as FlateCompression;
4use flate2::write::{GzDecoder, GzEncoder, ZlibEncoder};
5use flate2::{Decompress, FlushDecompress, Status};
6use std::collections::VecDeque;
7use std::io::Write;
8
9const DECOMPRESS_CHUNK_SIZE: usize = 8192;
10
11#[derive(Clone)]
12enum Terminal {
13    Complete,
14    Error(StreamError),
15}
16
17fn sticky_terminal<T>(terminal: &Terminal) -> Option<StreamResult<T>> {
18    match terminal {
19        Terminal::Complete => None,
20        Terminal::Error(error) => Some(Err(error.clone())),
21    }
22}
23
24fn codec_error<E: std::fmt::Display>(error: E) -> StreamError {
25    StreamError::Failed(error.to_string())
26}
27
28/// In-process byte-stream (de)compression flows backed by `flate2`. Mirrors Akka's `Compression`.
29///
30/// Each flow consumes and produces `Vec<u8>` chunks; output chunk boundaries do not line up with
31/// input chunk boundaries. A truncated compressed input fails the decompressing flows with
32/// [`StreamError`].
33pub struct Compression;
34
35impl Compression {
36    /// Compresses the byte stream in the gzip format.
37    #[must_use]
38    pub fn gzip() -> Flow<Vec<u8>, Vec<u8>> {
39        Flow::from_transform(|input| Box::new(CompressStream::gzip(input)) as BoxStream<Vec<u8>>)
40    }
41
42    /// Compresses the byte stream in the zlib/deflate format.
43    #[must_use]
44    pub fn deflate() -> Flow<Vec<u8>, Vec<u8>> {
45        Flow::from_transform(|input| Box::new(CompressStream::deflate(input)) as BoxStream<Vec<u8>>)
46    }
47
48    /// Decompresses a gzip-format byte stream (inverse of [`Compression::gzip`]).
49    #[must_use]
50    pub fn gunzip() -> Flow<Vec<u8>, Vec<u8>> {
51        Flow::from_transform(|input| {
52            Box::new(DecompressStream::gunzip(input)) as BoxStream<Vec<u8>>
53        })
54    }
55
56    /// Decompresses a zlib/deflate-format byte stream (inverse of [`Compression::deflate`]).
57    #[must_use]
58    pub fn inflate() -> Flow<Vec<u8>, Vec<u8>> {
59        Flow::from_transform(|input| Box::new(InflateStream::new(input)) as BoxStream<Vec<u8>>)
60    }
61}
62
63enum EncoderKind {
64    Gzip(GzEncoder<Vec<u8>>),
65    Deflate(ZlibEncoder<Vec<u8>>),
66}
67
68impl EncoderKind {
69    fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
70        match self {
71            Self::Gzip(codec) => codec.write_all(chunk),
72            Self::Deflate(codec) => codec.write_all(chunk),
73        }
74    }
75
76    fn try_finish(&mut self) -> std::io::Result<()> {
77        match self {
78            Self::Gzip(codec) => codec.try_finish(),
79            Self::Deflate(codec) => codec.try_finish(),
80        }
81    }
82
83    fn take_output(&mut self) -> Vec<u8> {
84        match self {
85            Self::Gzip(codec) => std::mem::take(codec.get_mut()),
86            Self::Deflate(codec) => std::mem::take(codec.get_mut()),
87        }
88    }
89}
90
91struct CompressStream {
92    input: BoxStream<Vec<u8>>,
93    codec: EncoderKind,
94    pending: VecDeque<Vec<u8>>,
95    finished: bool,
96    terminal: Option<Terminal>,
97}
98
99impl CompressStream {
100    fn gzip(input: BoxStream<Vec<u8>>) -> Self {
101        Self {
102            input,
103            codec: EncoderKind::Gzip(GzEncoder::new(Vec::new(), FlateCompression::default())),
104            pending: VecDeque::new(),
105            finished: false,
106            terminal: None,
107        }
108    }
109
110    fn deflate(input: BoxStream<Vec<u8>>) -> Self {
111        Self {
112            input,
113            codec: EncoderKind::Deflate(ZlibEncoder::new(Vec::new(), FlateCompression::default())),
114            pending: VecDeque::new(),
115            finished: false,
116            terminal: None,
117        }
118    }
119
120    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
121        self.terminal = Some(Terminal::Error(error.clone()));
122        Some(Err(error))
123    }
124
125    fn harvest_output(&mut self) {
126        let output = self.codec.take_output();
127        if !output.is_empty() {
128            self.pending.push_back(output);
129        }
130    }
131}
132
133impl Iterator for CompressStream {
134    type Item = StreamResult<Vec<u8>>;
135
136    fn next(&mut self) -> Option<Self::Item> {
137        if let Some(chunk) = self.pending.pop_front() {
138            return Some(Ok(chunk));
139        }
140        if let Some(terminal) = &self.terminal {
141            return sticky_terminal(terminal);
142        }
143
144        loop {
145            if self.finished {
146                self.terminal = Some(Terminal::Complete);
147                return None;
148            }
149
150            match self.input.next() {
151                Some(Ok(chunk)) => {
152                    if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
153                        return self.fail(error);
154                    }
155                    self.harvest_output();
156                    if let Some(chunk) = self.pending.pop_front() {
157                        return Some(Ok(chunk));
158                    }
159                }
160                Some(Err(error)) => {
161                    self.terminal = Some(Terminal::Error(error.clone()));
162                    return Some(Err(error));
163                }
164                None => {
165                    if let Err(error) = self.codec.try_finish().map_err(codec_error) {
166                        return self.fail(error);
167                    }
168                    self.finished = true;
169                    self.harvest_output();
170                    if let Some(chunk) = self.pending.pop_front() {
171                        return Some(Ok(chunk));
172                    }
173                }
174            }
175        }
176    }
177}
178
179enum DecoderKind {
180    Gzip(GzDecoder<Vec<u8>>),
181}
182
183impl DecoderKind {
184    fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
185        match self {
186            Self::Gzip(codec) => codec.write_all(chunk),
187        }
188    }
189
190    fn try_finish(&mut self) -> std::io::Result<()> {
191        match self {
192            Self::Gzip(codec) => codec.try_finish(),
193        }
194    }
195
196    fn take_output(&mut self) -> Vec<u8> {
197        match self {
198            Self::Gzip(codec) => std::mem::take(codec.get_mut()),
199        }
200    }
201}
202
203struct DecompressStream {
204    input: BoxStream<Vec<u8>>,
205    codec: DecoderKind,
206    pending: VecDeque<Vec<u8>>,
207    finished: bool,
208    terminal: Option<Terminal>,
209}
210
211impl DecompressStream {
212    fn gunzip(input: BoxStream<Vec<u8>>) -> Self {
213        Self {
214            input,
215            codec: DecoderKind::Gzip(GzDecoder::new(Vec::new())),
216            pending: VecDeque::new(),
217            finished: false,
218            terminal: None,
219        }
220    }
221
222    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
223        self.terminal = Some(Terminal::Error(error.clone()));
224        Some(Err(error))
225    }
226
227    fn harvest_output(&mut self) {
228        let output = self.codec.take_output();
229        if !output.is_empty() {
230            self.pending.push_back(output);
231        }
232    }
233}
234
235struct InflateStream {
236    input: BoxStream<Vec<u8>>,
237    codec: Decompress,
238    pending: VecDeque<Vec<u8>>,
239    finished: bool,
240    terminal: Option<Terminal>,
241}
242
243impl InflateStream {
244    fn new(input: BoxStream<Vec<u8>>) -> Self {
245        Self {
246            input,
247            codec: Decompress::new(true),
248            pending: VecDeque::new(),
249            finished: false,
250            terminal: None,
251        }
252    }
253
254    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
255        self.terminal = Some(Terminal::Error(error.clone()));
256        Some(Err(error))
257    }
258
259    fn pump(&mut self, mut remaining: &[u8], flush: FlushDecompress) -> StreamResult<bool> {
260        loop {
261            let before_in = self.codec.total_in();
262            let before_out = self.codec.total_out();
263            let mut output = vec![0_u8; DECOMPRESS_CHUNK_SIZE];
264            let status = self
265                .codec
266                .decompress(remaining, &mut output, flush)
267                .map_err(codec_error)?;
268            let consumed = (self.codec.total_in() - before_in) as usize;
269            let produced = (self.codec.total_out() - before_out) as usize;
270            output.truncate(produced);
271            if !output.is_empty() {
272                // Reusing these buffers could help later, but that needs benchmark coverage.
273                output.shrink_to_fit();
274                self.pending.push_back(output);
275            }
276            remaining = &remaining[consumed..];
277
278            if matches!(status, Status::StreamEnd) {
279                return Ok(true);
280            }
281            if consumed == 0 && produced == 0 {
282                return Ok(false);
283            }
284            if remaining.is_empty() && !matches!(flush, FlushDecompress::Finish) {
285                return Ok(false);
286            }
287        }
288    }
289}
290
291impl Iterator for InflateStream {
292    type Item = StreamResult<Vec<u8>>;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        if let Some(chunk) = self.pending.pop_front() {
296            return Some(Ok(chunk));
297        }
298        if let Some(terminal) = &self.terminal {
299            return sticky_terminal(terminal);
300        }
301        if self.finished {
302            self.terminal = Some(Terminal::Complete);
303            return None;
304        }
305
306        loop {
307            match self.input.next() {
308                Some(Ok(chunk)) => match self.pump(&chunk, FlushDecompress::None) {
309                    Ok(done) => {
310                        if done {
311                            self.finished = true;
312                        }
313                        if let Some(chunk) = self.pending.pop_front() {
314                            return Some(Ok(chunk));
315                        }
316                        if self.finished {
317                            self.terminal = Some(Terminal::Complete);
318                            return None;
319                        }
320                    }
321                    Err(error) => return self.fail(error),
322                },
323                Some(Err(error)) => {
324                    self.terminal = Some(Terminal::Error(error.clone()));
325                    return Some(Err(error));
326                }
327                None => match self.pump(&[], FlushDecompress::Finish) {
328                    Ok(true) => {
329                        self.finished = true;
330                        if let Some(chunk) = self.pending.pop_front() {
331                            return Some(Ok(chunk));
332                        }
333                        self.terminal = Some(Terminal::Complete);
334                        return None;
335                    }
336                    Ok(false) => {
337                        return self.fail(StreamError::Failed(
338                            "truncated compressed stream".to_owned(),
339                        ));
340                    }
341                    Err(error) => return self.fail(error),
342                },
343            }
344        }
345    }
346}
347
348impl Iterator for DecompressStream {
349    type Item = StreamResult<Vec<u8>>;
350
351    fn next(&mut self) -> Option<Self::Item> {
352        if let Some(chunk) = self.pending.pop_front() {
353            return Some(Ok(chunk));
354        }
355        if let Some(terminal) = &self.terminal {
356            return sticky_terminal(terminal);
357        }
358        if self.finished {
359            self.terminal = Some(Terminal::Complete);
360            return None;
361        }
362
363        loop {
364            match self.input.next() {
365                Some(Ok(chunk)) => {
366                    if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
367                        return self.fail(error);
368                    }
369                    self.harvest_output();
370                    if let Some(chunk) = self.pending.pop_front() {
371                        return Some(Ok(chunk));
372                    }
373                }
374                Some(Err(error)) => {
375                    self.terminal = Some(Terminal::Error(error.clone()));
376                    return Some(Err(error));
377                }
378                None => match self.codec.try_finish().map_err(codec_error) {
379                    Ok(()) => {
380                        self.finished = true;
381                        self.harvest_output();
382                        if let Some(chunk) = self.pending.pop_front() {
383                            return Some(Ok(chunk));
384                        }
385                    }
386                    Err(error) => return self.fail(error),
387                },
388            }
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::Source;
397
398    fn collect_chunks(flow: Flow<Vec<u8>, Vec<u8>>) -> Vec<Vec<u8>> {
399        Source::from_iter([b"hello ".to_vec(), b"world".to_vec()])
400            .via(flow)
401            .run_with(crate::Sink::collect())
402            .expect("flow materializes")
403            .wait()
404            .expect("flow completes")
405    }
406
407    #[test]
408    fn gzip_and_gunzip_round_trip() {
409        let compressed = collect_chunks(Compression::gzip());
410        let decoded = Source::from_iter(compressed)
411            .via(Compression::gunzip())
412            .run_with(crate::Sink::collect())
413            .expect("gunzip materializes")
414            .wait()
415            .expect("gunzip completes");
416
417        assert_eq!(decoded.concat(), b"hello world");
418    }
419
420    #[test]
421    fn deflate_and_inflate_round_trip() {
422        let compressed = collect_chunks(Compression::deflate());
423        let decoded = Source::from_iter(compressed)
424            .via(Compression::inflate())
425            .run_with(crate::Sink::collect())
426            .expect("inflate materializes")
427            .wait()
428            .expect("inflate completes");
429
430        assert_eq!(decoded.concat(), b"hello world");
431    }
432
433    #[test]
434    fn gunzip_fails_on_truncated_input() {
435        let compressed = collect_chunks(Compression::gzip());
436        let mut truncated = compressed.concat();
437        truncated.truncate(truncated.len().saturating_sub(2));
438
439        let result = Source::single(truncated)
440            .via(Compression::gunzip())
441            .run_with(crate::Sink::collect())
442            .expect("gunzip materializes")
443            .wait();
444
445        assert!(matches!(result, Err(StreamError::Failed(_))));
446    }
447
448    #[test]
449    fn inflate_fails_on_truncated_input() {
450        let compressed = collect_chunks(Compression::deflate());
451        let mut truncated = compressed.concat();
452        truncated.truncate(truncated.len() / 2);
453
454        let result = Source::single(truncated)
455            .via(Compression::inflate())
456            .run_with(crate::Sink::collect())
457            .expect("inflate materializes")
458            .wait();
459
460        assert!(matches!(result, Err(StreamError::Failed(_))));
461    }
462}