1use crate::stream::{BoxStream, Flow};
2use crate::{StreamError, StreamResult};
3use flate2::Compression as FlateCompression;
4use flate2::write::{GzDecoder, GzEncoder, ZlibEncoder};
5use flate2::{Decompress, FlushDecompress, Status};
6use std::collections::VecDeque;
7use std::io::Write;
8
9const DECOMPRESS_CHUNK_SIZE: usize = 8192;
10
11#[derive(Clone)]
12enum Terminal {
13 Complete,
14 Error(StreamError),
15}
16
17fn sticky_terminal<T>(terminal: &Terminal) -> Option<StreamResult<T>> {
18 match terminal {
19 Terminal::Complete => None,
20 Terminal::Error(error) => Some(Err(error.clone())),
21 }
22}
23
24fn codec_error<E: std::fmt::Display>(error: E) -> StreamError {
25 StreamError::Failed(error.to_string())
26}
27
28pub struct Compression;
34
35impl Compression {
36 #[must_use]
38 pub fn gzip() -> Flow<Vec<u8>, Vec<u8>> {
39 Flow::from_transform(|input| Box::new(CompressStream::gzip(input)) as BoxStream<Vec<u8>>)
40 }
41
42 #[must_use]
44 pub fn deflate() -> Flow<Vec<u8>, Vec<u8>> {
45 Flow::from_transform(|input| Box::new(CompressStream::deflate(input)) as BoxStream<Vec<u8>>)
46 }
47
48 #[must_use]
50 pub fn gunzip() -> Flow<Vec<u8>, Vec<u8>> {
51 Flow::from_transform(|input| {
52 Box::new(DecompressStream::gunzip(input)) as BoxStream<Vec<u8>>
53 })
54 }
55
56 #[must_use]
58 pub fn inflate() -> Flow<Vec<u8>, Vec<u8>> {
59 Flow::from_transform(|input| Box::new(InflateStream::new(input)) as BoxStream<Vec<u8>>)
60 }
61}
62
63enum EncoderKind {
64 Gzip(GzEncoder<Vec<u8>>),
65 Deflate(ZlibEncoder<Vec<u8>>),
66}
67
68impl EncoderKind {
69 fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
70 match self {
71 Self::Gzip(codec) => codec.write_all(chunk),
72 Self::Deflate(codec) => codec.write_all(chunk),
73 }
74 }
75
76 fn try_finish(&mut self) -> std::io::Result<()> {
77 match self {
78 Self::Gzip(codec) => codec.try_finish(),
79 Self::Deflate(codec) => codec.try_finish(),
80 }
81 }
82
83 fn take_output(&mut self) -> Vec<u8> {
84 match self {
85 Self::Gzip(codec) => std::mem::take(codec.get_mut()),
86 Self::Deflate(codec) => std::mem::take(codec.get_mut()),
87 }
88 }
89}
90
91struct CompressStream {
92 input: BoxStream<Vec<u8>>,
93 codec: EncoderKind,
94 pending: VecDeque<Vec<u8>>,
95 finished: bool,
96 terminal: Option<Terminal>,
97}
98
99impl CompressStream {
100 fn gzip(input: BoxStream<Vec<u8>>) -> Self {
101 Self {
102 input,
103 codec: EncoderKind::Gzip(GzEncoder::new(Vec::new(), FlateCompression::default())),
104 pending: VecDeque::new(),
105 finished: false,
106 terminal: None,
107 }
108 }
109
110 fn deflate(input: BoxStream<Vec<u8>>) -> Self {
111 Self {
112 input,
113 codec: EncoderKind::Deflate(ZlibEncoder::new(Vec::new(), FlateCompression::default())),
114 pending: VecDeque::new(),
115 finished: false,
116 terminal: None,
117 }
118 }
119
120 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
121 self.terminal = Some(Terminal::Error(error.clone()));
122 Some(Err(error))
123 }
124
125 fn harvest_output(&mut self) {
126 let output = self.codec.take_output();
127 if !output.is_empty() {
128 self.pending.push_back(output);
129 }
130 }
131}
132
133impl Iterator for CompressStream {
134 type Item = StreamResult<Vec<u8>>;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 if let Some(chunk) = self.pending.pop_front() {
138 return Some(Ok(chunk));
139 }
140 if let Some(terminal) = &self.terminal {
141 return sticky_terminal(terminal);
142 }
143
144 loop {
145 if self.finished {
146 self.terminal = Some(Terminal::Complete);
147 return None;
148 }
149
150 match self.input.next() {
151 Some(Ok(chunk)) => {
152 if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
153 return self.fail(error);
154 }
155 self.harvest_output();
156 if let Some(chunk) = self.pending.pop_front() {
157 return Some(Ok(chunk));
158 }
159 }
160 Some(Err(error)) => {
161 self.terminal = Some(Terminal::Error(error.clone()));
162 return Some(Err(error));
163 }
164 None => {
165 if let Err(error) = self.codec.try_finish().map_err(codec_error) {
166 return self.fail(error);
167 }
168 self.finished = true;
169 self.harvest_output();
170 if let Some(chunk) = self.pending.pop_front() {
171 return Some(Ok(chunk));
172 }
173 }
174 }
175 }
176 }
177}
178
179enum DecoderKind {
180 Gzip(GzDecoder<Vec<u8>>),
181}
182
183impl DecoderKind {
184 fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
185 match self {
186 Self::Gzip(codec) => codec.write_all(chunk),
187 }
188 }
189
190 fn try_finish(&mut self) -> std::io::Result<()> {
191 match self {
192 Self::Gzip(codec) => codec.try_finish(),
193 }
194 }
195
196 fn take_output(&mut self) -> Vec<u8> {
197 match self {
198 Self::Gzip(codec) => std::mem::take(codec.get_mut()),
199 }
200 }
201}
202
203struct DecompressStream {
204 input: BoxStream<Vec<u8>>,
205 codec: DecoderKind,
206 pending: VecDeque<Vec<u8>>,
207 finished: bool,
208 terminal: Option<Terminal>,
209}
210
211impl DecompressStream {
212 fn gunzip(input: BoxStream<Vec<u8>>) -> Self {
213 Self {
214 input,
215 codec: DecoderKind::Gzip(GzDecoder::new(Vec::new())),
216 pending: VecDeque::new(),
217 finished: false,
218 terminal: None,
219 }
220 }
221
222 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
223 self.terminal = Some(Terminal::Error(error.clone()));
224 Some(Err(error))
225 }
226
227 fn harvest_output(&mut self) {
228 let output = self.codec.take_output();
229 if !output.is_empty() {
230 self.pending.push_back(output);
231 }
232 }
233}
234
235struct InflateStream {
236 input: BoxStream<Vec<u8>>,
237 codec: Decompress,
238 pending: VecDeque<Vec<u8>>,
239 finished: bool,
240 terminal: Option<Terminal>,
241}
242
243impl InflateStream {
244 fn new(input: BoxStream<Vec<u8>>) -> Self {
245 Self {
246 input,
247 codec: Decompress::new(true),
248 pending: VecDeque::new(),
249 finished: false,
250 terminal: None,
251 }
252 }
253
254 fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
255 self.terminal = Some(Terminal::Error(error.clone()));
256 Some(Err(error))
257 }
258
259 fn pump(&mut self, mut remaining: &[u8], flush: FlushDecompress) -> StreamResult<bool> {
260 loop {
261 let before_in = self.codec.total_in();
262 let before_out = self.codec.total_out();
263 let mut output = vec![0_u8; DECOMPRESS_CHUNK_SIZE];
264 let status = self
265 .codec
266 .decompress(remaining, &mut output, flush)
267 .map_err(codec_error)?;
268 let consumed = (self.codec.total_in() - before_in) as usize;
269 let produced = (self.codec.total_out() - before_out) as usize;
270 output.truncate(produced);
271 if !output.is_empty() {
272 output.shrink_to_fit();
274 self.pending.push_back(output);
275 }
276 remaining = &remaining[consumed..];
277
278 if matches!(status, Status::StreamEnd) {
279 return Ok(true);
280 }
281 if consumed == 0 && produced == 0 {
282 return Ok(false);
283 }
284 if remaining.is_empty() && !matches!(flush, FlushDecompress::Finish) {
285 return Ok(false);
286 }
287 }
288 }
289}
290
291impl Iterator for InflateStream {
292 type Item = StreamResult<Vec<u8>>;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 if let Some(chunk) = self.pending.pop_front() {
296 return Some(Ok(chunk));
297 }
298 if let Some(terminal) = &self.terminal {
299 return sticky_terminal(terminal);
300 }
301 if self.finished {
302 self.terminal = Some(Terminal::Complete);
303 return None;
304 }
305
306 loop {
307 match self.input.next() {
308 Some(Ok(chunk)) => match self.pump(&chunk, FlushDecompress::None) {
309 Ok(done) => {
310 if done {
311 self.finished = true;
312 }
313 if let Some(chunk) = self.pending.pop_front() {
314 return Some(Ok(chunk));
315 }
316 if self.finished {
317 self.terminal = Some(Terminal::Complete);
318 return None;
319 }
320 }
321 Err(error) => return self.fail(error),
322 },
323 Some(Err(error)) => {
324 self.terminal = Some(Terminal::Error(error.clone()));
325 return Some(Err(error));
326 }
327 None => match self.pump(&[], FlushDecompress::Finish) {
328 Ok(true) => {
329 self.finished = true;
330 if let Some(chunk) = self.pending.pop_front() {
331 return Some(Ok(chunk));
332 }
333 self.terminal = Some(Terminal::Complete);
334 return None;
335 }
336 Ok(false) => {
337 return self.fail(StreamError::Failed(
338 "truncated compressed stream".to_owned(),
339 ));
340 }
341 Err(error) => return self.fail(error),
342 },
343 }
344 }
345 }
346}
347
348impl Iterator for DecompressStream {
349 type Item = StreamResult<Vec<u8>>;
350
351 fn next(&mut self) -> Option<Self::Item> {
352 if let Some(chunk) = self.pending.pop_front() {
353 return Some(Ok(chunk));
354 }
355 if let Some(terminal) = &self.terminal {
356 return sticky_terminal(terminal);
357 }
358 if self.finished {
359 self.terminal = Some(Terminal::Complete);
360 return None;
361 }
362
363 loop {
364 match self.input.next() {
365 Some(Ok(chunk)) => {
366 if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
367 return self.fail(error);
368 }
369 self.harvest_output();
370 if let Some(chunk) = self.pending.pop_front() {
371 return Some(Ok(chunk));
372 }
373 }
374 Some(Err(error)) => {
375 self.terminal = Some(Terminal::Error(error.clone()));
376 return Some(Err(error));
377 }
378 None => match self.codec.try_finish().map_err(codec_error) {
379 Ok(()) => {
380 self.finished = true;
381 self.harvest_output();
382 if let Some(chunk) = self.pending.pop_front() {
383 return Some(Ok(chunk));
384 }
385 }
386 Err(error) => return self.fail(error),
387 },
388 }
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::Source;
397
398 fn collect_chunks(flow: Flow<Vec<u8>, Vec<u8>>) -> Vec<Vec<u8>> {
399 Source::from_iter([b"hello ".to_vec(), b"world".to_vec()])
400 .via(flow)
401 .run_with(crate::Sink::collect())
402 .expect("flow materializes")
403 .wait()
404 .expect("flow completes")
405 }
406
407 #[test]
408 fn gzip_and_gunzip_round_trip() {
409 let compressed = collect_chunks(Compression::gzip());
410 let decoded = Source::from_iter(compressed)
411 .via(Compression::gunzip())
412 .run_with(crate::Sink::collect())
413 .expect("gunzip materializes")
414 .wait()
415 .expect("gunzip completes");
416
417 assert_eq!(decoded.concat(), b"hello world");
418 }
419
420 #[test]
421 fn deflate_and_inflate_round_trip() {
422 let compressed = collect_chunks(Compression::deflate());
423 let decoded = Source::from_iter(compressed)
424 .via(Compression::inflate())
425 .run_with(crate::Sink::collect())
426 .expect("inflate materializes")
427 .wait()
428 .expect("inflate completes");
429
430 assert_eq!(decoded.concat(), b"hello world");
431 }
432
433 #[test]
434 fn gunzip_fails_on_truncated_input() {
435 let compressed = collect_chunks(Compression::gzip());
436 let mut truncated = compressed.concat();
437 truncated.truncate(truncated.len().saturating_sub(2));
438
439 let result = Source::single(truncated)
440 .via(Compression::gunzip())
441 .run_with(crate::Sink::collect())
442 .expect("gunzip materializes")
443 .wait();
444
445 assert!(matches!(result, Err(StreamError::Failed(_))));
446 }
447
448 #[test]
449 fn inflate_fails_on_truncated_input() {
450 let compressed = collect_chunks(Compression::deflate());
451 let mut truncated = compressed.concat();
452 truncated.truncate(truncated.len() / 2);
453
454 let result = Source::single(truncated)
455 .via(Compression::inflate())
456 .run_with(crate::Sink::collect())
457 .expect("inflate materializes")
458 .wait();
459
460 assert!(matches!(result, Err(StreamError::Failed(_))));
461 }
462}