use crate::encode::ZstdEncoder;
use crate::frame::{ZstdDecoder, decompress_multi_frame};
use std::io::{self, Read, Write};
const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
pub struct ZstdStreamEncoder<W: Write> {
inner: Option<W>,
buffer: Vec<u8>,
level: i32,
dict: Option<Vec<u8>>,
finished: bool,
block_size: usize,
}
impl<W: Write> ZstdStreamEncoder<W> {
pub fn new(writer: W, level: i32) -> Self {
Self {
inner: Some(writer),
buffer: Vec::new(),
level,
dict: None,
finished: false,
block_size: DEFAULT_BLOCK_SIZE,
}
}
pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
Self {
inner: Some(writer),
buffer: Vec::new(),
level,
dict: Some(dict),
finished: false,
block_size: DEFAULT_BLOCK_SIZE,
}
}
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size.max(1);
self
}
pub fn finish(mut self) -> io::Result<W> {
if !self.finished {
self.flush_buffer_unconditional()?;
self.finished = true;
}
self.inner
.take()
.ok_or_else(|| io::Error::other("inner writer already taken"))
}
fn compress_and_write(&mut self, data: &[u8]) -> io::Result<()> {
let mut encoder = ZstdEncoder::new();
encoder.set_level(self.level);
if let Some(ref dict) = self.dict {
encoder.set_dictionary(dict);
}
let compressed = encoder
.compress(data)
.map_err(|e| io::Error::other(e.to_string()))?;
if let Some(ref mut w) = self.inner {
w.write_all(&compressed)?;
}
Ok(())
}
fn maybe_flush_block(&mut self) -> io::Result<()> {
if self.buffer.len() >= self.block_size {
let data = std::mem::take(&mut self.buffer);
self.compress_and_write(&data)?;
}
Ok(())
}
fn flush_buffer_unconditional(&mut self) -> io::Result<()> {
let data = std::mem::take(&mut self.buffer);
self.compress_and_write(&data)
}
pub fn buffered_bytes(&self) -> usize {
self.buffer.len()
}
pub fn is_finished(&self) -> bool {
self.finished
}
}
impl<W: Write> Write for ZstdStreamEncoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.finished {
return Err(io::Error::other("encoder already finished"));
}
self.buffer.extend_from_slice(buf);
self.maybe_flush_block()?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
let data = std::mem::take(&mut self.buffer);
self.compress_and_write(&data)?;
}
if let Some(ref mut w) = self.inner {
w.flush()?;
}
Ok(())
}
}
pub struct ZstdStreamDecoder<R: Read> {
inner: R,
output_buffer: Vec<u8>,
output_pos: usize,
finished: bool,
dict: Option<Vec<u8>>,
}
impl<R: Read> ZstdStreamDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
inner: reader,
output_buffer: Vec::new(),
output_pos: 0,
finished: false,
dict: None,
}
}
pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
Self {
inner: reader,
output_buffer: Vec::new(),
output_pos: 0,
finished: false,
dict: if dict.is_empty() { None } else { Some(dict) },
}
}
fn fill_buffer(&mut self) -> io::Result<()> {
if self.finished || self.output_pos < self.output_buffer.len() {
return Ok(());
}
let mut compressed = Vec::new();
self.inner.read_to_end(&mut compressed)?;
if compressed.is_empty() {
self.finished = true;
return Ok(());
}
self.output_buffer = if self.dict.is_none() {
decompress_multi_frame(&compressed)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
} else {
let mut decoder = ZstdDecoder::new();
if let Some(ref dict) = self.dict {
decoder.set_dictionary(dict);
}
decoder
.decode_frame(&compressed)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
};
self.output_pos = 0;
self.finished = true;
Ok(())
}
pub fn decompressed_size(&self) -> usize {
self.output_buffer.len()
}
pub fn is_finished(&self) -> bool {
self.finished && self.output_pos >= self.output_buffer.len()
}
}
impl<R: Read> Read for ZstdStreamDecoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.fill_buffer()?;
let available = self.output_buffer.len() - self.output_pos;
if available == 0 {
return Ok(0);
}
let to_copy = buf.len().min(available);
buf[..to_copy]
.copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
self.output_pos += to_copy;
Ok(to_copy)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_encoder_basic() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(b"Hello, Zstandard!").unwrap();
let compressed = encoder.finish().unwrap();
assert!(!compressed.is_empty());
}
#[test]
fn test_stream_encoder_empty() {
let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
let compressed = encoder.finish().unwrap();
assert!(!compressed.is_empty());
}
#[test]
fn test_stream_roundtrip() {
let original = b"The quick brown fox jumps over the lazy dog.";
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, original.as_slice());
}
#[test]
fn test_stream_roundtrip_multiple_writes() {
let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
for part in parts {
encoder.write_all(part).unwrap();
}
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, b"Hello, streaming Zstd!");
}
#[test]
fn test_stream_decoder_small_reads() {
let original = b"ABCDEFGHIJ";
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
let mut buf = [0u8; 3];
loop {
let n = decoder.read(&mut buf).unwrap();
if n == 0 {
break;
}
output.extend_from_slice(&buf[..n]);
}
assert_eq!(output, original.as_slice());
}
#[test]
fn test_stream_decoder_empty_input() {
let mut decoder = ZstdStreamDecoder::new(&[][..]);
let mut buf = [0u8; 16];
let n = decoder.read(&mut buf).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_stream_encoder_with_dictionary() {
let dict = b"common pattern data".to_vec();
let mut encoder = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict);
encoder.write_all(b"test data").unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, b"test data");
}
#[test]
fn test_stream_encoder_buffered_bytes() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
assert_eq!(encoder.buffered_bytes(), 0);
encoder.write_all(b"12345").unwrap();
assert_eq!(encoder.buffered_bytes(), 5);
encoder.write_all(b"67890").unwrap();
assert_eq!(encoder.buffered_bytes(), 10);
}
#[test]
fn test_stream_encoder_is_finished() {
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
assert!(!encoder.is_finished());
encoder.write_all(b"data").unwrap();
assert!(!encoder.is_finished());
}
#[test]
fn test_stream_decoder_is_finished() {
let original = b"short";
let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
enc.write_all(original).unwrap();
let compressed = enc.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
assert!(!decoder.is_finished());
let mut out = Vec::new();
decoder.read_to_end(&mut out).unwrap();
assert!(decoder.is_finished());
}
#[test]
fn test_stream_roundtrip_large_data() {
let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
encoder.write_all(&original).unwrap();
let compressed = encoder.finish().unwrap();
let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
let mut output = Vec::new();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(output, original);
}
}