1use crate::stream::{BoxStream, Flow};
2use crate::{StreamError, StreamResult};
3use flate2::Compression as FlateCompression;
4use flate2::write::{GzDecoder, GzEncoder, ZlibEncoder};
5use flate2::{Decompress, FlushDecompress, Status};
6use std::collections::VecDeque;
7use std::io::Write;
8
9const DECOMPRESS_CHUNK_SIZE: usize = 8192;
10
11#[derive(Clone)]
12enum Terminal {
13 Complete,
14 Error(StreamError),
15}
16
17fn sticky_terminal<T>(terminal: &Terminal) -> Option<StreamResult<T>> {
18 match terminal {
19 Terminal::Complete => None,
20 Terminal::Error(error) => Some(Err(error.clone())),
21 }
22}
23
24fn codec_error<E: std::fmt::Display>(error: E) -> StreamError {
25 StreamError::Failed(error.to_string())
26}
27
28pub struct Compression;
29
30impl Compression {
31 #[must_use]
32 pub fn gzip() -> Flow<Vec<u8>, Vec<u8>> {
33 Flow::from_transform(|input| Box::new(CompressStream::gzip(input)) as BoxStream<Vec<u8>>)
34 }
35
36 #[must_use]
37 pub fn deflate() -> Flow<Vec<u8>, Vec<u8>> {
38 Flow::from_transform(|input| Box::new(CompressStream::deflate(input)) as BoxStream<Vec<u8>>)
39 }
40
41 #[must_use]
42 pub fn gunzip() -> Flow<Vec<u8>, Vec<u8>> {
43 Flow::from_transform(|input| {
44 Box::new(DecompressStream::gunzip(input)) as BoxStream<Vec<u8>>
45 })
46 }
47
48 #[must_use]
49 pub fn inflate() -> Flow<Vec<u8>, Vec<u8>> {
50 Flow::from_transform(|input| Box::new(InflateStream::new(input)) as BoxStream<Vec<u8>>)
51 }
52}
53
54enum EncoderKind {
55 Gzip(GzEncoder<Vec<u8>>),
56 Deflate(ZlibEncoder<Vec<u8>>),
57}
58
59impl EncoderKind {
60 fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
61 match self {
62 Self::Gzip(codec) => codec.write_all(chunk),
63 Self::Deflate(codec) => codec.write_all(chunk),
64 }
65 }
66
67 fn try_finish(&mut self) -> std::io::Result<()> {
68 match self {
69 Self::Gzip(codec) => codec.try_finish(),
70 Self::Deflate(codec) => codec.try_finish(),
71 }
72 }
73
74 fn take_output(&mut self) -> Vec<u8> {
75 match self {
76 Self::Gzip(codec) => std::mem::take(codec.get_mut()),
77 Self::Deflate(codec) => std::mem::take(codec.get_mut()),
78 }
79 }
80}
81
82struct CompressStream {
83 input: BoxStream<Vec<u8>>,
84 codec: EncoderKind,
85 pending: VecDeque<Vec<u8>>,
86 finished: bool,
87 terminal: Option<Terminal>,
88}
89
90impl CompressStream {
91 fn gzip(input: BoxStream<Vec<u8>>) -> Self {
92 Self {
93 input,
94 codec: EncoderKind::Gzip(GzEncoder::new(Vec::new(), FlateCompression::default())),
95 pending: VecDeque::new(),
96 finished: false,
97 terminal: None,
98 }
99 }
100
101 fn deflate(input: BoxStream<Vec<u8>>) -> Self {
102 Self {
103 input,
104 codec: EncoderKind::Deflate(ZlibEncoder::new(Vec::new(), FlateCompression::default())),
105 pending: VecDeque::new(),
106 finished: false,
107 terminal: None,
108 }
109 }
110
111 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
112 self.terminal = Some(Terminal::Error(error.clone()));
113 Some(Err(error))
114 }
115
116 fn harvest_output(&mut self) {
117 let output = self.codec.take_output();
118 if !output.is_empty() {
119 self.pending.push_back(output);
120 }
121 }
122}
123
124impl Iterator for CompressStream {
125 type Item = StreamResult<Vec<u8>>;
126
127 fn next(&mut self) -> Option<Self::Item> {
128 if let Some(chunk) = self.pending.pop_front() {
129 return Some(Ok(chunk));
130 }
131 if let Some(terminal) = &self.terminal {
132 return sticky_terminal(terminal);
133 }
134
135 loop {
136 if self.finished {
137 self.terminal = Some(Terminal::Complete);
138 return None;
139 }
140
141 match self.input.next() {
142 Some(Ok(chunk)) => {
143 if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
144 return self.fail(error);
145 }
146 self.harvest_output();
147 if let Some(chunk) = self.pending.pop_front() {
148 return Some(Ok(chunk));
149 }
150 }
151 Some(Err(error)) => {
152 self.terminal = Some(Terminal::Error(error.clone()));
153 return Some(Err(error));
154 }
155 None => {
156 if let Err(error) = self.codec.try_finish().map_err(codec_error) {
157 return self.fail(error);
158 }
159 self.finished = true;
160 self.harvest_output();
161 if let Some(chunk) = self.pending.pop_front() {
162 return Some(Ok(chunk));
163 }
164 }
165 }
166 }
167 }
168}
169
170enum DecoderKind {
171 Gzip(GzDecoder<Vec<u8>>),
172}
173
174impl DecoderKind {
175 fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
176 match self {
177 Self::Gzip(codec) => codec.write_all(chunk),
178 }
179 }
180
181 fn try_finish(&mut self) -> std::io::Result<()> {
182 match self {
183 Self::Gzip(codec) => codec.try_finish(),
184 }
185 }
186
187 fn take_output(&mut self) -> Vec<u8> {
188 match self {
189 Self::Gzip(codec) => std::mem::take(codec.get_mut()),
190 }
191 }
192}
193
194struct DecompressStream {
195 input: BoxStream<Vec<u8>>,
196 codec: DecoderKind,
197 pending: VecDeque<Vec<u8>>,
198 finished: bool,
199 terminal: Option<Terminal>,
200}
201
202impl DecompressStream {
203 fn gunzip(input: BoxStream<Vec<u8>>) -> Self {
204 Self {
205 input,
206 codec: DecoderKind::Gzip(GzDecoder::new(Vec::new())),
207 pending: VecDeque::new(),
208 finished: false,
209 terminal: None,
210 }
211 }
212
213 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
214 self.terminal = Some(Terminal::Error(error.clone()));
215 Some(Err(error))
216 }
217
218 fn harvest_output(&mut self) {
219 let output = self.codec.take_output();
220 if !output.is_empty() {
221 self.pending.push_back(output);
222 }
223 }
224}
225
226struct InflateStream {
227 input: BoxStream<Vec<u8>>,
228 codec: Decompress,
229 pending: VecDeque<Vec<u8>>,
230 finished: bool,
231 terminal: Option<Terminal>,
232}
233
234impl InflateStream {
235 fn new(input: BoxStream<Vec<u8>>) -> Self {
236 Self {
237 input,
238 codec: Decompress::new(true),
239 pending: VecDeque::new(),
240 finished: false,
241 terminal: None,
242 }
243 }
244
245 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
246 self.terminal = Some(Terminal::Error(error.clone()));
247 Some(Err(error))
248 }
249
250 fn pump(&mut self, mut remaining: &[u8], flush: FlushDecompress) -> StreamResult<bool> {
251 loop {
252 let before_in = self.codec.total_in();
253 let before_out = self.codec.total_out();
254 let mut output = vec![0_u8; DECOMPRESS_CHUNK_SIZE];
255 let status = self
256 .codec
257 .decompress(remaining, &mut output, flush)
258 .map_err(codec_error)?;
259 let consumed = (self.codec.total_in() - before_in) as usize;
260 let produced = (self.codec.total_out() - before_out) as usize;
261 output.truncate(produced);
262 if !output.is_empty() {
263 output.shrink_to_fit();
265 self.pending.push_back(output);
266 }
267 remaining = &remaining[consumed..];
268
269 if matches!(status, Status::StreamEnd) {
270 return Ok(true);
271 }
272 if consumed == 0 && produced == 0 {
273 return Ok(false);
274 }
275 if remaining.is_empty() && !matches!(flush, FlushDecompress::Finish) {
276 return Ok(false);
277 }
278 }
279 }
280}
281
282impl Iterator for InflateStream {
283 type Item = StreamResult<Vec<u8>>;
284
285 fn next(&mut self) -> Option<Self::Item> {
286 if let Some(chunk) = self.pending.pop_front() {
287 return Some(Ok(chunk));
288 }
289 if let Some(terminal) = &self.terminal {
290 return sticky_terminal(terminal);
291 }
292 if self.finished {
293 self.terminal = Some(Terminal::Complete);
294 return None;
295 }
296
297 loop {
298 match self.input.next() {
299 Some(Ok(chunk)) => match self.pump(&chunk, FlushDecompress::None) {
300 Ok(done) => {
301 if done {
302 self.finished = true;
303 }
304 if let Some(chunk) = self.pending.pop_front() {
305 return Some(Ok(chunk));
306 }
307 if self.finished {
308 self.terminal = Some(Terminal::Complete);
309 return None;
310 }
311 }
312 Err(error) => return self.fail(error),
313 },
314 Some(Err(error)) => {
315 self.terminal = Some(Terminal::Error(error.clone()));
316 return Some(Err(error));
317 }
318 None => match self.pump(&[], FlushDecompress::Finish) {
319 Ok(true) => {
320 self.finished = true;
321 if let Some(chunk) = self.pending.pop_front() {
322 return Some(Ok(chunk));
323 }
324 self.terminal = Some(Terminal::Complete);
325 return None;
326 }
327 Ok(false) => {
328 return self.fail(StreamError::Failed(
329 "truncated compressed stream".to_owned(),
330 ));
331 }
332 Err(error) => return self.fail(error),
333 },
334 }
335 }
336 }
337}
338
339impl Iterator for DecompressStream {
340 type Item = StreamResult<Vec<u8>>;
341
342 fn next(&mut self) -> Option<Self::Item> {
343 if let Some(chunk) = self.pending.pop_front() {
344 return Some(Ok(chunk));
345 }
346 if let Some(terminal) = &self.terminal {
347 return sticky_terminal(terminal);
348 }
349 if self.finished {
350 self.terminal = Some(Terminal::Complete);
351 return None;
352 }
353
354 loop {
355 match self.input.next() {
356 Some(Ok(chunk)) => {
357 if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
358 return self.fail(error);
359 }
360 self.harvest_output();
361 if let Some(chunk) = self.pending.pop_front() {
362 return Some(Ok(chunk));
363 }
364 }
365 Some(Err(error)) => {
366 self.terminal = Some(Terminal::Error(error.clone()));
367 return Some(Err(error));
368 }
369 None => match self.codec.try_finish().map_err(codec_error) {
370 Ok(()) => {
371 self.finished = true;
372 self.harvest_output();
373 if let Some(chunk) = self.pending.pop_front() {
374 return Some(Ok(chunk));
375 }
376 }
377 Err(error) => return self.fail(error),
378 },
379 }
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::Source;
388
389 fn collect_chunks(flow: Flow<Vec<u8>, Vec<u8>>) -> Vec<Vec<u8>> {
390 Source::from_iter([b"hello ".to_vec(), b"world".to_vec()])
391 .via(flow)
392 .run_with(crate::Sink::collect())
393 .expect("flow materializes")
394 .wait()
395 .expect("flow completes")
396 }
397
398 #[test]
399 fn gzip_and_gunzip_round_trip() {
400 let compressed = collect_chunks(Compression::gzip());
401 let decoded = Source::from_iter(compressed)
402 .via(Compression::gunzip())
403 .run_with(crate::Sink::collect())
404 .expect("gunzip materializes")
405 .wait()
406 .expect("gunzip completes");
407
408 assert_eq!(decoded.concat(), b"hello world");
409 }
410
411 #[test]
412 fn deflate_and_inflate_round_trip() {
413 let compressed = collect_chunks(Compression::deflate());
414 let decoded = Source::from_iter(compressed)
415 .via(Compression::inflate())
416 .run_with(crate::Sink::collect())
417 .expect("inflate materializes")
418 .wait()
419 .expect("inflate completes");
420
421 assert_eq!(decoded.concat(), b"hello world");
422 }
423
424 #[test]
425 fn gunzip_fails_on_truncated_input() {
426 let compressed = collect_chunks(Compression::gzip());
427 let mut truncated = compressed.concat();
428 truncated.truncate(truncated.len().saturating_sub(2));
429
430 let result = Source::single(truncated)
431 .via(Compression::gunzip())
432 .run_with(crate::Sink::collect())
433 .expect("gunzip materializes")
434 .wait();
435
436 assert!(matches!(result, Err(StreamError::Failed(_))));
437 }
438
439 #[test]
440 fn inflate_fails_on_truncated_input() {
441 let compressed = collect_chunks(Compression::deflate());
442 let mut truncated = compressed.concat();
443 truncated.truncate(truncated.len() / 2);
444
445 let result = Source::single(truncated)
446 .via(Compression::inflate())
447 .run_with(crate::Sink::collect())
448 .expect("inflate materializes")
449 .wait();
450
451 assert!(matches!(result, Err(StreamError::Failed(_))));
452 }
453}