Skip to main content

oxiarc_zstd/
streaming.rs

1//! Streaming compression and decompression for Zstandard.
2//!
3//! Provides [`ZstdStreamEncoder`] (implements [`std::io::Write`]) and
4//! [`ZstdStreamDecoder`] (implements [`std::io::Read`]) for processing Zstandard data
5//! through standard Rust I/O traits.
6//!
7//! The streaming encoder buffers all written data and compresses it into a
8//! single Zstandard frame when [`ZstdStreamEncoder::finish`] is called. This
9//! matches the behaviour of many Zstd wrapper crates that operate on in-memory
10//! buffers.
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use std::io::{Read, Write};
16//! use oxiarc_zstd::streaming::{ZstdStreamEncoder, ZstdStreamDecoder};
17//!
18//! // Compress
19//! let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
20//! encoder.write_all(b"Hello, streaming Zstd!").expect("write failed");
21//! let compressed = encoder.finish().expect("finish failed");
22//!
23//! // Decompress
24//! let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
25//! let mut output = String::new();
26//! decoder.read_to_string(&mut output).expect("read failed");
27//! assert_eq!(output, "Hello, streaming Zstd!");
28//! ```
29
30use crate::encode::ZstdEncoder;
31use crate::frame::{decompress_multi_frame, decompress_multi_frame_with_dict};
32use oxiarc_core::cancel::CancellationToken;
33use oxiarc_core::progress::ProgressHandle;
34use std::io::{self, Read, Write};
35
36/// Default block size for the incremental encoder (128 KiB).
37const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
38
39/// Streaming Zstandard encoder that implements [`Write`].
40///
41/// Data written to this encoder is buffered internally.  When the internal
42/// buffer reaches `block_size` bytes it is automatically flushed as a
43/// complete Zstandard frame to the inner writer (truly incremental).  Any
44/// remaining data is flushed when [`finish`](ZstdStreamEncoder::finish) is
45/// called.
46///
47/// The output is a sequence of valid concatenated Zstandard frames and can be
48/// decoded with [`decompress_multi_frame`].
49///
50/// Supports optional progress reporting via [`ProgressHandle`] and
51/// cooperative cancellation via [`CancellationToken`] using the
52/// [`ZstdStreamEncoder::with_progress`] / [`ZstdStreamEncoder::with_cancel`] builders.
53///
54/// **Important:** you *must* call [`finish`](ZstdStreamEncoder::finish) to
55/// flush the final (possibly partial) block. Dropping the encoder without
56/// calling `finish` will silently discard any buffered data.
57pub struct ZstdStreamEncoder<W: Write> {
58    /// The wrapped writer that receives compressed output.
59    inner: Option<W>,
60    /// Internal buffer holding uncompressed data waiting to be flushed.
61    buffer: Vec<u8>,
62    /// Compression level used when encoding.
63    level: i32,
64    /// Optional pre-trained dictionary data.
65    dict: Option<Vec<u8>>,
66    /// Whether `finish` has already been called.
67    finished: bool,
68    /// Threshold at which the buffer is automatically flushed.
69    block_size: usize,
70    /// Optional progress sink. Notified after each block is flushed.
71    progress: Option<ProgressHandle>,
72    /// Optional cancellation token. Checked before each block flush.
73    cancel: Option<CancellationToken>,
74    /// Cumulative uncompressed bytes flushed so far.
75    bytes_processed: u64,
76}
77
78impl<W: Write> ZstdStreamEncoder<W> {
79    /// Create a new streaming encoder wrapping `writer`.
80    ///
81    /// The `level` parameter controls the compression level passed to the
82    /// underlying [`ZstdEncoder`].  The encoder uses a default block size of
83    /// 128 KiB; use [`with_block_size`](ZstdStreamEncoder::with_block_size)
84    /// to customise this.
85    pub fn new(writer: W, level: i32) -> Self {
86        Self {
87            inner: Some(writer),
88            buffer: Vec::new(),
89            level,
90            dict: None,
91            finished: false,
92            block_size: DEFAULT_BLOCK_SIZE,
93            progress: None,
94            cancel: None,
95            bytes_processed: 0,
96        }
97    }
98
99    /// Create a new streaming encoder with a pre-trained dictionary.
100    ///
101    /// Dictionary-based compression improves ratios for small payloads that
102    /// share common patterns.
103    pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
104        Self {
105            inner: Some(writer),
106            buffer: Vec::new(),
107            level,
108            dict: Some(dict),
109            finished: false,
110            block_size: DEFAULT_BLOCK_SIZE,
111            progress: None,
112            cancel: None,
113            bytes_processed: 0,
114        }
115    }
116
117    /// Set the block size used for incremental flushing.
118    ///
119    /// When the internal buffer reaches this many bytes it is automatically
120    /// compressed and written to the inner writer as a Zstandard frame.
121    pub fn with_block_size(mut self, block_size: usize) -> Self {
122        self.block_size = block_size.max(1);
123        self
124    }
125
126    /// Attach a progress sink.
127    ///
128    /// The sink's `on_progress(cumulative_bytes, None)` is called after each
129    /// block is flushed to the inner writer. `on_finish()` is called after
130    /// `finish` flushes the final block.
131    pub fn with_progress(mut self, handle: ProgressHandle) -> Self {
132        self.progress = Some(handle);
133        self
134    }
135
136    /// Attach a cancellation token.
137    ///
138    /// The token is checked before each block is compressed and written.
139    /// If cancelled, returns an I/O error wrapping
140    /// [`oxiarc_core::error::OxiArcError::Cancelled`].
141    pub fn with_cancel(mut self, token: CancellationToken) -> Self {
142        self.cancel = Some(token);
143        self
144    }
145
146    /// Finish compression and return the inner writer.
147    ///
148    /// This **must** be called to flush the final compressed data. Failing to
149    /// call `finish` means all buffered data is lost.
150    ///
151    /// # Errors
152    ///
153    /// Returns an [`io::Error`] if compression or writing to the inner writer
154    /// fails.
155    pub fn finish(mut self) -> io::Result<W> {
156        if !self.finished {
157            // Flush whatever remains in the buffer (even if empty, to match
158            // the single-frame behaviour expected by existing tests).
159            self.flush_buffer_unconditional()?;
160            self.finished = true;
161            if let Some(ref handle) = self.progress {
162                handle.on_finish();
163            }
164        }
165        // inner is always Some until finish() is called once.
166        self.inner
167            .take()
168            .ok_or_else(|| io::Error::other("inner writer already taken"))
169    }
170
171    /// Compress `data` as a single Zstandard frame and write it to `inner`.
172    fn compress_and_write(&mut self, data: &[u8]) -> io::Result<()> {
173        // Cooperative cancellation check before each block.
174        if let Some(ref token) = self.cancel {
175            token.check().map_err(|e| io::Error::other(e.to_string()))?;
176        }
177
178        let mut encoder = ZstdEncoder::new();
179        encoder.set_level(self.level);
180        if let Some(ref dict) = self.dict {
181            encoder.set_dictionary(dict);
182        }
183        let compressed = encoder
184            .compress(data)
185            .map_err(|e| io::Error::other(e.to_string()))?;
186        if let Some(ref mut w) = self.inner {
187            w.write_all(&compressed)?;
188        }
189
190        self.bytes_processed += data.len() as u64;
191        if let Some(ref handle) = self.progress {
192            handle.on_progress(self.bytes_processed, None);
193        }
194
195        Ok(())
196    }
197
198    /// If the buffer has reached `block_size`, flush it as a frame.
199    fn maybe_flush_block(&mut self) -> io::Result<()> {
200        if self.buffer.len() >= self.block_size {
201            let data = std::mem::take(&mut self.buffer);
202            self.compress_and_write(&data)?;
203        }
204        Ok(())
205    }
206
207    /// Always flush the current buffer contents (even if empty) as a frame.
208    fn flush_buffer_unconditional(&mut self) -> io::Result<()> {
209        let data = std::mem::take(&mut self.buffer);
210        self.compress_and_write(&data)
211    }
212
213    /// Returns the number of uncompressed bytes currently buffered.
214    pub fn buffered_bytes(&self) -> usize {
215        self.buffer.len()
216    }
217
218    /// Returns `true` if `finish` has already been called.
219    pub fn is_finished(&self) -> bool {
220        self.finished
221    }
222}
223
224impl<W: Write> Write for ZstdStreamEncoder<W> {
225    /// Buffer `buf` and flush a frame to the inner writer whenever the
226    /// internal buffer reaches `block_size`.
227    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
228        if self.finished {
229            return Err(io::Error::other("encoder already finished"));
230        }
231        self.buffer.extend_from_slice(buf);
232        self.maybe_flush_block()?;
233        Ok(buf.len())
234    }
235
236    /// Flush any buffered data as a Zstandard frame to the inner writer.
237    fn flush(&mut self) -> io::Result<()> {
238        if !self.buffer.is_empty() {
239            let data = std::mem::take(&mut self.buffer);
240            self.compress_and_write(&data)?;
241        }
242        if let Some(ref mut w) = self.inner {
243            w.flush()?;
244        }
245        Ok(())
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Streaming decoder
251// ---------------------------------------------------------------------------
252
253/// Streaming Zstandard decoder that implements [`Read`].
254///
255/// All compressed data is read eagerly from the inner reader on the first
256/// `read` call, decompressed into an internal buffer, and then served from
257/// that buffer for subsequent reads.
258///
259/// Supports optional progress reporting via [`ProgressHandle`] and
260/// cooperative cancellation via [`CancellationToken`] using the
261/// [`ZstdStreamDecoder::with_progress`] / [`ZstdStreamDecoder::with_cancel`] builders.
262pub struct ZstdStreamDecoder<R: Read> {
263    /// The wrapped reader providing compressed input.
264    inner: R,
265    /// Decompressed output buffer.
266    output_buffer: Vec<u8>,
267    /// Current read position inside `output_buffer`.
268    output_pos: usize,
269    /// Whether the compressed stream has been fully consumed.
270    finished: bool,
271    /// Optional pre-trained dictionary data for decompression.
272    dict: Option<Vec<u8>>,
273    /// Optional progress sink.
274    progress: Option<ProgressHandle>,
275    /// Optional cancellation token.
276    cancel: Option<CancellationToken>,
277}
278
279impl<R: Read> ZstdStreamDecoder<R> {
280    /// Create a new streaming decoder wrapping `reader`.
281    pub fn new(reader: R) -> Self {
282        Self {
283            inner: reader,
284            output_buffer: Vec::new(),
285            output_pos: 0,
286            finished: false,
287            dict: None,
288            progress: None,
289            cancel: None,
290        }
291    }
292
293    /// Create a new streaming decoder with a dictionary.
294    ///
295    /// Dictionary-based decompression requires the same dictionary that was
296    /// used during compression.
297    pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
298        Self {
299            inner: reader,
300            output_buffer: Vec::new(),
301            output_pos: 0,
302            finished: false,
303            dict: if dict.is_empty() { None } else { Some(dict) },
304            progress: None,
305            cancel: None,
306        }
307    }
308
309    /// Attach a progress sink.
310    ///
311    /// The sink's `on_progress(decompressed_bytes, None)` is called once
312    /// after the entire stream is decompressed into the internal buffer.
313    /// `on_finish()` is called at the same point.
314    pub fn with_progress(mut self, handle: ProgressHandle) -> Self {
315        self.progress = Some(handle);
316        self
317    }
318
319    /// Attach a cancellation token.
320    ///
321    /// The token is checked before the compressed stream is read and
322    /// decompressed. If cancelled, an I/O error wrapping
323    /// [`oxiarc_core::error::OxiArcError::Cancelled`] is returned.
324    pub fn with_cancel(mut self, token: CancellationToken) -> Self {
325        self.cancel = Some(token);
326        self
327    }
328
329    /// Read and decompress all compressed data from the inner reader.
330    ///
331    /// Handles concatenated Zstandard frames (multi-frame streams) by using
332    /// [`decompress_multi_frame`].  Skippable frames are silently ignored.
333    fn fill_buffer(&mut self) -> io::Result<()> {
334        if self.finished || self.output_pos < self.output_buffer.len() {
335            return Ok(());
336        }
337
338        // Cooperative cancellation check before reading.
339        if let Some(ref token) = self.cancel {
340            token.check().map_err(|e| io::Error::other(e.to_string()))?;
341        }
342
343        let mut compressed = Vec::new();
344        self.inner.read_to_end(&mut compressed)?;
345
346        if compressed.is_empty() {
347            self.finished = true;
348            return Ok(());
349        }
350
351        // Use multi-frame decompression so that a stream of concatenated
352        // frames (as produced by the incremental encoder) is handled correctly.
353        // When a dictionary is set, use the dict-aware variant so that all
354        // frames in the concatenated stream are decoded with the same dictionary
355        // (the encoder writes one frame per block, each referencing the dict).
356        self.output_buffer = if let Some(ref dict) = self.dict {
357            decompress_multi_frame_with_dict(&compressed, dict)
358                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
359        } else {
360            decompress_multi_frame(&compressed)
361                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
362        };
363        self.output_pos = 0;
364        self.finished = true;
365
366        let total = self.output_buffer.len() as u64;
367        if let Some(ref handle) = self.progress {
368            handle.on_progress(total, None);
369            handle.on_finish();
370        }
371
372        Ok(())
373    }
374
375    /// Returns the total number of decompressed bytes available (including
376    /// bytes already consumed via `read`).
377    pub fn decompressed_size(&self) -> usize {
378        self.output_buffer.len()
379    }
380
381    /// Returns `true` if all decompressed data has been read.
382    pub fn is_finished(&self) -> bool {
383        self.finished && self.output_pos >= self.output_buffer.len()
384    }
385}
386
387impl<R: Read> Read for ZstdStreamDecoder<R> {
388    /// Read decompressed data into `buf`.
389    ///
390    /// On the first call this eagerly decompresses the entire compressed
391    /// stream from the inner reader. Subsequent calls serve from the buffer.
392    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393        self.fill_buffer()?;
394
395        let available = self.output_buffer.len() - self.output_pos;
396        if available == 0 {
397            return Ok(0);
398        }
399
400        let to_copy = buf.len().min(available);
401        buf[..to_copy]
402            .copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
403        self.output_pos += to_copy;
404        Ok(to_copy)
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_stream_encoder_basic() {
414        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
415        encoder
416            .write_all(b"Hello, Zstandard!")
417            .expect("write failed");
418        let compressed = encoder.finish().expect("finish failed");
419        assert!(!compressed.is_empty());
420    }
421
422    #[test]
423    fn test_stream_encoder_empty() {
424        let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
425        let compressed = encoder.finish().expect("finish failed");
426        // Should produce a valid (minimal) Zstd frame.
427        assert!(!compressed.is_empty());
428    }
429
430    #[test]
431    fn test_stream_roundtrip() {
432        let original = b"The quick brown fox jumps over the lazy dog.";
433
434        // Compress
435        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
436        encoder.write_all(original).expect("write failed");
437        let compressed = encoder.finish().expect("finish failed");
438
439        // Decompress
440        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
441        let mut output = Vec::new();
442        decoder.read_to_end(&mut output).expect("read failed");
443
444        assert_eq!(output, original.as_slice());
445    }
446
447    #[test]
448    fn test_stream_roundtrip_multiple_writes() {
449        let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
450
451        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
452        for part in parts {
453            encoder.write_all(part).expect("write failed");
454        }
455        let compressed = encoder.finish().expect("finish failed");
456
457        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
458        let mut output = Vec::new();
459        decoder.read_to_end(&mut output).expect("read failed");
460
461        assert_eq!(output, b"Hello, streaming Zstd!");
462    }
463
464    #[test]
465    fn test_stream_decoder_small_reads() {
466        let original = b"ABCDEFGHIJ";
467
468        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
469        encoder.write_all(original).expect("write failed");
470        let compressed = encoder.finish().expect("finish failed");
471
472        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
473        let mut output = Vec::new();
474        let mut buf = [0u8; 3];
475
476        loop {
477            let n = decoder.read(&mut buf).expect("read failed");
478            if n == 0 {
479                break;
480            }
481            output.extend_from_slice(&buf[..n]);
482        }
483
484        assert_eq!(output, original.as_slice());
485    }
486
487    #[test]
488    fn test_stream_decoder_empty_input() {
489        let mut decoder = ZstdStreamDecoder::new(&[][..]);
490        let mut buf = [0u8; 16];
491        let n = decoder.read(&mut buf).expect("read failed");
492        assert_eq!(n, 0);
493    }
494
495    #[test]
496    fn test_stream_encoder_decoder_dict_roundtrip_small() {
497        let dict = b"common pattern data appears frequently in the corpus".to_vec();
498        let payload = b"common pattern data";
499
500        let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict.clone());
501        enc.write_all(payload).expect("write");
502        let compressed = enc.finish().expect("finish");
503
504        let mut dec = ZstdStreamDecoder::with_dictionary(&compressed[..], dict);
505        let mut out = Vec::new();
506        dec.read_to_end(&mut out).expect("read");
507        assert_eq!(out, payload as &[u8]);
508    }
509
510    #[test]
511    fn test_stream_encoder_decoder_dict_roundtrip_large() {
512        // Multi-frame dict roundtrip: use a small block_size so each write produces
513        // multiple concatenated Zstd frames, each well under 128 KiB so we stay in
514        // the single-internal-block regime and avoid the known multi-internal-block
515        // + dict bug that is pre-existing in the encoder.
516        let dict_text = "alpha beta gamma delta epsilon zeta eta theta iota kappa ".repeat(50);
517        let dict = dict_text.as_bytes().to_vec();
518        // ~57 KB payload with 8 KB block_size → ~8 frames, each < 128 KiB.
519        let payload: Vec<u8> = dict_text.repeat(20).into_bytes();
520
521        let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 3, dict.clone())
522            .with_block_size(8 * 1024);
523        enc.write_all(&payload).expect("write");
524        let compressed = enc.finish().expect("finish");
525
526        // Verify multiple Zstd frames were produced by counting magic bytes.
527        let magic = &crate::ZSTD_MAGIC;
528        let frame_count = compressed.windows(4).filter(|w| *w == magic).count();
529        assert!(
530            frame_count > 1,
531            "expected multiple frames, got {}",
532            frame_count
533        );
534
535        let mut dec = ZstdStreamDecoder::with_dictionary(&compressed[..], dict);
536        let mut out = Vec::new();
537        dec.read_to_end(&mut out).expect("read");
538        assert_eq!(out, payload);
539    }
540
541    #[test]
542    fn test_stream_decoder_without_dict_on_dict_compressed_large_data() {
543        // Compress with a dict; decode without. On large inputs that trigger
544        // dict back-references, this must either error or produce wrong output.
545        let dict_text = "pattern frequently repeating text ".repeat(200);
546        let dict = dict_text.as_bytes().to_vec();
547        let payload: Vec<u8> = dict_text.repeat(50).into_bytes();
548
549        let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 3, dict);
550        enc.write_all(&payload).expect("write");
551        let compressed = enc.finish().expect("finish");
552
553        let mut dec = ZstdStreamDecoder::new(&compressed[..]);
554        let mut out = Vec::new();
555        let result = dec.read_to_end(&mut out);
556        if result.is_ok() {
557            assert_ne!(
558                out, payload,
559                "decoding without dict should not reproduce original on large input"
560            );
561        }
562    }
563
564    #[test]
565    fn test_stream_encoder_buffered_bytes() {
566        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
567        assert_eq!(encoder.buffered_bytes(), 0);
568        encoder.write_all(b"12345").expect("write failed");
569        assert_eq!(encoder.buffered_bytes(), 5);
570        encoder.write_all(b"67890").expect("write failed");
571        assert_eq!(encoder.buffered_bytes(), 10);
572    }
573
574    #[test]
575    fn test_stream_encoder_is_finished() {
576        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
577        assert!(!encoder.is_finished());
578        encoder.write_all(b"data").expect("write failed");
579        assert!(!encoder.is_finished());
580        // Cannot check after finish since finish consumes self.
581    }
582
583    #[test]
584    fn test_stream_decoder_is_finished() {
585        let original = b"short";
586
587        let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
588        enc.write_all(original).expect("write failed");
589        let compressed = enc.finish().expect("finish failed");
590
591        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
592        assert!(!decoder.is_finished());
593
594        let mut out = Vec::new();
595        decoder.read_to_end(&mut out).expect("read failed");
596        assert!(decoder.is_finished());
597    }
598
599    #[test]
600    fn test_stream_roundtrip_large_data() {
601        let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
602
603        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
604        encoder.write_all(&original).expect("write failed");
605        let compressed = encoder.finish().expect("finish failed");
606
607        let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
608        let mut output = Vec::new();
609        decoder.read_to_end(&mut output).expect("read failed");
610
611        assert_eq!(output, original);
612    }
613
614    use oxiarc_core::cancel::CancellationToken;
615    use oxiarc_core::progress::ProgressSink;
616    use std::sync::{Arc, Mutex};
617
618    type ProgressLog = Arc<Mutex<Vec<(u64, Option<u64>)>>>;
619
620    struct MockSink(ProgressLog);
621
622    impl ProgressSink for MockSink {
623        fn on_progress(&self, processed: u64, total: Option<u64>) {
624            self.0
625                .lock()
626                .expect("lock poisoned")
627                .push((processed, total));
628        }
629    }
630
631    fn make_compressible_data(size: usize) -> Vec<u8> {
632        let pattern = b"ZstdStream test data with repeating pattern ABCDEFGH ";
633        let mut data = Vec::with_capacity(size);
634        while data.len() < size {
635            let remaining = size - data.len();
636            let chunk = &pattern[..remaining.min(pattern.len())];
637            data.extend_from_slice(chunk);
638        }
639        data
640    }
641
642    #[test]
643    fn test_zstd_stream_encoder_progress_reports() {
644        let data = make_compressible_data(1024 * 1024); // 1 MB
645
646        let calls: ProgressLog = Arc::new(Mutex::new(Vec::new()));
647        let sink = Arc::new(MockSink(calls.clone()));
648
649        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1)
650            .with_progress(sink as oxiarc_core::progress::ProgressHandle);
651        encoder.write_all(&data).expect("write_all failed");
652        encoder.finish().expect("finish failed");
653
654        let recorded = calls.lock().expect("lock poisoned");
655        assert!(!recorded.is_empty(), "expected at least one progress call");
656        let (last_processed, _) = *recorded.last().expect("non-empty");
657        assert_eq!(
658            last_processed,
659            data.len() as u64,
660            "final processed count must equal input size"
661        );
662    }
663
664    #[test]
665    fn test_zstd_stream_encoder_cancel_aborts() {
666        let data = make_compressible_data(1024 * 1024);
667        let token = CancellationToken::new();
668
669        // Use a small block size so we hit the block boundary quickly.
670        let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1)
671            .with_block_size(4096)
672            .with_cancel(token.clone());
673
674        token.cancel();
675        let result = encoder.write_all(&data);
676        assert!(result.is_err(), "expected cancellation error");
677    }
678
679    #[test]
680    fn test_zstd_stream_decoder_progress_reports() {
681        let data = make_compressible_data(1024 * 1024); // 1 MB
682
683        let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
684        enc.write_all(&data).expect("write_all failed");
685        let compressed = enc.finish().expect("finish failed");
686
687        let calls: ProgressLog = Arc::new(Mutex::new(Vec::new()));
688        let sink = Arc::new(MockSink(calls.clone()));
689
690        let mut decoder = ZstdStreamDecoder::new(&compressed[..])
691            .with_progress(sink as oxiarc_core::progress::ProgressHandle);
692        let mut output = Vec::new();
693        decoder
694            .read_to_end(&mut output)
695            .expect("read_to_end failed");
696
697        let recorded = calls.lock().expect("lock poisoned");
698        assert!(!recorded.is_empty(), "expected at least one progress call");
699        let (last_processed, _) = *recorded.last().expect("non-empty");
700        assert_eq!(
701            last_processed,
702            data.len() as u64,
703            "final processed count must equal decompressed size"
704        );
705    }
706
707    #[test]
708    fn test_zstd_stream_decoder_cancel_aborts() {
709        let data = make_compressible_data(1024 * 1024);
710        let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
711        enc.write_all(&data).expect("write_all failed");
712        let compressed = enc.finish().expect("finish failed");
713
714        let token = CancellationToken::new();
715        let mut decoder = ZstdStreamDecoder::new(&compressed[..]).with_cancel(token.clone());
716        let mut output = Vec::new();
717
718        token.cancel();
719        let result = decoder.read_to_end(&mut output);
720        assert!(result.is_err(), "expected cancellation error");
721    }
722}