use flate2::write::GzDecoder;
use std::io::{self, Write};
enum WriterState<W: Write> {
Detecting(W),
Plain(W),
Gzipped(GzDecoder<W>),
Consumed,
}
pub struct AutoDecompressWriter<W: Write> {
state: WriterState<W>,
}
impl<W: Write> AutoDecompressWriter<W> {
pub fn new(inner: W) -> Self {
Self {
state: WriterState::Detecting(inner),
}
}
pub fn finish(self) -> io::Result<W> {
match self.state {
WriterState::Detecting(inner) => Ok(inner),
WriterState::Plain(inner) => Ok(inner),
WriterState::Gzipped(decoder) => decoder.finish(),
WriterState::Consumed => Err(io::Error::new(
io::ErrorKind::Other,
"Writer already consumed",
)),
}
}
pub fn is_gzipped(&self) -> Option<bool> {
match &self.state {
WriterState::Detecting(_) => None,
WriterState::Plain(_) => Some(false),
WriterState::Gzipped(_) => Some(true),
WriterState::Consumed => None,
}
}
pub fn get_ref(&self) -> Option<&W> {
match &self.state {
WriterState::Plain(inner) => Some(inner),
WriterState::Detecting(inner) => Some(inner),
WriterState::Gzipped(decoder) => Some(decoder.get_ref()),
WriterState::Consumed => None,
}
}
}
impl<W: Write> Write for AutoDecompressWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if let WriterState::Detecting(_) = &self.state {
if !buf.is_empty() {
let inner = match std::mem::replace(&mut self.state, WriterState::Consumed) {
WriterState::Detecting(w) => w,
_ => unreachable!(),
};
let is_gzipped = buf.len() >= 2 && buf[0] == 0x1f && buf[1] == 0x8b;
if is_gzipped {
self.state = WriterState::Gzipped(GzDecoder::new(inner));
} else {
self.state = WriterState::Plain(inner);
}
}
}
match &mut self.state {
WriterState::Detecting(_) => Ok(0), WriterState::Plain(inner) => inner.write(buf),
WriterState::Gzipped(decoder) => decoder.write(buf),
WriterState::Consumed => Err(io::Error::new(
io::ErrorKind::Other,
"Writer already consumed",
)),
}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.state {
WriterState::Detecting(inner) => inner.flush(),
WriterState::Plain(inner) => inner.flush(),
WriterState::Gzipped(decoder) => decoder.flush(),
WriterState::Consumed => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::Compression;
use flate2::write::GzEncoder;
#[test]
fn test_plain_passthrough() {
let mut output = Vec::new();
{
let mut writer = AutoDecompressWriter::new(&mut output);
writer.write_all(b"Hello, world!").unwrap();
writer.finish().unwrap();
}
assert_eq!(&output, b"Hello, world!");
}
#[test]
fn test_gzip_decompression() {
let original = b"Hello, gzip world!";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut output = Vec::new();
{
let mut writer = AutoDecompressWriter::new(&mut output);
writer.write_all(&compressed).unwrap();
writer.finish().unwrap();
}
assert_eq!(&output, original);
}
#[test]
fn test_gzip_chunked() {
let original = b"Hello, this is a longer message for chunked testing!";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut output = Vec::new();
{
let mut writer = AutoDecompressWriter::new(&mut output);
for chunk in compressed.chunks(5) {
writer.write_all(chunk).unwrap();
}
writer.finish().unwrap();
}
assert_eq!(&output, original);
}
#[test]
fn test_is_gzipped_detection() {
let mut output = Vec::new();
let mut writer = AutoDecompressWriter::new(&mut output);
assert_eq!(writer.is_gzipped(), None); writer.write_all(b"plain").unwrap();
assert_eq!(writer.is_gzipped(), Some(false));
drop(writer);
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"data").unwrap();
let compressed = encoder.finish().unwrap();
let mut output = Vec::new();
let mut writer = AutoDecompressWriter::new(&mut output);
writer.write_all(&compressed).unwrap();
assert_eq!(writer.is_gzipped(), Some(true));
}
#[test]
fn test_empty_write() {
let mut output = Vec::new();
let mut writer = AutoDecompressWriter::new(&mut output);
writer.write_all(b"").unwrap(); assert_eq!(writer.is_gzipped(), None); writer.write_all(b"data").unwrap();
assert_eq!(writer.is_gzipped(), Some(false));
let _ = writer.finish();
assert_eq!(&output, b"data");
}
#[test]
fn test_single_byte_non_gzip() {
let mut output = Vec::new();
{
let mut writer = AutoDecompressWriter::new(&mut output);
writer.write_all(b"X").unwrap();
writer.finish().unwrap();
}
assert_eq!(&output, b"X");
}
#[test]
fn test_get_ref() {
let mut output = Vec::new();
let writer = AutoDecompressWriter::new(&mut output);
assert!(writer.get_ref().is_some());
}
}