use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use flate2::{Compress, Compression, FlushCompress, Status};
use tokio::io::AsyncWrite;
pub(crate) struct DeflateEncoder<W: AsyncWrite + Unpin> {
inner: W,
compress: Compress,
out_buf: Vec<u8>,
out_pos: usize,
out_len: usize,
finished: bool,
}
impl<W: AsyncWrite + Unpin> DeflateEncoder<W> {
pub(crate) fn new(inner: W, level: Compression) -> Self {
Self {
inner,
compress: Compress::new(level, false),
out_buf: vec![0u8; 8192],
out_pos: 0,
out_len: 0,
finished: false,
}
}
#[allow(dead_code)]
pub(crate) fn get_ref(&self) -> &W {
&self.inner
}
pub(crate) fn into_inner(self) -> W {
self.inner
}
fn poll_drain(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
while self.out_pos < self.out_len {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.inner)
.poll_write(cx, &this.out_buf[this.out_pos..this.out_len])
{
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"underlying writer returned 0 bytes",
)));
}
Poll::Ready(Ok(n)) => this.out_pos += n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let this = self.as_mut().get_mut();
this.out_pos = 0;
this.out_len = 0;
Poll::Ready(Ok(()))
}
fn do_compress(&mut self, input: &[u8], flush: FlushCompress) -> io::Result<(usize, Status)> {
let before_in = self.compress.total_in();
let before_out = self.compress.total_out();
let status = self
.compress
.compress(input, &mut self.out_buf, flush)
.map_err(io::Error::other)?;
let consumed = (self.compress.total_in() - before_in) as usize;
let produced = (self.compress.total_out() - before_out) as usize;
self.out_pos = 0;
self.out_len = produced;
if matches!(status, Status::BufError) && consumed == 0 && produced == 0 {
return Err(io::Error::other("flate2 BufError with no progress"));
}
Ok((consumed, status))
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for DeflateEncoder<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.finished {
return if buf.is_empty() {
Poll::Ready(Ok(0))
} else {
Poll::Ready(Err(io::Error::other("write after shutdown")))
};
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if self.out_pos < self.out_len {
match self.as_mut().poll_drain(cx) {
Poll::Ready(Ok(())) => {}
other => return other.map(|r| r.map(|_| 0)),
}
}
let this = self.as_mut().get_mut();
let (consumed, _status) = this.do_compress(buf, FlushCompress::None)?;
if this.out_pos < this.out_len {
let _ = self.as_mut().poll_drain(cx);
}
Poll::Ready(Ok(consumed))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.finished {
return Poll::Ready(Ok(()));
}
if self.out_pos < self.out_len {
match self.as_mut().poll_drain(cx) {
Poll::Ready(Ok(())) => {}
other => return other,
}
}
loop {
let this = self.as_mut().get_mut();
let (_, status) = this.do_compress(&[], FlushCompress::Sync)?;
if this.out_pos < this.out_len {
match self.as_mut().poll_drain(cx) {
Poll::Ready(Ok(())) => {}
other => return other,
}
}
match status {
Status::Ok | Status::StreamEnd => break,
Status::BufError => continue,
}
}
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if !self.finished {
if self.out_pos < self.out_len {
match self.as_mut().poll_drain(cx) {
Poll::Ready(Ok(())) => {}
other => return other,
}
}
loop {
let this = self.as_mut().get_mut();
let (_, status) = this.do_compress(&[], FlushCompress::Finish)?;
if this.out_pos < this.out_len {
match self.as_mut().poll_drain(cx) {
Poll::Ready(Ok(())) => {}
other => return other,
}
}
match status {
Status::StreamEnd => break,
Status::Ok | Status::BufError => continue,
}
}
self.get_mut().finished = true;
}
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn compress(data: &[u8], level: Compression) -> Vec<u8> {
let mut buf = Vec::new();
let mut encoder = DeflateEncoder::new(&mut buf, level);
tokio::io::AsyncWriteExt::write_all(&mut encoder, data)
.await
.unwrap();
tokio::io::AsyncWriteExt::shutdown(&mut encoder)
.await
.unwrap();
drop(encoder);
buf
}
#[tokio::test]
async fn test_encoder_produces_valid_deflate() {
let data = b"Hello, World! This is a test of the deflate encoder.";
let compressed = compress(data, Compression::default()).await;
let mut decompressor = flate2::Decompress::new(false);
let mut raw_out = [0u8; 8192];
let (mut in_pos, mut out_len) = (0, 0);
loop {
let in_bytes = &compressed[in_pos..];
let out_bytes = &mut raw_out[out_len..];
let result = decompressor
.decompress(in_bytes, out_bytes, flate2::FlushDecompress::Finish)
.unwrap();
in_pos = decompressor.total_in() as usize;
out_len = decompressor.total_out() as usize;
match result {
flate2::Status::StreamEnd => break,
flate2::Status::Ok | flate2::Status::BufError => continue,
}
}
assert_eq!(
&raw_out[..out_len],
data,
"round-trip should produce original data"
);
}
#[tokio::test]
async fn test_encoder_compresses_repeated_data() {
let data = vec![b'A'; 4096];
let compressed = compress(&data, Compression::best()).await;
assert!(
compressed.len() < data.len(),
"compressed size {} should be less than uncompressed {}",
compressed.len(),
data.len()
);
}
#[tokio::test]
async fn test_encoder_no_compression_level_0() {
let data = vec![b'A'; 1024];
let compressed = compress(&data, Compression::none()).await;
assert!(
compressed.len() >= data.len(),
"level 0 should not compress (got {} < {})",
compressed.len(),
data.len()
);
}
#[tokio::test]
async fn test_encoder_empty_input() {
let data = b"";
let compressed = compress(data, Compression::default()).await;
assert!(
!compressed.is_empty(),
"empty input should produce deflate end marker"
);
let mut decompressor = flate2::Decompress::new(false);
let mut raw_out = [0u8; 8192];
let (mut in_pos, mut out_len) = (0, 0);
loop {
let in_bytes = &compressed[in_pos..];
let out_bytes = &mut raw_out[out_len..];
let result = decompressor
.decompress(in_bytes, out_bytes, flate2::FlushDecompress::Finish)
.unwrap();
in_pos = decompressor.total_in() as usize;
out_len = decompressor.total_out() as usize;
match result {
flate2::Status::StreamEnd => break,
flate2::Status::Ok | flate2::Status::BufError => continue,
}
}
assert!(out_len == 0, "decompressed should be empty");
}
#[tokio::test]
async fn test_encoder_large_data() {
let data: Vec<u8> = (0..100_000u32).map(|i| (i % 256) as u8).collect();
let compressed = compress(&data, Compression::default()).await;
assert!(
compressed.len() < data.len(),
"100KB of cyclic data should compress"
);
let mut decompressor = flate2::Decompress::new(false);
let mut raw_out = vec![0u8; data.len() + 8192];
let (mut in_pos, mut out_len) = (0, 0);
loop {
let in_bytes = &compressed[in_pos..];
let out_bytes = &mut raw_out[out_len..];
let result = decompressor
.decompress(in_bytes, out_bytes, flate2::FlushDecompress::Finish)
.unwrap();
in_pos = decompressor.total_in() as usize;
out_len = decompressor.total_out() as usize;
match result {
flate2::Status::StreamEnd => break,
flate2::Status::Ok | flate2::Status::BufError => continue,
}
}
assert_eq!(
&raw_out[..out_len],
&data[..],
"round-trip should match for large data"
);
}
#[tokio::test]
async fn test_shutdown_does_not_propagate_to_inner() {
struct ShutdownTracker {
data: Vec<u8>,
shutdown_called: bool,
}
impl AsyncWrite for ShutdownTracker {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.data.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.shutdown_called = true;
Poll::Ready(Ok(()))
}
}
let tracker = ShutdownTracker {
data: Vec::new(),
shutdown_called: false,
};
let mut encoder = DeflateEncoder::new(tracker, Compression::default());
tokio::io::AsyncWriteExt::write_all(&mut encoder, b"test data")
.await
.unwrap();
tokio::io::AsyncWriteExt::shutdown(&mut encoder)
.await
.unwrap();
assert!(
!encoder.get_ref().shutdown_called,
"encoder should not propagate shutdown to inner"
);
assert!(
!encoder.get_ref().data.is_empty(),
"compressed data should be available"
);
}
#[tokio::test]
async fn test_into_inner_after_shutdown() {
let mut buf = Vec::new();
let mut encoder = DeflateEncoder::new(&mut buf, Compression::default());
tokio::io::AsyncWriteExt::write_all(&mut encoder, b"hello")
.await
.unwrap();
tokio::io::AsyncWriteExt::shutdown(&mut encoder)
.await
.unwrap();
let inner = encoder.into_inner();
assert!(
!inner.is_empty(),
"compressed data should have been written before shutdown"
);
}
}