use crate::stream::{BoxStream, Flow};
use crate::{StreamError, StreamResult};
use flate2::Compression as FlateCompression;
use flate2::write::{GzDecoder, GzEncoder, ZlibEncoder};
use flate2::{Decompress, FlushDecompress, Status};
use std::collections::VecDeque;
use std::io::Write;
const DECOMPRESS_CHUNK_SIZE: usize = 8192;
#[derive(Clone)]
enum Terminal {
Complete,
Error(StreamError),
}
fn sticky_terminal<T>(terminal: &Terminal) -> Option<StreamResult<T>> {
match terminal {
Terminal::Complete => None,
Terminal::Error(error) => Some(Err(error.clone())),
}
}
fn codec_error<E: std::fmt::Display>(error: E) -> StreamError {
StreamError::Failed(error.to_string())
}
pub struct Compression;
impl Compression {
#[must_use]
pub fn gzip() -> Flow<Vec<u8>, Vec<u8>> {
Flow::from_transform(|input| Box::new(CompressStream::gzip(input)) as BoxStream<Vec<u8>>)
}
#[must_use]
pub fn deflate() -> Flow<Vec<u8>, Vec<u8>> {
Flow::from_transform(|input| Box::new(CompressStream::deflate(input)) as BoxStream<Vec<u8>>)
}
#[must_use]
pub fn gunzip() -> Flow<Vec<u8>, Vec<u8>> {
Flow::from_transform(|input| {
Box::new(DecompressStream::gunzip(input)) as BoxStream<Vec<u8>>
})
}
#[must_use]
pub fn inflate() -> Flow<Vec<u8>, Vec<u8>> {
Flow::from_transform(|input| Box::new(InflateStream::new(input)) as BoxStream<Vec<u8>>)
}
}
enum EncoderKind {
Gzip(GzEncoder<Vec<u8>>),
Deflate(ZlibEncoder<Vec<u8>>),
}
impl EncoderKind {
fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
match self {
Self::Gzip(codec) => codec.write_all(chunk),
Self::Deflate(codec) => codec.write_all(chunk),
}
}
fn try_finish(&mut self) -> std::io::Result<()> {
match self {
Self::Gzip(codec) => codec.try_finish(),
Self::Deflate(codec) => codec.try_finish(),
}
}
fn take_output(&mut self) -> Vec<u8> {
match self {
Self::Gzip(codec) => std::mem::take(codec.get_mut()),
Self::Deflate(codec) => std::mem::take(codec.get_mut()),
}
}
}
struct CompressStream {
input: BoxStream<Vec<u8>>,
codec: EncoderKind,
pending: VecDeque<Vec<u8>>,
finished: bool,
terminal: Option<Terminal>,
}
impl CompressStream {
fn gzip(input: BoxStream<Vec<u8>>) -> Self {
Self {
input,
codec: EncoderKind::Gzip(GzEncoder::new(Vec::new(), FlateCompression::default())),
pending: VecDeque::new(),
finished: false,
terminal: None,
}
}
fn deflate(input: BoxStream<Vec<u8>>) -> Self {
Self {
input,
codec: EncoderKind::Deflate(ZlibEncoder::new(Vec::new(), FlateCompression::default())),
pending: VecDeque::new(),
finished: false,
terminal: None,
}
}
fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
self.terminal = Some(Terminal::Error(error.clone()));
Some(Err(error))
}
fn harvest_output(&mut self) {
let output = self.codec.take_output();
if !output.is_empty() {
self.pending.push_back(output);
}
}
}
impl Iterator for CompressStream {
type Item = StreamResult<Vec<u8>>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
if let Some(terminal) = &self.terminal {
return sticky_terminal(terminal);
}
loop {
if self.finished {
self.terminal = Some(Terminal::Complete);
return None;
}
match self.input.next() {
Some(Ok(chunk)) => {
if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
return self.fail(error);
}
self.harvest_output();
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
}
Some(Err(error)) => {
self.terminal = Some(Terminal::Error(error.clone()));
return Some(Err(error));
}
None => {
if let Err(error) = self.codec.try_finish().map_err(codec_error) {
return self.fail(error);
}
self.finished = true;
self.harvest_output();
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
}
}
}
}
}
enum DecoderKind {
Gzip(GzDecoder<Vec<u8>>),
}
impl DecoderKind {
fn write_all(&mut self, chunk: &[u8]) -> std::io::Result<()> {
match self {
Self::Gzip(codec) => codec.write_all(chunk),
}
}
fn try_finish(&mut self) -> std::io::Result<()> {
match self {
Self::Gzip(codec) => codec.try_finish(),
}
}
fn take_output(&mut self) -> Vec<u8> {
match self {
Self::Gzip(codec) => std::mem::take(codec.get_mut()),
}
}
}
struct DecompressStream {
input: BoxStream<Vec<u8>>,
codec: DecoderKind,
pending: VecDeque<Vec<u8>>,
finished: bool,
terminal: Option<Terminal>,
}
impl DecompressStream {
fn gunzip(input: BoxStream<Vec<u8>>) -> Self {
Self {
input,
codec: DecoderKind::Gzip(GzDecoder::new(Vec::new())),
pending: VecDeque::new(),
finished: false,
terminal: None,
}
}
fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
self.terminal = Some(Terminal::Error(error.clone()));
Some(Err(error))
}
fn harvest_output(&mut self) {
let output = self.codec.take_output();
if !output.is_empty() {
self.pending.push_back(output);
}
}
}
struct InflateStream {
input: BoxStream<Vec<u8>>,
codec: Decompress,
pending: VecDeque<Vec<u8>>,
finished: bool,
terminal: Option<Terminal>,
}
impl InflateStream {
fn new(input: BoxStream<Vec<u8>>) -> Self {
Self {
input,
codec: Decompress::new(true),
pending: VecDeque::new(),
finished: false,
terminal: None,
}
}
fn fail<T>(&mut self, error: StreamError) -> Option<StreamResult<T>> {
self.terminal = Some(Terminal::Error(error.clone()));
Some(Err(error))
}
fn pump(&mut self, mut remaining: &[u8], flush: FlushDecompress) -> StreamResult<bool> {
loop {
let before_in = self.codec.total_in();
let before_out = self.codec.total_out();
let mut output = vec![0_u8; DECOMPRESS_CHUNK_SIZE];
let status = self
.codec
.decompress(remaining, &mut output, flush)
.map_err(codec_error)?;
let consumed = (self.codec.total_in() - before_in) as usize;
let produced = (self.codec.total_out() - before_out) as usize;
output.truncate(produced);
if !output.is_empty() {
output.shrink_to_fit();
self.pending.push_back(output);
}
remaining = &remaining[consumed..];
if matches!(status, Status::StreamEnd) {
return Ok(true);
}
if consumed == 0 && produced == 0 {
return Ok(false);
}
if remaining.is_empty() && !matches!(flush, FlushDecompress::Finish) {
return Ok(false);
}
}
}
}
impl Iterator for InflateStream {
type Item = StreamResult<Vec<u8>>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
if let Some(terminal) = &self.terminal {
return sticky_terminal(terminal);
}
if self.finished {
self.terminal = Some(Terminal::Complete);
return None;
}
loop {
match self.input.next() {
Some(Ok(chunk)) => match self.pump(&chunk, FlushDecompress::None) {
Ok(done) => {
if done {
self.finished = true;
}
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
if self.finished {
self.terminal = Some(Terminal::Complete);
return None;
}
}
Err(error) => return self.fail(error),
},
Some(Err(error)) => {
self.terminal = Some(Terminal::Error(error.clone()));
return Some(Err(error));
}
None => match self.pump(&[], FlushDecompress::Finish) {
Ok(true) => {
self.finished = true;
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
self.terminal = Some(Terminal::Complete);
return None;
}
Ok(false) => {
return self.fail(StreamError::Failed(
"truncated compressed stream".to_owned(),
));
}
Err(error) => return self.fail(error),
},
}
}
}
}
impl Iterator for DecompressStream {
type Item = StreamResult<Vec<u8>>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
if let Some(terminal) = &self.terminal {
return sticky_terminal(terminal);
}
if self.finished {
self.terminal = Some(Terminal::Complete);
return None;
}
loop {
match self.input.next() {
Some(Ok(chunk)) => {
if let Err(error) = self.codec.write_all(&chunk).map_err(codec_error) {
return self.fail(error);
}
self.harvest_output();
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
}
Some(Err(error)) => {
self.terminal = Some(Terminal::Error(error.clone()));
return Some(Err(error));
}
None => match self.codec.try_finish().map_err(codec_error) {
Ok(()) => {
self.finished = true;
self.harvest_output();
if let Some(chunk) = self.pending.pop_front() {
return Some(Ok(chunk));
}
}
Err(error) => return self.fail(error),
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Source;
fn collect_chunks(flow: Flow<Vec<u8>, Vec<u8>>) -> Vec<Vec<u8>> {
Source::from_iter([b"hello ".to_vec(), b"world".to_vec()])
.via(flow)
.run_with(crate::Sink::collect())
.expect("flow materializes")
.wait()
.expect("flow completes")
}
#[test]
fn gzip_and_gunzip_round_trip() {
let compressed = collect_chunks(Compression::gzip());
let decoded = Source::from_iter(compressed)
.via(Compression::gunzip())
.run_with(crate::Sink::collect())
.expect("gunzip materializes")
.wait()
.expect("gunzip completes");
assert_eq!(decoded.concat(), b"hello world");
}
#[test]
fn deflate_and_inflate_round_trip() {
let compressed = collect_chunks(Compression::deflate());
let decoded = Source::from_iter(compressed)
.via(Compression::inflate())
.run_with(crate::Sink::collect())
.expect("inflate materializes")
.wait()
.expect("inflate completes");
assert_eq!(decoded.concat(), b"hello world");
}
#[test]
fn gunzip_fails_on_truncated_input() {
let compressed = collect_chunks(Compression::gzip());
let mut truncated = compressed.concat();
truncated.truncate(truncated.len().saturating_sub(2));
let result = Source::single(truncated)
.via(Compression::gunzip())
.run_with(crate::Sink::collect())
.expect("gunzip materializes")
.wait();
assert!(matches!(result, Err(StreamError::Failed(_))));
}
#[test]
fn inflate_fails_on_truncated_input() {
let compressed = collect_chunks(Compression::deflate());
let mut truncated = compressed.concat();
truncated.truncate(truncated.len() / 2);
let result = Source::single(truncated)
.via(Compression::inflate())
.run_with(crate::Sink::collect())
.expect("inflate materializes")
.wait();
assert!(matches!(result, Err(StreamError::Failed(_))));
}
}