use crate::encode::ZstdEncoder;
use crate::frame::ZstdDecoder;
use std::io::{self, Read, Write};
pub struct ZstdStreamEncoder<W: Write> {
inner: Option<W>,
buffer: Vec<u8>,
level: i32,
dict: Option<Vec<u8>>,
finished: bool,
}
impl<W: Write> ZstdStreamEncoder<W> {
pub fn new(writer: W, level: i32) -> Self {
Self {
inner: Some(writer),
buffer: Vec::new(),
level,
dict: None,
finished: false,
}
}
pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
Self {
inner: Some(writer),
buffer: Vec::new(),
level,
dict: Some(dict),
finished: false,
}
}
pub fn finish(mut self) -> io::Result<W> {
if !self.finished {
self.flush_buffer()?;
self.finished = true;
}
Ok(self
.inner
.take()
.expect("inner writer should still be present"))
}
fn flush_buffer(&mut self) -> io::Result<()> {
let data = if self.buffer.is_empty() {
vec![]
} else {
std::mem::take(&mut self.buffer)
};
let mut encoder = ZstdEncoder::new();
encoder.set_level(self.level);
if let Some(ref dict) = self.dict {
encoder.set_dictionary(dict);
}
let compressed = encoder
.compress(&data)
.map_err(|e| io::Error::other(e.to_string()))?;
if let Some(ref mut w) = self.inner {
w.write_all(&compressed)?;
}
Ok(())
}
pub fn buffered_bytes(&self) -> usize {
self.buffer.len()
}
pub fn is_finished(&self) -> bool {
self.finished
}
}
impl<W: Write> Write for ZstdStreamEncoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.finished {
return Err(io::Error::other("encoder already finished"));
}
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct ZstdStreamDecoder<R: Read> {
inner: R,
output_buffer: Vec<u8>,
output_pos: usize,
finished: bool,
dict: Option<Vec<u8>>,
}
impl<R: Read> ZstdStreamDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
inner: reader,
output_buffer: Vec::new(),
output_pos: 0,
finished: false,
dict: None,
}
}
pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
Self {
inner: reader,
output_buffer: Vec::new(),
output_pos: 0,
finished: false,
dict: if dict.is_empty() { None } else { Some(dict) },
}
}
fn fill_buffer(&mut self) -> io::Result<()> {
if self.finished || self.output_pos < self.output_buffer.len() {
return Ok(());
}
let mut compressed = Vec::new();
self.inner.read_to_end(&mut compressed)?;
if compressed.is_empty() {
self.finished = true;
return Ok(());
}
let mut decoder = ZstdDecoder::new();
if let Some(ref dict) = self.dict {
decoder.set_dictionary(dict);
}
self.output_buffer = decoder
.decode_frame(&compressed)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.output_pos = 0;
self.finished = true;
Ok(())
}
pub fn decompressed_size(&self) -> usize {
self.output_buffer.len()
}
pub fn is_finished(&self) -> bool {
self.finished && self.output_pos >= self.output_buffer.len()
}
}
impl<R: Read> Read for ZstdStreamDecoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.fill_buffer()?;
let available = self.output_buffer.len() - self.output_pos;
if available == 0 {
return Ok(0);
}
let to_copy = buf.len().min(available);
buf[..to_copy]
.copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
self.output_pos += to_copy;
Ok(to_copy)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_encoder_basic() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(b"Hello, Zstandard!").unwrap();
let compressed = encoder.finish().unwrap();
assert!(!compressed.is_empty());
}
#[test]
fn test_stream_encoder_empty() {
let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
let compressed = encoder.finish().unwrap();
assert!(!compressed.is_empty());
}
#[test]
fn test_stream_roundtrip() {
let original = b"The quick brown fox jumps over the lazy dog.";
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, original.as_slice());
}
#[test]
fn test_stream_roundtrip_multiple_writes() {
let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
for part in parts {
encoder.write_all(part).unwrap();
}
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, b"Hello, streaming Zstd!");
}
#[test]
fn test_stream_decoder_small_reads() {
let original = b"ABCDEFGHIJ";
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
let mut buf = [0u8; 3];
loop {
let n = decoder.read(&mut buf).unwrap();
if n == 0 {
break;
}
output.extend_from_slice(&buf[..n]);
}
assert_eq!(output, original.as_slice());
}
#[test]
fn test_stream_decoder_empty_input() {
let mut decoder = ZstdStreamDecoder::new(&[][..]);
let mut buf = [0u8; 16];
let n = decoder.read(&mut buf).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_stream_encoder_with_dictionary() {
let dict = b"common pattern data".to_vec();
let mut encoder = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict);
encoder.write_all(b"test data").unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, b"test data");
}
#[test]
fn test_stream_encoder_buffered_bytes() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
assert_eq!(encoder.buffered_bytes(), 0);
encoder.write_all(b"12345").unwrap();
assert_eq!(encoder.buffered_bytes(), 5);
encoder.write_all(b"67890").unwrap();
assert_eq!(encoder.buffered_bytes(), 10);
}
#[test]
fn test_stream_encoder_is_finished() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
assert!(!encoder.is_finished());
encoder.write_all(b"data").unwrap();
assert!(!encoder.is_finished());
}
#[test]
fn test_stream_decoder_is_finished() {
let original = b"short";
let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
enc.write_all(original).unwrap();
let compressed = enc.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
assert!(!decoder.is_finished());
let mut out = Vec::new();
decoder.read_to_end(&mut out).unwrap();
assert!(decoder.is_finished());
}
#[test]
fn test_stream_roundtrip_large_data() {
let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(&original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, original);
}
}