Skip to main content

oxiarc_zstd/
streaming.rs

1//! Streaming compression and decompression for Zstandard.
2//!
3//! Provides [`ZstdStreamEncoder`] (implements [`std::io::Write`]) and
4//! [`ZstdStreamDecoder`] (implements [`std::io::Read`]) for processing Zstandard data
5//! through standard Rust I/O traits.
6//!
7//! The streaming encoder buffers all written data and compresses it into a
8//! single Zstandard frame when [`ZstdStreamEncoder::finish`] is called. This
9//! matches the behaviour of many Zstd wrapper crates that operate on in-memory
10//! buffers.
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use std::io::{Read, Write};
16//! use oxiarc_zstd::streaming::{ZstdStreamEncoder, ZstdStreamDecoder};
17//!
18//! // Compress
19//! let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
20//! encoder.write_all(b"Hello, streaming Zstd!").unwrap();
21//! let compressed = encoder.finish().unwrap();
22//!
23//! // Decompress
24//! let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
25//! let mut output = String::new();
26//! decoder.read_to_string(&mut output).unwrap();
27//! assert_eq!(output, "Hello, streaming Zstd!");
28//! ```
29
30use crate::encode::ZstdEncoder;
31use crate::frame::{ZstdDecoder, decompress_multi_frame};
32use std::io::{self, Read, Write};
33
34/// Default block size for the incremental encoder (128 KiB).
35const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
36
37/// Streaming Zstandard encoder that implements [`Write`].
38///
39/// Data written to this encoder is buffered internally.  When the internal
40/// buffer reaches `block_size` bytes it is automatically flushed as a
41/// complete Zstandard frame to the inner writer (truly incremental).  Any
42/// remaining data is flushed when [`finish`](ZstdStreamEncoder::finish) is
43/// called.
44///
45/// The output is a sequence of valid concatenated Zstandard frames and can be
46/// decoded with [`decompress_multi_frame`].
47///
48/// **Important:** you *must* call [`finish`](ZstdStreamEncoder::finish) to
49/// flush the final (possibly partial) block. Dropping the encoder without
50/// calling `finish` will silently discard any buffered data.
51pub struct ZstdStreamEncoder<W: Write> {
52    /// The wrapped writer that receives compressed output.
53    inner: Option<W>,
54    /// Internal buffer holding uncompressed data waiting to be flushed.
55    buffer: Vec<u8>,
56    /// Compression level used when encoding.
57    level: i32,
58    /// Optional pre-trained dictionary data.
59    dict: Option<Vec<u8>>,
60    /// Whether `finish` has already been called.
61    finished: bool,
62    /// Threshold at which the buffer is automatically flushed.
63    block_size: usize,
64}
65
66impl<W: Write> ZstdStreamEncoder<W> {
67    /// Create a new streaming encoder wrapping `writer`.
68    ///
69    /// The `level` parameter controls the compression level passed to the
70    /// underlying [`ZstdEncoder`].  The encoder uses a default block size of
71    /// 128 KiB; use [`with_block_size`](ZstdStreamEncoder::with_block_size)
72    /// to customise this.
73    pub fn new(writer: W, level: i32) -> Self {
74        Self {
75            inner: Some(writer),
76            buffer: Vec::new(),
77            level,
78            dict: None,
79            finished: false,
80            block_size: DEFAULT_BLOCK_SIZE,
81        }
82    }
83
84    /// Create a new streaming encoder with a pre-trained dictionary.
85    ///
86    /// Dictionary-based compression improves ratios for small payloads that
87    /// share common patterns.
88    pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
89        Self {
90            inner: Some(writer),
91            buffer: Vec::new(),
92            level,
93            dict: Some(dict),
94            finished: false,
95            block_size: DEFAULT_BLOCK_SIZE,
96        }
97    }
98
99    /// Set the block size used for incremental flushing.
100    ///
101    /// When the internal buffer reaches this many bytes it is automatically
102    /// compressed and written to the inner writer as a Zstandard frame.
103    pub fn with_block_size(mut self, block_size: usize) -> Self {
104        self.block_size = block_size.max(1);
105        self
106    }
107
108    /// Finish compression and return the inner writer.
109    ///
110    /// This **must** be called to flush the final compressed data. Failing to
111    /// call `finish` means all buffered data is lost.
112    ///
113    /// # Errors
114    ///
115    /// Returns an [`io::Error`] if compression or writing to the inner writer
116    /// fails.
117    pub fn finish(mut self) -> io::Result<W> {
118        if !self.finished {
119            // Flush whatever remains in the buffer (even if empty, to match
120            // the single-frame behaviour expected by existing tests).
121            self.flush_buffer_unconditional()?;
122            self.finished = true;
123        }
124        // inner is always Some until finish() is called once.
125        self.inner
126            .take()
127            .ok_or_else(|| io::Error::other("inner writer already taken"))
128    }
129
130    /// Compress `data` as a single Zstandard frame and write it to `inner`.
131    fn compress_and_write(&mut self, data: &[u8]) -> io::Result<()> {
132        let mut encoder = ZstdEncoder::new();
133        encoder.set_level(self.level);
134        if let Some(ref dict) = self.dict {
135            encoder.set_dictionary(dict);
136        }
137        let compressed = encoder
138            .compress(data)
139            .map_err(|e| io::Error::other(e.to_string()))?;
140        if let Some(ref mut w) = self.inner {
141            w.write_all(&compressed)?;
142        }
143        Ok(())
144    }
145
146    /// If the buffer has reached `block_size`, flush it as a frame.
147    fn maybe_flush_block(&mut self) -> io::Result<()> {
148        if self.buffer.len() >= self.block_size {
149            let data = std::mem::take(&mut self.buffer);
150            self.compress_and_write(&data)?;
151        }
152        Ok(())
153    }
154
155    /// Always flush the current buffer contents (even if empty) as a frame.
156    fn flush_buffer_unconditional(&mut self) -> io::Result<()> {
157        let data = std::mem::take(&mut self.buffer);
158        self.compress_and_write(&data)
159    }
160
161    /// Returns the number of uncompressed bytes currently buffered.
162    pub fn buffered_bytes(&self) -> usize {
163        self.buffer.len()
164    }
165
166    /// Returns `true` if `finish` has already been called.
167    pub fn is_finished(&self) -> bool {
168        self.finished
169    }
170}
171
172impl<W: Write> Write for ZstdStreamEncoder<W> {
173    /// Buffer `buf` and flush a frame to the inner writer whenever the
174    /// internal buffer reaches `block_size`.
175    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
176        if self.finished {
177            return Err(io::Error::other("encoder already finished"));
178        }
179        self.buffer.extend_from_slice(buf);
180        self.maybe_flush_block()?;
181        Ok(buf.len())
182    }
183
184    /// Flush any buffered data as a Zstandard frame to the inner writer.
185    fn flush(&mut self) -> io::Result<()> {
186        if !self.buffer.is_empty() {
187            let data = std::mem::take(&mut self.buffer);
188            self.compress_and_write(&data)?;
189        }
190        if let Some(ref mut w) = self.inner {
191            w.flush()?;
192        }
193        Ok(())
194    }
195}
196
197// ---------------------------------------------------------------------------
198// Streaming decoder
199// ---------------------------------------------------------------------------
200
201/// Streaming Zstandard decoder that implements [`Read`].
202///
203/// All compressed data is read eagerly from the inner reader on the first
204/// `read` call, decompressed into an internal buffer, and then served from
205/// that buffer for subsequent reads.
206pub struct ZstdStreamDecoder<R: Read> {
207    /// The wrapped reader providing compressed input.
208    inner: R,
209    /// Decompressed output buffer.
210    output_buffer: Vec<u8>,
211    /// Current read position inside `output_buffer`.
212    output_pos: usize,
213    /// Whether the compressed stream has been fully consumed.
214    finished: bool,
215    /// Optional pre-trained dictionary data for decompression.
216    dict: Option<Vec<u8>>,
217}
218
219impl<R: Read> ZstdStreamDecoder<R> {
220    /// Create a new streaming decoder wrapping `reader`.
221    pub fn new(reader: R) -> Self {
222        Self {
223            inner: reader,
224            output_buffer: Vec::new(),
225            output_pos: 0,
226            finished: false,
227            dict: None,
228        }
229    }
230
231    /// Create a new streaming decoder with a dictionary.
232    ///
233    /// Dictionary-based decompression requires the same dictionary that was
234    /// used during compression.
235    pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
236        Self {
237            inner: reader,
238            output_buffer: Vec::new(),
239            output_pos: 0,
240            finished: false,
241            dict: if dict.is_empty() { None } else { Some(dict) },
242        }
243    }
244
245    /// Read and decompress all compressed data from the inner reader.
246    ///
247    /// Handles concatenated Zstandard frames (multi-frame streams) by using
248    /// [`decompress_multi_frame`].  Skippable frames are silently ignored.
249    fn fill_buffer(&mut self) -> io::Result<()> {
250        if self.finished || self.output_pos < self.output_buffer.len() {
251            return Ok(());
252        }
253
254        let mut compressed = Vec::new();
255        self.inner.read_to_end(&mut compressed)?;
256
257        if compressed.is_empty() {
258            self.finished = true;
259            return Ok(());
260        }
261
262        // Use multi-frame decompression so that a stream of concatenated
263        // frames (as produced by the incremental encoder) is handled correctly.
264        // Dictionary support: if a dict is set we fall back to single-frame
265        // decoding (dict + multi-frame is a more complex scenario and not
266        // required by the current API surface).
267        self.output_buffer = if self.dict.is_none() {
268            decompress_multi_frame(&compressed)
269                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
270        } else {
271            let mut decoder = ZstdDecoder::new();
272            if let Some(ref dict) = self.dict {
273                decoder.set_dictionary(dict);
274            }
275            decoder
276                .decode_frame(&compressed)
277                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
278        };
279        self.output_pos = 0;
280        self.finished = true;
281
282        Ok(())
283    }
284
285    /// Returns the total number of decompressed bytes available (including
286    /// bytes already consumed via `read`).
287    pub fn decompressed_size(&self) -> usize {
288        self.output_buffer.len()
289    }
290
291    /// Returns `true` if all decompressed data has been read.
292    pub fn is_finished(&self) -> bool {
293        self.finished && self.output_pos >= self.output_buffer.len()
294    }
295}
296
297impl<R: Read> Read for ZstdStreamDecoder<R> {
298    /// Read decompressed data into `buf`.
299    ///
300    /// On the first call this eagerly decompresses the entire compressed
301    /// stream from the inner reader. Subsequent calls serve from the buffer.
302    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
303        self.fill_buffer()?;
304
305        let available = self.output_buffer.len() - self.output_pos;
306        if available == 0 {
307            return Ok(0);
308        }
309
310        let to_copy = buf.len().min(available);
311        buf[..to_copy]
312            .copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
313        self.output_pos += to_copy;
314        Ok(to_copy)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_stream_encoder_basic() {
324        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
325        encoder.write_all(b"Hello, Zstandard!").unwrap();
326        let compressed = encoder.finish().unwrap();
327        assert!(!compressed.is_empty());
328    }
329
330    #[test]
331    fn test_stream_encoder_empty() {
332        let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
333        let compressed = encoder.finish().unwrap();
334        // Should produce a valid (minimal) Zstd frame.
335        assert!(!compressed.is_empty());
336    }
337
338    #[test]
339    fn test_stream_roundtrip() {
340        let original = b"The quick brown fox jumps over the lazy dog.";
341
342        // Compress
343        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
344        encoder.write_all(original).unwrap();
345        let compressed = encoder.finish().unwrap();
346
347        // Decompress
348        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
349        let mut output = Vec::new();
350        decoder.read_to_end(&mut output).unwrap();
351
352        assert_eq!(output, original.as_slice());
353    }
354
355    #[test]
356    fn test_stream_roundtrip_multiple_writes() {
357        let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
358
359        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
360        for part in parts {
361            encoder.write_all(part).unwrap();
362        }
363        let compressed = encoder.finish().unwrap();
364
365        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
366        let mut output = Vec::new();
367        decoder.read_to_end(&mut output).unwrap();
368
369        assert_eq!(output, b"Hello, streaming Zstd!");
370    }
371
372    #[test]
373    fn test_stream_decoder_small_reads() {
374        let original = b"ABCDEFGHIJ";
375
376        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
377        encoder.write_all(original).unwrap();
378        let compressed = encoder.finish().unwrap();
379
380        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
381        let mut output = Vec::new();
382        let mut buf = [0u8; 3];
383
384        loop {
385            let n = decoder.read(&mut buf).unwrap();
386            if n == 0 {
387                break;
388            }
389            output.extend_from_slice(&buf[..n]);
390        }
391
392        assert_eq!(output, original.as_slice());
393    }
394
395    #[test]
396    fn test_stream_decoder_empty_input() {
397        let mut decoder = ZstdStreamDecoder::new(&[][..]);
398        let mut buf = [0u8; 16];
399        let n = decoder.read(&mut buf).unwrap();
400        assert_eq!(n, 0);
401    }
402
403    #[test]
404    fn test_stream_encoder_with_dictionary() {
405        let dict = b"common pattern data".to_vec();
406        let mut encoder = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict);
407        encoder.write_all(b"test data").unwrap();
408        let compressed = encoder.finish().unwrap();
409
410        // Should still decompress (dict is a placeholder for now).
411        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
412        let mut output = Vec::new();
413        decoder.read_to_end(&mut output).unwrap();
414        assert_eq!(output, b"test data");
415    }
416
417    #[test]
418    fn test_stream_encoder_buffered_bytes() {
419        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
420        assert_eq!(encoder.buffered_bytes(), 0);
421        encoder.write_all(b"12345").unwrap();
422        assert_eq!(encoder.buffered_bytes(), 5);
423        encoder.write_all(b"67890").unwrap();
424        assert_eq!(encoder.buffered_bytes(), 10);
425    }
426
427    #[test]
428    fn test_stream_encoder_is_finished() {
429        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
430        assert!(!encoder.is_finished());
431        encoder.write_all(b"data").unwrap();
432        assert!(!encoder.is_finished());
433        // Cannot check after finish since finish consumes self.
434    }
435
436    #[test]
437    fn test_stream_decoder_is_finished() {
438        let original = b"short";
439
440        let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
441        enc.write_all(original).unwrap();
442        let compressed = enc.finish().unwrap();
443
444        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
445        assert!(!decoder.is_finished());
446
447        let mut out = Vec::new();
448        decoder.read_to_end(&mut out).unwrap();
449        assert!(decoder.is_finished());
450    }
451
452    #[test]
453    fn test_stream_roundtrip_large_data() {
454        let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
455
456        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
457        encoder.write_all(&original).unwrap();
458        let compressed = encoder.finish().unwrap();
459
460        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
461        let mut output = Vec::new();
462        decoder.read_to_end(&mut output).unwrap();
463
464        assert_eq!(output, original);
465    }
466}