1use crate::encode::ZstdEncoder;
31use crate::frame::{decompress_multi_frame, decompress_multi_frame_with_dict};
32use oxiarc_core::cancel::CancellationToken;
33use oxiarc_core::progress::ProgressHandle;
34use std::io::{self, Read, Write};
35
36const DEFAULT_BLOCK_SIZE: usize = 128 * 1024;
38
39pub struct ZstdStreamEncoder<W: Write> {
58 inner: Option<W>,
60 buffer: Vec<u8>,
62 level: i32,
64 dict: Option<Vec<u8>>,
66 finished: bool,
68 block_size: usize,
70 progress: Option<ProgressHandle>,
72 cancel: Option<CancellationToken>,
74 bytes_processed: u64,
76}
77
78impl<W: Write> ZstdStreamEncoder<W> {
79 pub fn new(writer: W, level: i32) -> Self {
86 Self {
87 inner: Some(writer),
88 buffer: Vec::new(),
89 level,
90 dict: None,
91 finished: false,
92 block_size: DEFAULT_BLOCK_SIZE,
93 progress: None,
94 cancel: None,
95 bytes_processed: 0,
96 }
97 }
98
99 pub fn with_dictionary(writer: W, level: i32, dict: Vec<u8>) -> Self {
104 Self {
105 inner: Some(writer),
106 buffer: Vec::new(),
107 level,
108 dict: Some(dict),
109 finished: false,
110 block_size: DEFAULT_BLOCK_SIZE,
111 progress: None,
112 cancel: None,
113 bytes_processed: 0,
114 }
115 }
116
117 pub fn with_block_size(mut self, block_size: usize) -> Self {
122 self.block_size = block_size.max(1);
123 self
124 }
125
126 pub fn with_progress(mut self, handle: ProgressHandle) -> Self {
132 self.progress = Some(handle);
133 self
134 }
135
136 pub fn with_cancel(mut self, token: CancellationToken) -> Self {
142 self.cancel = Some(token);
143 self
144 }
145
146 pub fn finish(mut self) -> io::Result<W> {
156 if !self.finished {
157 self.flush_buffer_unconditional()?;
160 self.finished = true;
161 if let Some(ref handle) = self.progress {
162 handle.on_finish();
163 }
164 }
165 self.inner
167 .take()
168 .ok_or_else(|| io::Error::other("inner writer already taken"))
169 }
170
171 fn compress_and_write(&mut self, data: &[u8]) -> io::Result<()> {
173 if let Some(ref token) = self.cancel {
175 token.check().map_err(|e| io::Error::other(e.to_string()))?;
176 }
177
178 let mut encoder = ZstdEncoder::new();
179 encoder.set_level(self.level);
180 if let Some(ref dict) = self.dict {
181 encoder.set_dictionary(dict);
182 }
183 let compressed = encoder
184 .compress(data)
185 .map_err(|e| io::Error::other(e.to_string()))?;
186 if let Some(ref mut w) = self.inner {
187 w.write_all(&compressed)?;
188 }
189
190 self.bytes_processed += data.len() as u64;
191 if let Some(ref handle) = self.progress {
192 handle.on_progress(self.bytes_processed, None);
193 }
194
195 Ok(())
196 }
197
198 fn maybe_flush_block(&mut self) -> io::Result<()> {
200 if self.buffer.len() >= self.block_size {
201 let data = std::mem::take(&mut self.buffer);
202 self.compress_and_write(&data)?;
203 }
204 Ok(())
205 }
206
207 fn flush_buffer_unconditional(&mut self) -> io::Result<()> {
209 let data = std::mem::take(&mut self.buffer);
210 self.compress_and_write(&data)
211 }
212
213 pub fn buffered_bytes(&self) -> usize {
215 self.buffer.len()
216 }
217
218 pub fn is_finished(&self) -> bool {
220 self.finished
221 }
222}
223
224impl<W: Write> Write for ZstdStreamEncoder<W> {
225 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
228 if self.finished {
229 return Err(io::Error::other("encoder already finished"));
230 }
231 self.buffer.extend_from_slice(buf);
232 self.maybe_flush_block()?;
233 Ok(buf.len())
234 }
235
236 fn flush(&mut self) -> io::Result<()> {
238 if !self.buffer.is_empty() {
239 let data = std::mem::take(&mut self.buffer);
240 self.compress_and_write(&data)?;
241 }
242 if let Some(ref mut w) = self.inner {
243 w.flush()?;
244 }
245 Ok(())
246 }
247}
248
249pub struct ZstdStreamDecoder<R: Read> {
263 inner: R,
265 output_buffer: Vec<u8>,
267 output_pos: usize,
269 finished: bool,
271 dict: Option<Vec<u8>>,
273 progress: Option<ProgressHandle>,
275 cancel: Option<CancellationToken>,
277}
278
279impl<R: Read> ZstdStreamDecoder<R> {
280 pub fn new(reader: R) -> Self {
282 Self {
283 inner: reader,
284 output_buffer: Vec::new(),
285 output_pos: 0,
286 finished: false,
287 dict: None,
288 progress: None,
289 cancel: None,
290 }
291 }
292
293 pub fn with_dictionary(reader: R, dict: Vec<u8>) -> Self {
298 Self {
299 inner: reader,
300 output_buffer: Vec::new(),
301 output_pos: 0,
302 finished: false,
303 dict: if dict.is_empty() { None } else { Some(dict) },
304 progress: None,
305 cancel: None,
306 }
307 }
308
309 pub fn with_progress(mut self, handle: ProgressHandle) -> Self {
315 self.progress = Some(handle);
316 self
317 }
318
319 pub fn with_cancel(mut self, token: CancellationToken) -> Self {
325 self.cancel = Some(token);
326 self
327 }
328
329 fn fill_buffer(&mut self) -> io::Result<()> {
334 if self.finished || self.output_pos < self.output_buffer.len() {
335 return Ok(());
336 }
337
338 if let Some(ref token) = self.cancel {
340 token.check().map_err(|e| io::Error::other(e.to_string()))?;
341 }
342
343 let mut compressed = Vec::new();
344 self.inner.read_to_end(&mut compressed)?;
345
346 if compressed.is_empty() {
347 self.finished = true;
348 return Ok(());
349 }
350
351 self.output_buffer = if let Some(ref dict) = self.dict {
357 decompress_multi_frame_with_dict(&compressed, dict)
358 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
359 } else {
360 decompress_multi_frame(&compressed)
361 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?
362 };
363 self.output_pos = 0;
364 self.finished = true;
365
366 let total = self.output_buffer.len() as u64;
367 if let Some(ref handle) = self.progress {
368 handle.on_progress(total, None);
369 handle.on_finish();
370 }
371
372 Ok(())
373 }
374
375 pub fn decompressed_size(&self) -> usize {
378 self.output_buffer.len()
379 }
380
381 pub fn is_finished(&self) -> bool {
383 self.finished && self.output_pos >= self.output_buffer.len()
384 }
385}
386
387impl<R: Read> Read for ZstdStreamDecoder<R> {
388 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393 self.fill_buffer()?;
394
395 let available = self.output_buffer.len() - self.output_pos;
396 if available == 0 {
397 return Ok(0);
398 }
399
400 let to_copy = buf.len().min(available);
401 buf[..to_copy]
402 .copy_from_slice(&self.output_buffer[self.output_pos..self.output_pos + to_copy]);
403 self.output_pos += to_copy;
404 Ok(to_copy)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_stream_encoder_basic() {
414 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
415 encoder
416 .write_all(b"Hello, Zstandard!")
417 .expect("write failed");
418 let compressed = encoder.finish().expect("finish failed");
419 assert!(!compressed.is_empty());
420 }
421
422 #[test]
423 fn test_stream_encoder_empty() {
424 let encoder = ZstdStreamEncoder::new(Vec::new(), 1);
425 let compressed = encoder.finish().expect("finish failed");
426 assert!(!compressed.is_empty());
428 }
429
430 #[test]
431 fn test_stream_roundtrip() {
432 let original = b"The quick brown fox jumps over the lazy dog.";
433
434 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
436 encoder.write_all(original).expect("write failed");
437 let compressed = encoder.finish().expect("finish failed");
438
439 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
441 let mut output = Vec::new();
442 decoder.read_to_end(&mut output).expect("read failed");
443
444 assert_eq!(output, original.as_slice());
445 }
446
447 #[test]
448 fn test_stream_roundtrip_multiple_writes() {
449 let parts: &[&[u8]] = &[b"Hello, ", b"streaming ", b"Zstd!"];
450
451 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
452 for part in parts {
453 encoder.write_all(part).expect("write failed");
454 }
455 let compressed = encoder.finish().expect("finish failed");
456
457 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
458 let mut output = Vec::new();
459 decoder.read_to_end(&mut output).expect("read failed");
460
461 assert_eq!(output, b"Hello, streaming Zstd!");
462 }
463
464 #[test]
465 fn test_stream_decoder_small_reads() {
466 let original = b"ABCDEFGHIJ";
467
468 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
469 encoder.write_all(original).expect("write failed");
470 let compressed = encoder.finish().expect("finish failed");
471
472 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
473 let mut output = Vec::new();
474 let mut buf = [0u8; 3];
475
476 loop {
477 let n = decoder.read(&mut buf).expect("read failed");
478 if n == 0 {
479 break;
480 }
481 output.extend_from_slice(&buf[..n]);
482 }
483
484 assert_eq!(output, original.as_slice());
485 }
486
487 #[test]
488 fn test_stream_decoder_empty_input() {
489 let mut decoder = ZstdStreamDecoder::new(&[][..]);
490 let mut buf = [0u8; 16];
491 let n = decoder.read(&mut buf).expect("read failed");
492 assert_eq!(n, 0);
493 }
494
495 #[test]
496 fn test_stream_encoder_decoder_dict_roundtrip_small() {
497 let dict = b"common pattern data appears frequently in the corpus".to_vec();
498 let payload = b"common pattern data";
499
500 let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 1, dict.clone());
501 enc.write_all(payload).expect("write");
502 let compressed = enc.finish().expect("finish");
503
504 let mut dec = ZstdStreamDecoder::with_dictionary(&compressed[..], dict);
505 let mut out = Vec::new();
506 dec.read_to_end(&mut out).expect("read");
507 assert_eq!(out, payload as &[u8]);
508 }
509
510 #[test]
511 fn test_stream_encoder_decoder_dict_roundtrip_large() {
512 let dict_text = "alpha beta gamma delta epsilon zeta eta theta iota kappa ".repeat(50);
517 let dict = dict_text.as_bytes().to_vec();
518 let payload: Vec<u8> = dict_text.repeat(20).into_bytes();
520
521 let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 3, dict.clone())
522 .with_block_size(8 * 1024);
523 enc.write_all(&payload).expect("write");
524 let compressed = enc.finish().expect("finish");
525
526 let magic = &crate::ZSTD_MAGIC;
528 let frame_count = compressed.windows(4).filter(|w| *w == magic).count();
529 assert!(
530 frame_count > 1,
531 "expected multiple frames, got {}",
532 frame_count
533 );
534
535 let mut dec = ZstdStreamDecoder::with_dictionary(&compressed[..], dict);
536 let mut out = Vec::new();
537 dec.read_to_end(&mut out).expect("read");
538 assert_eq!(out, payload);
539 }
540
541 #[test]
542 fn test_stream_decoder_without_dict_on_dict_compressed_large_data() {
543 let dict_text = "pattern frequently repeating text ".repeat(200);
546 let dict = dict_text.as_bytes().to_vec();
547 let payload: Vec<u8> = dict_text.repeat(50).into_bytes();
548
549 let mut enc = ZstdStreamEncoder::with_dictionary(Vec::new(), 3, dict);
550 enc.write_all(&payload).expect("write");
551 let compressed = enc.finish().expect("finish");
552
553 let mut dec = ZstdStreamDecoder::new(&compressed[..]);
554 let mut out = Vec::new();
555 let result = dec.read_to_end(&mut out);
556 if result.is_ok() {
557 assert_ne!(
558 out, payload,
559 "decoding without dict should not reproduce original on large input"
560 );
561 }
562 }
563
564 #[test]
565 fn test_stream_encoder_buffered_bytes() {
566 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
567 assert_eq!(encoder.buffered_bytes(), 0);
568 encoder.write_all(b"12345").expect("write failed");
569 assert_eq!(encoder.buffered_bytes(), 5);
570 encoder.write_all(b"67890").expect("write failed");
571 assert_eq!(encoder.buffered_bytes(), 10);
572 }
573
574 #[test]
575 fn test_stream_encoder_is_finished() {
576 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
577 assert!(!encoder.is_finished());
578 encoder.write_all(b"data").expect("write failed");
579 assert!(!encoder.is_finished());
580 }
582
583 #[test]
584 fn test_stream_decoder_is_finished() {
585 let original = b"short";
586
587 let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
588 enc.write_all(original).expect("write failed");
589 let compressed = enc.finish().expect("finish failed");
590
591 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
592 assert!(!decoder.is_finished());
593
594 let mut out = Vec::new();
595 decoder.read_to_end(&mut out).expect("read failed");
596 assert!(decoder.is_finished());
597 }
598
599 #[test]
600 fn test_stream_roundtrip_large_data() {
601 let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
602
603 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1);
604 encoder.write_all(&original).expect("write failed");
605 let compressed = encoder.finish().expect("finish failed");
606
607 let mut decoder = ZstdStreamDecoder::new(&compressed[..]);
608 let mut output = Vec::new();
609 decoder.read_to_end(&mut output).expect("read failed");
610
611 assert_eq!(output, original);
612 }
613
614 use oxiarc_core::cancel::CancellationToken;
615 use oxiarc_core::progress::ProgressSink;
616 use std::sync::{Arc, Mutex};
617
618 type ProgressLog = Arc<Mutex<Vec<(u64, Option<u64>)>>>;
619
620 struct MockSink(ProgressLog);
621
622 impl ProgressSink for MockSink {
623 fn on_progress(&self, processed: u64, total: Option<u64>) {
624 self.0
625 .lock()
626 .expect("lock poisoned")
627 .push((processed, total));
628 }
629 }
630
631 fn make_compressible_data(size: usize) -> Vec<u8> {
632 let pattern = b"ZstdStream test data with repeating pattern ABCDEFGH ";
633 let mut data = Vec::with_capacity(size);
634 while data.len() < size {
635 let remaining = size - data.len();
636 let chunk = &pattern[..remaining.min(pattern.len())];
637 data.extend_from_slice(chunk);
638 }
639 data
640 }
641
642 #[test]
643 fn test_zstd_stream_encoder_progress_reports() {
644 let data = make_compressible_data(1024 * 1024); let calls: ProgressLog = Arc::new(Mutex::new(Vec::new()));
647 let sink = Arc::new(MockSink(calls.clone()));
648
649 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1)
650 .with_progress(sink as oxiarc_core::progress::ProgressHandle);
651 encoder.write_all(&data).expect("write_all failed");
652 encoder.finish().expect("finish failed");
653
654 let recorded = calls.lock().expect("lock poisoned");
655 assert!(!recorded.is_empty(), "expected at least one progress call");
656 let (last_processed, _) = *recorded.last().expect("non-empty");
657 assert_eq!(
658 last_processed,
659 data.len() as u64,
660 "final processed count must equal input size"
661 );
662 }
663
664 #[test]
665 fn test_zstd_stream_encoder_cancel_aborts() {
666 let data = make_compressible_data(1024 * 1024);
667 let token = CancellationToken::new();
668
669 let mut encoder = ZstdStreamEncoder::new(Vec::new(), 1)
671 .with_block_size(4096)
672 .with_cancel(token.clone());
673
674 token.cancel();
675 let result = encoder.write_all(&data);
676 assert!(result.is_err(), "expected cancellation error");
677 }
678
679 #[test]
680 fn test_zstd_stream_decoder_progress_reports() {
681 let data = make_compressible_data(1024 * 1024); let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
684 enc.write_all(&data).expect("write_all failed");
685 let compressed = enc.finish().expect("finish failed");
686
687 let calls: ProgressLog = Arc::new(Mutex::new(Vec::new()));
688 let sink = Arc::new(MockSink(calls.clone()));
689
690 let mut decoder = ZstdStreamDecoder::new(&compressed[..])
691 .with_progress(sink as oxiarc_core::progress::ProgressHandle);
692 let mut output = Vec::new();
693 decoder
694 .read_to_end(&mut output)
695 .expect("read_to_end failed");
696
697 let recorded = calls.lock().expect("lock poisoned");
698 assert!(!recorded.is_empty(), "expected at least one progress call");
699 let (last_processed, _) = *recorded.last().expect("non-empty");
700 assert_eq!(
701 last_processed,
702 data.len() as u64,
703 "final processed count must equal decompressed size"
704 );
705 }
706
707 #[test]
708 fn test_zstd_stream_decoder_cancel_aborts() {
709 let data = make_compressible_data(1024 * 1024);
710 let mut enc = ZstdStreamEncoder::new(Vec::new(), 1);
711 enc.write_all(&data).expect("write_all failed");
712 let compressed = enc.finish().expect("finish failed");
713
714 let token = CancellationToken::new();
715 let mut decoder = ZstdStreamDecoder::new(&compressed[..]).with_cancel(token.clone());
716 let mut output = Vec::new();
717
718 token.cancel();
719 let result = decoder.read_to_end(&mut output);
720 assert!(result.is_err(), "expected cancellation error");
721 }
722}