Skip to main content

oxiarc_zstd/
encode.rs

1//! Zstandard encoder (frame construction).
2//!
3//! This module provides Zstandard compression with multiple strategies:
4//! - **Level 0**: Raw/RLE blocks only (no LZ77 compression)
5//! - **Levels 1-22**: Full LZ77 + Huffman + FSE compressed blocks
6//!
7//! Creates valid Zstd frames compatible with any decoder.
8
9use crate::compressed_block::encode_compressed_block;
10use crate::lz77::{LevelConfig, MatchFinder};
11use crate::xxhash::xxhash64_checksum;
12use crate::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use oxiarc_core::cancel::CancellationToken;
14use oxiarc_core::error::Result;
15use oxiarc_core::progress::ProgressHandle;
16
17#[cfg(feature = "parallel")]
18use rayon::prelude::*;
19
20/// Compression strategy for block encoding.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum CompressionStrategy {
23    /// Use raw blocks only (no compression).
24    Raw,
25    /// Use RLE blocks for homogeneous data, raw otherwise.
26    #[default]
27    RleOnly,
28}
29
30/// Zstandard encoder.
31///
32/// Supports multiple compression levels (0-22) with LZ77 matching,
33/// Huffman literal encoding, and FSE sequence encoding.
34///
35/// Supports optional progress reporting via [`ProgressHandle`] and
36/// cooperative cancellation via [`CancellationToken`] using the
37/// [`ZstdEncoder::with_progress`] / [`ZstdEncoder::with_cancel`] builders.
38#[derive(Clone)]
39pub struct ZstdEncoder {
40    /// Include content checksum in output.
41    include_checksum: bool,
42    /// Include content size in header.
43    include_content_size: bool,
44    /// Compression strategy (used when level == 0).
45    strategy: CompressionStrategy,
46    /// Compression level (0 = raw/RLE, 1-22 = LZ77 compression).
47    level: i32,
48    /// Optional dictionary for improved compression of small data.
49    dictionary: Option<Vec<u8>>,
50    /// Dictionary ID (XXH64 of dictionary data, lower 32 bits).
51    dict_id: Option<u32>,
52    /// Optional progress sink. Notified after each block is written.
53    progress: Option<ProgressHandle>,
54    /// Optional cancellation token. Checked before each block.
55    cancel: Option<CancellationToken>,
56}
57
58impl std::fmt::Debug for ZstdEncoder {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("ZstdEncoder")
61            .field("level", &self.level)
62            .field("include_checksum", &self.include_checksum)
63            .field("include_content_size", &self.include_content_size)
64            .finish()
65    }
66}
67
68impl ZstdEncoder {
69    /// Create a new encoder with default settings (level 0, RLE strategy).
70    pub fn new() -> Self {
71        Self {
72            include_checksum: true,
73            include_content_size: true,
74            strategy: CompressionStrategy::default(),
75            level: 0,
76            dictionary: None,
77            dict_id: None,
78            progress: None,
79            cancel: None,
80        }
81    }
82
83    /// Attach a progress sink.
84    ///
85    /// The sink's `on_progress(bytes_processed, None)` is called after each
86    /// block is written to the output. `on_finish()` is called after the
87    /// content checksum is written.
88    pub fn with_progress(mut self, handle: ProgressHandle) -> Self {
89        self.progress = Some(handle);
90        self
91    }
92
93    /// Attach a cancellation token.
94    ///
95    /// The token is checked before each block is encoded.
96    /// If cancelled, returns [`oxiarc_core::error::OxiArcError::Cancelled`].
97    pub fn with_cancel(mut self, token: CancellationToken) -> Self {
98        self.cancel = Some(token);
99        self
100    }
101
102    /// Set whether to include content checksum.
103    pub fn set_checksum(&mut self, include: bool) -> &mut Self {
104        self.include_checksum = include;
105        self
106    }
107
108    /// Set whether to include content size in header.
109    pub fn set_content_size(&mut self, include: bool) -> &mut Self {
110        self.include_content_size = include;
111        self
112    }
113
114    /// Set compression strategy (only effective when level == 0).
115    pub fn set_strategy(&mut self, strategy: CompressionStrategy) -> &mut Self {
116        self.strategy = strategy;
117        self
118    }
119
120    /// Set compression level (0-22).
121    ///
122    /// - Level 0: Raw/RLE blocks (fastest, no compression)
123    /// - Levels 1-3: Fast compression (greedy matching)
124    /// - Levels 4-9: Balanced compression (lazy matching)
125    /// - Levels 10-22: High compression (deep search)
126    pub fn set_level(&mut self, level: i32) -> &mut Self {
127        self.level = level.clamp(0, 22);
128        self
129    }
130
131    /// Set a pre-trained dictionary for improved compression of small data.
132    pub fn set_dictionary(&mut self, dict: &[u8]) -> &mut Self {
133        if dict.is_empty() {
134            self.dictionary = None;
135            self.dict_id = None;
136        } else {
137            let id = crate::xxhash::xxhash64(dict) as u32;
138            self.dictionary = Some(dict.to_vec());
139            self.dict_id = Some(id);
140        }
141        self
142    }
143
144    /// Compress data into a Zstandard frame.
145    ///
146    /// Uses the configured compression level and strategy.
147    pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
148        // Cancellation check at the start of the full operation.
149        if let Some(ref token) = self.cancel {
150            token.check()?;
151        }
152
153        let mut output = Vec::with_capacity(data.len() + 32);
154
155        // Write magic number
156        output.extend_from_slice(&ZSTD_MAGIC);
157
158        // Write frame header
159        self.write_frame_header(&mut output, data.len());
160
161        // Write blocks with compression
162        if self.level > 0 {
163            self.write_compressed_blocks(&mut output, data)?;
164        } else {
165            self.write_blocks(&mut output, data)?;
166        }
167
168        // Write content checksum if enabled
169        if self.include_checksum {
170            let checksum = xxhash64_checksum(data);
171            output.extend_from_slice(&checksum.to_le_bytes());
172        }
173
174        if let Some(ref handle) = self.progress {
175            handle.on_finish();
176        }
177
178        Ok(output)
179    }
180
181    /// Compress data into a Zstandard frame using parallel block compression
182    /// (requires `parallel` feature).
183    #[cfg(feature = "parallel")]
184    pub fn compress_parallel(&self, data: &[u8]) -> Result<Vec<u8>> {
185        let mut output = Vec::with_capacity(data.len() + 32);
186
187        // Write magic number
188        output.extend_from_slice(&ZSTD_MAGIC);
189
190        // Write frame header
191        self.write_frame_header(&mut output, data.len());
192
193        // Split data into blocks
194        if data.is_empty() {
195            write_empty_block(&mut output);
196        } else {
197            let chunks: Vec<&[u8]> = data.chunks(MAX_BLOCK_SIZE).collect();
198
199            // Process blocks in parallel
200            let block_data: Vec<(bool, Vec<u8>)> = chunks
201                .par_iter()
202                .enumerate()
203                .map(|(idx, chunk)| {
204                    let is_last = idx == chunks.len() - 1;
205
206                    // Try RLE encoding if strategy allows
207                    if self.strategy == CompressionStrategy::RleOnly {
208                        if let Some(rle_byte) = detect_rle(chunk) {
209                            let mut block_output = Vec::new();
210                            write_rle_block_to(&mut block_output, rle_byte, chunk.len(), is_last);
211                            return (is_last, block_output);
212                        }
213                    }
214
215                    // Fall back to raw block
216                    let mut block_output = Vec::new();
217                    write_raw_block_to(&mut block_output, chunk, is_last);
218                    (is_last, block_output)
219                })
220                .collect();
221
222            // Assemble blocks sequentially
223            for (_is_last, block_bytes) in block_data {
224                output.extend_from_slice(&block_bytes);
225            }
226        }
227
228        // Write content checksum if enabled
229        if self.include_checksum {
230            let checksum = xxhash64_checksum(data);
231            output.extend_from_slice(&checksum.to_le_bytes());
232        }
233
234        Ok(output)
235    }
236
237    /// Write frame header descriptor.
238    fn write_frame_header(&self, output: &mut Vec<u8>, content_size: usize) {
239        let mut descriptor: u8 = 0;
240
241        if self.include_checksum {
242            descriptor |= 0x04; // Content_Checksum_flag
243        }
244
245        // Single_Segment_flag = 1 (no window descriptor needed)
246        descriptor |= 0x20;
247
248        // Dictionary ID flag
249        let dict_id_flag = if self.dict_id.is_some() { 3u8 } else { 0u8 };
250        descriptor |= dict_id_flag;
251
252        // Determine content size encoding
253        let (fcs_flag, fcs_bytes) = if !self.include_content_size || content_size <= 255 {
254            (0u8, 1)
255        } else if content_size <= 65535 + 256 {
256            (1u8, 2)
257        } else if content_size <= u32::MAX as usize {
258            (2u8, 4)
259        } else {
260            (3u8, 8)
261        };
262
263        descriptor |= fcs_flag << 6;
264        output.push(descriptor);
265
266        // Write Dictionary_ID (4 bytes if present)
267        if let Some(id) = self.dict_id {
268            output.extend_from_slice(&id.to_le_bytes());
269        }
270
271        // Write Frame_Content_Size (required for single segment)
272        match fcs_bytes {
273            1 => {
274                output.push(content_size as u8);
275            }
276            2 => {
277                let adjusted = (content_size - 256) as u16;
278                output.extend_from_slice(&adjusted.to_le_bytes());
279            }
280            4 => {
281                output.extend_from_slice(&(content_size as u32).to_le_bytes());
282            }
283            8 => {
284                output.extend_from_slice(&(content_size as u64).to_le_bytes());
285            }
286            _ => unreachable!(),
287        }
288    }
289
290    /// Write data as raw/RLE blocks (level 0).
291    fn write_blocks(&self, output: &mut Vec<u8>, data: &[u8]) -> Result<()> {
292        if data.is_empty() {
293            write_empty_block(output);
294            return Ok(());
295        }
296
297        let mut offset = 0;
298        let mut bytes_processed: u64 = 0;
299
300        while offset < data.len() {
301            // Cooperative cancellation check before each block.
302            if let Some(ref token) = self.cancel {
303                token.check()?;
304            }
305
306            let remaining = data.len() - offset;
307            let block_size = remaining.min(MAX_BLOCK_SIZE);
308            let is_last = offset + block_size >= data.len();
309            let block_data = &data[offset..offset + block_size];
310
311            // Try RLE encoding if strategy allows
312            if self.strategy == CompressionStrategy::RleOnly {
313                if let Some(rle_byte) = detect_rle(block_data) {
314                    write_rle_block_to(output, rle_byte, block_size, is_last);
315                    offset += block_size;
316                    bytes_processed += block_size as u64;
317                    if let Some(ref handle) = self.progress {
318                        handle.on_progress(bytes_processed, None);
319                    }
320                    continue;
321                }
322            }
323
324            // Fall back to raw block
325            write_raw_block_to(output, block_data, is_last);
326            offset += block_size;
327            bytes_processed += block_size as u64;
328            if let Some(ref handle) = self.progress {
329                handle.on_progress(bytes_processed, None);
330            }
331        }
332
333        Ok(())
334    }
335
336    /// Write data as compressed blocks using LZ77 (levels 1-22).
337    fn write_compressed_blocks(&self, output: &mut Vec<u8>, data: &[u8]) -> Result<()> {
338        if data.is_empty() {
339            write_empty_block(output);
340            return Ok(());
341        }
342
343        let config = LevelConfig::for_level(self.level);
344        let mut finder = MatchFinder::new(&config);
345        let dict = self.dictionary.as_deref().unwrap_or(&[]);
346
347        let mut offset = 0;
348        let mut bytes_processed: u64 = 0;
349
350        while offset < data.len() {
351            // Cooperative cancellation check before each block.
352            if let Some(ref token) = self.cancel {
353                token.check()?;
354            }
355
356            let remaining = data.len() - offset;
357            let block_size = remaining.min(config.target_block_size);
358            let is_last = offset + block_size >= data.len();
359            let block_data = &data[offset..offset + block_size];
360
361            // Try RLE first (always efficient for homogeneous data)
362            if let Some(rle_byte) = detect_rle(block_data) {
363                write_rle_block_to(output, rle_byte, block_size, is_last);
364                offset += block_size;
365                bytes_processed += block_size as u64;
366                if let Some(ref handle) = self.progress {
367                    handle.on_progress(bytes_processed, None);
368                }
369                continue;
370            }
371
372            // Find LZ77 matches
373            let sequences = finder.find_sequences(block_data, dict)?;
374
375            // Try to encode as a compressed block
376            match encode_compressed_block(&sequences) {
377                Ok(compressed_content) => {
378                    // Only use compressed block if it's actually smaller
379                    if compressed_content.len() < block_data.len() {
380                        write_compressed_block_to(output, &compressed_content, is_last);
381                    } else {
382                        // Compressed is larger, use raw block
383                        write_raw_block_to(output, block_data, is_last);
384                    }
385                }
386                Err(_) => {
387                    // Compression failed, fall back to raw block
388                    write_raw_block_to(output, block_data, is_last);
389                }
390            }
391
392            finder.reset();
393            offset += block_size;
394            bytes_processed += block_size as u64;
395            if let Some(ref handle) = self.progress {
396                handle.on_progress(bytes_processed, None);
397            }
398        }
399
400        Ok(())
401    }
402}
403
404impl Default for ZstdEncoder {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410// --- Block writing helpers ---
411
412/// Write an empty last block.
413fn write_empty_block(output: &mut Vec<u8>) {
414    let block_header: u32 = 1; // last=1, type=Raw(0), size=0
415    output.push((block_header & 0xFF) as u8);
416    output.push(((block_header >> 8) & 0xFF) as u8);
417    output.push(((block_header >> 16) & 0xFF) as u8);
418}
419
420/// Write a raw (uncompressed) block.
421fn write_raw_block_to(output: &mut Vec<u8>, data: &[u8], is_last: bool) {
422    let last_flag = if is_last { 1u32 } else { 0u32 };
423    let block_header: u32 = last_flag | ((data.len() as u32) << 3);
424    output.push((block_header & 0xFF) as u8);
425    output.push(((block_header >> 8) & 0xFF) as u8);
426    output.push(((block_header >> 16) & 0xFF) as u8);
427    output.extend_from_slice(data);
428}
429
430/// Write an RLE block.
431fn write_rle_block_to(output: &mut Vec<u8>, byte: u8, size: usize, is_last: bool) {
432    let last_flag = if is_last { 1u32 } else { 0u32 };
433    let block_type = 1u32 << 1; // RLE = 1
434    let block_header: u32 = last_flag | block_type | ((size as u32) << 3);
435    output.push((block_header & 0xFF) as u8);
436    output.push(((block_header >> 8) & 0xFF) as u8);
437    output.push(((block_header >> 16) & 0xFF) as u8);
438    output.push(byte);
439}
440
441/// Write a compressed block.
442fn write_compressed_block_to(output: &mut Vec<u8>, content: &[u8], is_last: bool) {
443    let last_flag = if is_last { 1u32 } else { 0u32 };
444    let block_type = 2u32 << 1; // Compressed = 2
445    let block_header: u32 = last_flag | block_type | ((content.len() as u32) << 3);
446    output.push((block_header & 0xFF) as u8);
447    output.push(((block_header >> 8) & 0xFF) as u8);
448    output.push(((block_header >> 16) & 0xFF) as u8);
449    output.extend_from_slice(content);
450}
451
452/// Detect if block can be encoded as RLE (all bytes the same).
453fn detect_rle(data: &[u8]) -> Option<u8> {
454    if data.is_empty() {
455        return None;
456    }
457    let first = data[0];
458    for chunk in data.chunks(16) {
459        if !chunk.iter().all(|&b| b == first) {
460            return None;
461        }
462    }
463    Some(first)
464}
465
466// --- Convenience functions ---
467
468/// Compress data using default settings (raw/RLE blocks, level 0).
469///
470/// For actual LZ77 compression, use [`compress_with_level`] or configure
471/// [`ZstdEncoder`] with [`set_level`](ZstdEncoder::set_level).
472pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
473    ZstdEncoder::new().compress(data)
474}
475
476/// Compress data with a specific compression level (1-22).
477///
478/// This is the primary compression function for most use cases.
479///
480/// # Arguments
481/// * `data` - Data to compress
482/// * `level` - Compression level (1 = fastest, 22 = best compression)
483pub fn compress_with_level(data: &[u8], level: i32) -> Result<Vec<u8>> {
484    let mut encoder = ZstdEncoder::new();
485    encoder.set_level(level);
486    encoder.compress(data)
487}
488
489/// Compress data without checksum.
490pub fn compress_no_checksum(data: &[u8]) -> Result<Vec<u8>> {
491    let mut encoder = ZstdEncoder::new();
492    encoder.set_checksum(false);
493    encoder.compress(data)
494}
495
496/// Compress data using parallel block compression (requires `parallel` feature).
497#[cfg(feature = "parallel")]
498pub fn compress_parallel(data: &[u8]) -> Result<Vec<u8>> {
499    ZstdEncoder::new().compress_parallel(data)
500}
501
502/// Convenience function: compress data and return bytes (compatible with
503/// `zstd::encode_all` pattern).
504///
505/// # Arguments
506/// * `data` - Data to compress (implements `AsRef<[u8]>`)
507/// * `level` - Compression level (1-22)
508pub fn encode_all(data: &[u8], level: i32) -> Result<Vec<u8>> {
509    compress_with_level(data, level)
510}
511
512/// Convenience function: decompress data (compatible with `zstd::decode_all` pattern).
513pub fn decode_all(data: &[u8]) -> Result<Vec<u8>> {
514    crate::decompress(data)
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::decompress;
521    use crate::frame::decompress_with_dict;
522
523    #[test]
524    fn test_encoder_decoder_dict_roundtrip() {
525        let dict = b"prefix-shared-corpus-prefix-shared-corpus";
526        let data = b"prefix-shared-corpus is here, and prefix-shared-corpus is here again";
527
528        let mut enc = ZstdEncoder::new();
529        enc.set_level(3);
530        enc.set_dictionary(dict);
531        let compressed = enc.compress(data).expect("compress");
532
533        let out = decompress_with_dict(&compressed, dict).expect("decompress with dict");
534        assert_eq!(out.as_slice(), data as &[u8]);
535    }
536
537    #[test]
538    fn test_compress_empty() {
539        let data: &[u8] = &[];
540        let compressed = compress(data).expect("compression failed");
541        assert_eq!(&compressed[0..4], &ZSTD_MAGIC);
542        let decompressed = decompress(&compressed).expect("decompression failed");
543        assert_eq!(decompressed, data);
544    }
545
546    #[test]
547    fn test_compress_small() {
548        let data = b"Hello, Zstandard!";
549        let compressed = compress(data).expect("compression failed");
550        let decompressed = decompress(&compressed).expect("compression failed");
551        assert_eq!(decompressed, data.as_slice());
552    }
553
554    #[test]
555    fn test_compress_larger() {
556        let data = vec![0x42u8; 1000];
557        let compressed = compress(&data).expect("compression failed");
558        let decompressed = decompress(&compressed).expect("compression failed");
559        assert_eq!(decompressed, data);
560    }
561
562    #[test]
563    fn test_compress_multi_block() {
564        let data = vec![0xABu8; MAX_BLOCK_SIZE + 1000];
565        let compressed = compress(&data).expect("compression failed");
566        let decompressed = decompress(&compressed).expect("compression failed");
567        assert_eq!(decompressed, data);
568    }
569
570    #[test]
571    fn test_compress_no_checksum() {
572        let data = b"Test without checksum";
573        let compressed = compress_no_checksum(data).expect("compression failed");
574        let decompressed = decompress(&compressed).expect("compression failed");
575        assert_eq!(decompressed, data.as_slice());
576    }
577
578    #[test]
579    fn test_encoder_builder() {
580        let data = b"Builder pattern test";
581        let mut encoder = ZstdEncoder::new();
582        encoder.set_checksum(true).set_content_size(true);
583        let compressed = encoder.compress(data).expect("compression failed");
584        let decompressed = decompress(&compressed).expect("compression failed");
585        assert_eq!(decompressed, data.as_slice());
586    }
587
588    #[test]
589    fn test_various_sizes() {
590        for size in [0, 1, 10, 100, 255, 256, 257, 1000, 65535, 65536, 100000] {
591            let data = vec![0x55u8; size];
592            let compressed = compress(&data).expect("compression failed");
593            let decompressed = decompress(&compressed).expect("compression failed");
594            assert_eq!(decompressed, data, "Failed for size {}", size);
595        }
596    }
597
598    #[test]
599    fn test_rle_compression() {
600        let data = vec![0xAAu8; 10000];
601        let compressed = compress(&data).expect("compression failed");
602        assert!(
603            compressed.len() < data.len() / 10,
604            "RLE compression failed: {} vs {}",
605            compressed.len(),
606            data.len()
607        );
608        let decompressed = decompress(&compressed).expect("compression failed");
609        assert_eq!(decompressed, data);
610    }
611
612    #[test]
613    fn test_rle_multi_block() {
614        let data = vec![0xBBu8; MAX_BLOCK_SIZE * 3];
615        let compressed = compress(&data).expect("compression failed");
616        assert!(
617            compressed.len() < 100,
618            "Expected small output, got {}",
619            compressed.len()
620        );
621        let decompressed = decompress(&compressed).expect("compression failed");
622        assert_eq!(decompressed, data);
623    }
624
625    #[test]
626    fn test_rle_mixed_data() {
627        let mut data = vec![0xCCu8; 1000];
628        data.extend_from_slice(b"Hello, World!");
629        data.extend_from_slice(&vec![0xDDu8; 1000]);
630        let compressed = compress(&data).expect("compression failed");
631        let decompressed = decompress(&compressed).expect("compression failed");
632        assert_eq!(decompressed, data);
633    }
634
635    #[test]
636    fn test_detect_rle() {
637        assert_eq!(detect_rle(&[0xAA; 100]), Some(0xAA));
638        assert_eq!(detect_rle(&[0x00; 50]), Some(0x00));
639        assert_eq!(detect_rle(&[0xFF]), Some(0xFF));
640        assert_eq!(detect_rle(&[0xAA, 0xAA, 0xBB]), None);
641        assert_eq!(detect_rle(&[0x00, 0x01]), None);
642        assert_eq!(detect_rle(&[]), None);
643    }
644
645    #[test]
646    fn test_raw_strategy() {
647        let data = vec![0xEEu8; 1000];
648        let mut encoder = ZstdEncoder::new();
649        encoder.set_strategy(CompressionStrategy::Raw);
650        let compressed = encoder.compress(&data).expect("compression failed");
651        assert!(compressed.len() > data.len());
652        let decompressed = decompress(&compressed).expect("compression failed");
653        assert_eq!(decompressed, data);
654    }
655
656    #[test]
657    fn test_compress_with_level() {
658        // Test that level-based compression produces valid output
659        let data = b"The quick brown fox jumps over the lazy dog. \
660                     The quick brown fox jumps over the lazy dog. \
661                     The quick brown fox jumps over the lazy dog.";
662
663        for level in [1, 3, 6, 9, 15, 22] {
664            let compressed = compress_with_level(data, level).expect("compression failed");
665            let decompressed = decompress(&compressed).expect("compression failed");
666            assert_eq!(
667                decompressed,
668                data.as_slice(),
669                "Roundtrip failed for level {}",
670                level
671            );
672        }
673    }
674
675    #[test]
676    fn test_encode_all_decode_all() {
677        let data = b"Testing encode_all and decode_all convenience functions";
678        let compressed = encode_all(data, 3).expect("compression failed");
679        let decompressed = decode_all(&compressed).expect("decompression failed");
680        assert_eq!(decompressed, data.as_slice());
681    }
682
683    #[test]
684    fn test_level_compression_ratio() {
685        // Repetitive data should compress with LZ77
686        let mut data = Vec::new();
687        for _ in 0..100 {
688            data.extend_from_slice(b"ABCDEFGHIJKLMNOP");
689        }
690
691        let raw = compress(&data).expect("compression failed");
692        let level3 = compress_with_level(&data, 3).expect("compression failed");
693
694        // Level 3 should produce smaller output than raw for repetitive data
695        assert!(
696            level3.len() <= raw.len(),
697            "Level 3 ({}) should be <= raw ({}) for repetitive data",
698            level3.len(),
699            raw.len()
700        );
701
702        // Both should decompress correctly
703        assert_eq!(decompress(&raw).expect("compression failed"), data);
704        assert_eq!(decompress(&level3).expect("compression failed"), data);
705    }
706
707    #[test]
708    fn test_large_data_roundtrip() {
709        // Simulate compressible data similar to what network compression tests use.
710        let mut data = Vec::with_capacity(16384);
711        let pattern = b"RDF triple: <http://example.org/subject> <http://example.org/predicate> \"value\" .\n";
712        while data.len() < 16384 {
713            data.extend_from_slice(pattern);
714        }
715        data.truncate(16384);
716
717        for level in [1, 3] {
718            let compressed = encode_all(&data, level).expect("compression failed");
719            let decompressed = decode_all(&compressed).expect("decompression failed");
720            assert_eq!(
721                decompressed, data,
722                "Large roundtrip failed for level {}",
723                level
724            );
725        }
726    }
727
728    #[test]
729    #[cfg(feature = "parallel")]
730    fn test_parallel_roundtrip_basic() {
731        let data = b"Hello, World! Parallel Zstandard compression.";
732        let compressed = compress_parallel(data).expect("compression failed");
733        let decompressed = decompress(&compressed).expect("compression failed");
734        assert_eq!(decompressed, data.as_slice());
735    }
736
737    #[test]
738    #[cfg(feature = "parallel")]
739    fn test_parallel_roundtrip_large() {
740        let data = vec![0xABu8; 5_000_000];
741        let compressed = compress_parallel(&data).expect("compression failed");
742        let decompressed = decompress(&compressed).expect("compression failed");
743        assert_eq!(decompressed, data);
744    }
745
746    #[test]
747    #[cfg(feature = "parallel")]
748    fn test_parallel_rle_compression() {
749        let data = vec![0xCCu8; 2_000_000];
750        let compressed = compress_parallel(&data).expect("compression failed");
751        assert!(compressed.len() < data.len() / 100);
752        let decompressed = decompress(&compressed).expect("compression failed");
753        assert_eq!(decompressed, data);
754    }
755
756    #[test]
757    #[cfg(feature = "parallel")]
758    fn test_parallel_empty() {
759        let data: &[u8] = &[];
760        let compressed = compress_parallel(data).expect("compression failed");
761        let decompressed = decompress(&compressed).expect("compression failed");
762        assert_eq!(decompressed, data);
763    }
764
765    #[test]
766    #[cfg(feature = "parallel")]
767    fn test_parallel_vs_serial() {
768        let data = b"Testing parallel vs serial compression output.";
769        let serial = compress(data).expect("compression failed");
770        let parallel = compress_parallel(data).expect("compression failed");
771        let serial_decompressed = decompress(&serial).expect("compression failed");
772        let parallel_decompressed = decompress(&parallel).expect("compression failed");
773        assert_eq!(serial_decompressed, data.as_slice());
774        assert_eq!(parallel_decompressed, data.as_slice());
775    }
776
777    #[test]
778    #[cfg(feature = "parallel")]
779    fn test_parallel_encoder_options() {
780        let data = vec![0xFFu8; 1_000_000];
781        let mut encoder = ZstdEncoder::new();
782        encoder
783            .set_checksum(false)
784            .set_strategy(CompressionStrategy::RleOnly);
785        let compressed = encoder
786            .compress_parallel(&data)
787            .expect("compression failed");
788        let decompressed = decompress(&compressed).expect("compression failed");
789        assert_eq!(decompressed, data);
790    }
791
792    #[test]
793    #[cfg(feature = "parallel")]
794    fn test_parallel_multi_block() {
795        let data = vec![0x55u8; MAX_BLOCK_SIZE * 3 + 5000];
796        let compressed = compress_parallel(&data).expect("compression failed");
797        let decompressed = decompress(&compressed).expect("compression failed");
798        assert_eq!(decompressed, data);
799    }
800
801    use oxiarc_core::cancel::CancellationToken;
802    use oxiarc_core::progress::ProgressSink;
803    use std::sync::{Arc, Mutex};
804
805    type ProgressLog = Arc<Mutex<Vec<(u64, Option<u64>)>>>;
806
807    struct MockSink(ProgressLog);
808
809    impl ProgressSink for MockSink {
810        fn on_progress(&self, processed: u64, total: Option<u64>) {
811            self.0
812                .lock()
813                .expect("lock poisoned")
814                .push((processed, total));
815        }
816    }
817
818    fn make_compressible_data(size: usize) -> Vec<u8> {
819        let pattern = b"ZstdEncoder test data with repeating pattern ABCDEFGH ";
820        let mut data = Vec::with_capacity(size);
821        while data.len() < size {
822            let remaining = size - data.len();
823            let chunk = &pattern[..remaining.min(pattern.len())];
824            data.extend_from_slice(chunk);
825        }
826        data
827    }
828
829    #[test]
830    fn test_zstd_encoder_progress_reports() {
831        let data = make_compressible_data(1024 * 1024); // 1 MB
832
833        let calls: ProgressLog = Arc::new(Mutex::new(Vec::new()));
834        let sink = Arc::new(MockSink(calls.clone()));
835
836        let encoder =
837            ZstdEncoder::new().with_progress(sink as oxiarc_core::progress::ProgressHandle);
838        encoder.compress(&data).expect("compress failed");
839
840        let recorded = calls.lock().expect("lock poisoned");
841        assert!(!recorded.is_empty(), "expected at least one progress call");
842        let (last_processed, _) = *recorded.last().expect("non-empty");
843        assert_eq!(
844            last_processed,
845            data.len() as u64,
846            "final processed count must equal input size"
847        );
848    }
849
850    #[test]
851    fn test_zstd_encoder_cancel_aborts() {
852        let data = make_compressible_data(1024 * 1024);
853        let token = CancellationToken::new();
854        let encoder = ZstdEncoder::new().with_cancel(token.clone());
855
856        token.cancel();
857        let result = encoder.compress(&data);
858        assert!(result.is_err(), "expected cancellation error");
859    }
860}