1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use super::Status;
use miniz_oxide::{deflate, deflate::core::CompressorOxide, MZError, MZFlush, MZStatus};
use quick_error::quick_error;
use std::io;

quick_error! {
    #[derive(Debug)]
    pub enum Error {
        Compression {
            display("The compression failed due to an unknown error")
        }
        ZLibNeedDict {
            display("Need dictionary")
        }
        Error(err: MZError) {
            display("A compression error occurred: {:?}", err)
        }
    }
}

pub struct Deflate {
    inner: CompressorOxide,
    total_in: u64,
    total_out: u64,
}

impl Default for Deflate {
    fn default() -> Self {
        Deflate {
            inner: CompressorOxide::default(),
            total_in: 0,
            total_out: 0,
        }
    }
}

impl Deflate {
    fn compress(&mut self, input: &[u8], output: &mut [u8], flush: MZFlush) -> Result<Status, Error> {
        let res = deflate::stream::deflate(&mut self.inner, input, output, flush);
        self.total_in += res.bytes_consumed as u64;
        self.total_out += res.bytes_written as u64;

        match res.status {
            Ok(status) => match status {
                MZStatus::Ok => Ok(Status::Ok),
                MZStatus::StreamEnd => Ok(Status::StreamEnd),
                MZStatus::NeedDict => Err(Error::ZLibNeedDict),
            },
            Err(status) => match status {
                MZError::Buf => Ok(Status::BufError),
                _ => Err(Error::Error(status)),
            },
        }
    }
}

const BUF_SIZE: usize = 4096 * 8;
pub struct DeflateWriter<W> {
    compressor: Deflate,
    inner: W,
    buf: [u8; BUF_SIZE],
}

impl<W> DeflateWriter<W>
where
    W: io::Write,
{
    pub fn new(inner: W) -> DeflateWriter<W> {
        DeflateWriter {
            compressor: Default::default(),
            inner,
            buf: [0; BUF_SIZE],
        }
    }

    pub fn reset(&mut self) {
        self.compressor.inner.reset();
    }

    pub fn into_inner(self) -> W {
        self.inner
    }

    fn write_inner(&mut self, mut buf: &[u8], flush: MZFlush) -> io::Result<usize> {
        let total_in_when_start = self.compressor.total_in;
        loop {
            let last_total_in = self.compressor.total_in;
            let last_total_out = self.compressor.total_out;

            let status = self
                .compressor
                .compress(buf, &mut self.buf, flush)
                .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

            let written = self.compressor.total_out - last_total_out;
            if written > 0 {
                self.inner.write_all(&self.buf[..written as usize])?;
            }

            match status {
                Status::StreamEnd => return Ok((self.compressor.total_in - total_in_when_start) as usize),
                Status::Ok | Status::BufError => {
                    let consumed = self.compressor.total_in - last_total_in;
                    buf = &buf[consumed as usize..];

                    // output buffer still makes progress
                    if self.compressor.total_out > last_total_out {
                        continue;
                    }
                    // input still makes progress
                    if self.compressor.total_in > last_total_in {
                        continue;
                    }
                    // input also makes no progress anymore, need more so leave with what we have
                    return Ok((self.compressor.total_in - total_in_when_start) as usize);
                }
            }
        }
    }
}

impl<W: io::Write> io::Write for DeflateWriter<W> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.write_inner(buf, MZFlush::None)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.write_inner(&[], MZFlush::Finish).map(|_| ())
    }
}

#[cfg(test)]
mod tests;