Skip to main content

datum/io/
compression.rs

1use crate::stream::{BoxStream, Flow};
2use crate::{StreamError, StreamResult};
3use flate2::Compression as FlateCompression;
4use flate2::write::{GzDecoder, GzEncoder, ZlibEncoder};
5use flate2::{Decompress, FlushDecompress, Status};
6use std::collections::VecDeque;
7use std::io::Write;
8
9const DECOMPRESS_CHUNK_SIZE: usize = 8192;
10
11#[derive(Clone)]
12enum Terminal {
13    Complete,
14    Error(StreamError),
15}
16
17fn sticky_terminal<T>(terminal: &Terminal) -> Option<StreamResult<T>> {
18    match terminal {
19        Terminal::Complete => None,
20        Terminal::Error(error) => Some(Err(error.clone())),
21    }
22}
23
24fn codec_error<E: std::fmt::Display>(error: E) -> StreamError {
25    StreamError::Failed(error.to_string())
26}
27
28pub struct Compression;
29
30impl Compression {
31    #[must_use]
32    pub fn gzip() -> Flow<Vec<u8>, Vec<u8>> {
33        Flow::from_transform(|input| Box::new(CompressStream::gzip(input)) as BoxStream<Vec<u8>>)
34    }
35
36    #[must_use]
37    pub fn deflate() -> Flow<Vec<u8>, Vec<u8>> {
38        Flow::from_transform(|input| Box::new(CompressStream::deflate(input)) as BoxStream<Vec<u8>>)
39    }
40
41    #[must_use]
42    pub fn gunzip() -> Flow<Vec<u8>, Vec<u8>> {
43        Flow::from_transform(|input| {
44            Box::new(DecompressStream::gunzip(input)) as BoxStream<Vec<u8>>
45        })
46    }
47
48    #[must_use]
49    pub fn inflate() -> Flow<Vec<u8>, Vec<u8>> {
50        Flow::from_transform(|input| Box::new(InflateStream::new(input)) as BoxStream<Vec<u8>>)
51    }
52}
53
54enum EncoderKind {
55    Gzip(GzEncoder<Vec<u8>>),
56    Deflate(ZlibEncoder<Vec<u8>>),
57}
58
59impl EncoderKind {
60    fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
61        match self {
62            Self::Gzip(codec) => codec.write_all(chunk),
63            Self::Deflate(codec) => codec.write_all(chunk),
64        }
65    }
66
67    fn try_finish(&mut self) -> std::io::Result<()> {
68        match self {
69            Self::Gzip(codec) => codec.try_finish(),
70            Self::Deflate(codec) => codec.try_finish(),
71        }
72    }
73
74    fn take_output(&mut self) -> Vec<u8> {
75        match self {
76            Self::Gzip(codec) => std::mem::take(codec.get_mut()),
77            Self::Deflate(codec) => std::mem::take(codec.get_mut()),
78        }
79    }
80}
81
82struct CompressStream {
83    input: BoxStream<Vec<u8>>,
84    codec: EncoderKind,
85    pending: VecDeque<Vec<u8>>,
86    finished: bool,
87    terminal: Option<Terminal>,
88}
89
90impl CompressStream {
91    fn gzip(input: BoxStream<Vec<u8>>) -> Self {
92        Self {
93            input,
94            codec: EncoderKind::Gzip(GzEncoder::new(Vec::new(), FlateCompression::default())),
95            pending: VecDeque::new(),
96            finished: false,
97            terminal: None,
98        }
99    }
100
101    fn deflate(input: BoxStream<Vec<u8>>) -> Self {
102        Self {
103            input,
104            codec: EncoderKind::Deflate(ZlibEncoder::new(Vec::new(), FlateCompression::default())),
105            pending: VecDeque::new(),
106            finished: false,
107            terminal: None,
108        }
109    }
110
111    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
112        self.terminal = Some(Terminal::Error(error.clone()));
113        Some(Err(error))
114    }
115
116    fn harvest_output(&mut self) {
117        let output = self.codec.take_output();
118        if !output.is_empty() {
119            self.pending.push_back(output);
120        }
121    }
122}
123
124impl Iterator for CompressStream {
125    type Item = StreamResult<Vec<u8>>;
126
127    fn next(&mut self) -> Option<Self::Item> {
128        if let Some(chunk) = self.pending.pop_front() {
129            return Some(Ok(chunk));
130        }
131        if let Some(terminal) = &self.terminal {
132            return sticky_terminal(terminal);
133        }
134
135        loop {
136            if self.finished {
137                self.terminal = Some(Terminal::Complete);
138                return None;
139            }
140
141            match self.input.next() {
142                Some(Ok(chunk)) => {
143                    if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
144                        return self.fail(error);
145                    }
146                    self.harvest_output();
147                    if let Some(chunk) = self.pending.pop_front() {
148                        return Some(Ok(chunk));
149                    }
150                }
151                Some(Err(error)) => {
152                    self.terminal = Some(Terminal::Error(error.clone()));
153                    return Some(Err(error));
154                }
155                None => {
156                    if let Err(error) = self.codec.try_finish().map_err(codec_error) {
157                        return self.fail(error);
158                    }
159                    self.finished = true;
160                    self.harvest_output();
161                    if let Some(chunk) = self.pending.pop_front() {
162                        return Some(Ok(chunk));
163                    }
164                }
165            }
166        }
167    }
168}
169
170enum DecoderKind {
171    Gzip(GzDecoder<Vec<u8>>),
172}
173
174impl DecoderKind {
175    fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
176        match self {
177            Self::Gzip(codec) => codec.write_all(chunk),
178        }
179    }
180
181    fn try_finish(&mut self) -> std::io::Result<()> {
182        match self {
183            Self::Gzip(codec) => codec.try_finish(),
184        }
185    }
186
187    fn take_output(&mut self) -> Vec<u8> {
188        match self {
189            Self::Gzip(codec) => std::mem::take(codec.get_mut()),
190        }
191    }
192}
193
194struct DecompressStream {
195    input: BoxStream<Vec<u8>>,
196    codec: DecoderKind,
197    pending: VecDeque<Vec<u8>>,
198    finished: bool,
199    terminal: Option<Terminal>,
200}
201
202impl DecompressStream {
203    fn gunzip(input: BoxStream<Vec<u8>>) -> Self {
204        Self {
205            input,
206            codec: DecoderKind::Gzip(GzDecoder::new(Vec::new())),
207            pending: VecDeque::new(),
208            finished: false,
209            terminal: None,
210        }
211    }
212
213    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
214        self.terminal = Some(Terminal::Error(error.clone()));
215        Some(Err(error))
216    }
217
218    fn harvest_output(&mut self) {
219        let output = self.codec.take_output();
220        if !output.is_empty() {
221            self.pending.push_back(output);
222        }
223    }
224}
225
226struct InflateStream {
227    input: BoxStream<Vec<u8>>,
228    codec: Decompress,
229    pending: VecDeque<Vec<u8>>,
230    finished: bool,
231    terminal: Option<Terminal>,
232}
233
234impl InflateStream {
235    fn new(input: BoxStream<Vec<u8>>) -> Self {
236        Self {
237            input,
238            codec: Decompress::new(true),
239            pending: VecDeque::new(),
240            finished: false,
241            terminal: None,
242        }
243    }
244
245    fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
246        self.terminal = Some(Terminal::Error(error.clone()));
247        Some(Err(error))
248    }
249
250    fn pump(&mut self, mut remaining: &[u8], flush: FlushDecompress) -> StreamResult<bool> {
251        loop {
252            let before_in = self.codec.total_in();
253            let before_out = self.codec.total_out();
254            let mut output = vec![0_u8; DECOMPRESS_CHUNK_SIZE];
255            let status = self
256                .codec
257                .decompress(remaining, &mut output, flush)
258                .map_err(codec_error)?;
259            let consumed = (self.codec.total_in() - before_in) as usize;
260            let produced = (self.codec.total_out() - before_out) as usize;
261            output.truncate(produced);
262            if !output.is_empty() {
263                // Reusing these buffers could help later, but that needs benchmark coverage.
264                output.shrink_to_fit();
265                self.pending.push_back(output);
266            }
267            remaining = &remaining[consumed..];
268
269            if matches!(status, Status::StreamEnd) {
270                return Ok(true);
271            }
272            if consumed == 0 && produced == 0 {
273                return Ok(false);
274            }
275            if remaining.is_empty() && !matches!(flush, FlushDecompress::Finish) {
276                return Ok(false);
277            }
278        }
279    }
280}
281
282impl Iterator for InflateStream {
283    type Item = StreamResult<Vec<u8>>;
284
285    fn next(&mut self) -> Option<Self::Item> {
286        if let Some(chunk) = self.pending.pop_front() {
287            return Some(Ok(chunk));
288        }
289        if let Some(terminal) = &self.terminal {
290            return sticky_terminal(terminal);
291        }
292        if self.finished {
293            self.terminal = Some(Terminal::Complete);
294            return None;
295        }
296
297        loop {
298            match self.input.next() {
299                Some(Ok(chunk)) => match self.pump(&chunk, FlushDecompress::None) {
300                    Ok(done) => {
301                        if done {
302                            self.finished = true;
303                        }
304                        if let Some(chunk) = self.pending.pop_front() {
305                            return Some(Ok(chunk));
306                        }
307                        if self.finished {
308                            self.terminal = Some(Terminal::Complete);
309                            return None;
310                        }
311                    }
312                    Err(error) => return self.fail(error),
313                },
314                Some(Err(error)) => {
315                    self.terminal = Some(Terminal::Error(error.clone()));
316                    return Some(Err(error));
317                }
318                None => match self.pump(&[], FlushDecompress::Finish) {
319                    Ok(true) => {
320                        self.finished = true;
321                        if let Some(chunk) = self.pending.pop_front() {
322                            return Some(Ok(chunk));
323                        }
324                        self.terminal = Some(Terminal::Complete);
325                        return None;
326                    }
327                    Ok(false) => {
328                        return self.fail(StreamError::Failed(
329                            "truncated compressed stream".to_owned(),
330                        ));
331                    }
332                    Err(error) => return self.fail(error),
333                },
334            }
335        }
336    }
337}
338
339impl Iterator for DecompressStream {
340    type Item = StreamResult<Vec<u8>>;
341
342    fn next(&mut self) -> Option<Self::Item> {
343        if let Some(chunk) = self.pending.pop_front() {
344            return Some(Ok(chunk));
345        }
346        if let Some(terminal) = &self.terminal {
347            return sticky_terminal(terminal);
348        }
349        if self.finished {
350            self.terminal = Some(Terminal::Complete);
351            return None;
352        }
353
354        loop {
355            match self.input.next() {
356                Some(Ok(chunk)) => {
357                    if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
358                        return self.fail(error);
359                    }
360                    self.harvest_output();
361                    if let Some(chunk) = self.pending.pop_front() {
362                        return Some(Ok(chunk));
363                    }
364                }
365                Some(Err(error)) => {
366                    self.terminal = Some(Terminal::Error(error.clone()));
367                    return Some(Err(error));
368                }
369                None => match self.codec.try_finish().map_err(codec_error) {
370                    Ok(()) => {
371                        self.finished = true;
372                        self.harvest_output();
373                        if let Some(chunk) = self.pending.pop_front() {
374                            return Some(Ok(chunk));
375                        }
376                    }
377                    Err(error) => return self.fail(error),
378                },
379            }
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::Source;
388
389    fn collect_chunks(flow: Flow<Vec<u8>, Vec<u8>>) -> Vec<Vec<u8>> {
390        Source::from_iter([b"hello ".to_vec(), b"world".to_vec()])
391            .via(flow)
392            .run_with(crate::Sink::collect())
393            .expect("flow materializes")
394            .wait()
395            .expect("flow completes")
396    }
397
398    #[test]
399    fn gzip_and_gunzip_round_trip() {
400        let compressed = collect_chunks(Compression::gzip());
401        let decoded = Source::from_iter(compressed)
402            .via(Compression::gunzip())
403            .run_with(crate::Sink::collect())
404            .expect("gunzip materializes")
405            .wait()
406            .expect("gunzip completes");
407
408        assert_eq!(decoded.concat(), b"hello world");
409    }
410
411    #[test]
412    fn deflate_and_inflate_round_trip() {
413        let compressed = collect_chunks(Compression::deflate());
414        let decoded = Source::from_iter(compressed)
415            .via(Compression::inflate())
416            .run_with(crate::Sink::collect())
417            .expect("inflate materializes")
418            .wait()
419            .expect("inflate completes");
420
421        assert_eq!(decoded.concat(), b"hello world");
422    }
423
424    #[test]
425    fn gunzip_fails_on_truncated_input() {
426        let compressed = collect_chunks(Compression::gzip());
427        let mut truncated = compressed.concat();
428        truncated.truncate(truncated.len().saturating_sub(2));
429
430        let result = Source::single(truncated)
431            .via(Compression::gunzip())
432            .run_with(crate::Sink::collect())
433            .expect("gunzip materializes")
434            .wait();
435
436        assert!(matches!(result, Err(StreamError::Failed(_))));
437    }
438
439    #[test]
440    fn inflate_fails_on_truncated_input() {
441        let compressed = collect_chunks(Compression::deflate());
442        let mut truncated = compressed.concat();
443        truncated.truncate(truncated.len() / 2);
444
445        let result = Source::single(truncated)
446            .via(Compression::inflate())
447            .run_with(crate::Sink::collect())
448            .expect("inflate materializes")
449            .wait();
450
451        assert!(matches!(result, Err(StreamError::Failed(_))));
452    }
453}