Skip to main content

oxiarc_zstd/
encode.rs

1//! Zstandard encoder (frame construction).
2//!
3//! This module provides Zstandard compression with multiple strategies:
4//! - **Level 0**: Raw/RLE blocks only (no LZ77 compression)
5//! - **Levels 1-22**: Full LZ77 + Huffman + FSE compressed blocks
6//!
7//! Creates valid Zstd frames compatible with any decoder.
8
9use crate::compressed_block::encode_compressed_block;
10use crate::lz77::{LevelConfig, MatchFinder};
11use crate::xxhash::xxhash64_checksum;
12use crate::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use oxiarc_core::error::Result;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// Compression strategy for block encoding.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum CompressionStrategy {
21    /// Use raw blocks only (no compression).
22    Raw,
23    /// Use RLE blocks for homogeneous data, raw otherwise.
24    #[default]
25    RleOnly,
26}
27
28/// Zstandard encoder.
29///
30/// Supports multiple compression levels (0-22) with LZ77 matching,
31/// Huffman literal encoding, and FSE sequence encoding.
32#[derive(Debug, Clone)]
33pub struct ZstdEncoder {
34    /// Include content checksum in output.
35    include_checksum: bool,
36    /// Include content size in header.
37    include_content_size: bool,
38    /// Compression strategy (used when level == 0).
39    strategy: CompressionStrategy,
40    /// Compression level (0 = raw/RLE, 1-22 = LZ77 compression).
41    level: i32,
42    /// Optional dictionary for improved compression of small data.
43    dictionary: Option<Vec<u8>>,
44    /// Dictionary ID (XXH64 of dictionary data, lower 32 bits).
45    dict_id: Option<u32>,
46}
47
48impl ZstdEncoder {
49    /// Create a new encoder with default settings (level 0, RLE strategy).
50    pub fn new() -> Self {
51        Self {
52            include_checksum: true,
53            include_content_size: true,
54            strategy: CompressionStrategy::default(),
55            level: 0,
56            dictionary: None,
57            dict_id: None,
58        }
59    }
60
61    /// Set whether to include content checksum.
62    pub fn set_checksum(&mut self, include: bool) -> &mut Self {
63        self.include_checksum = include;
64        self
65    }
66
67    /// Set whether to include content size in header.
68    pub fn set_content_size(&mut self, include: bool) -> &mut Self {
69        self.include_content_size = include;
70        self
71    }
72
73    /// Set compression strategy (only effective when level == 0).
74    pub fn set_strategy(&mut self, strategy: CompressionStrategy) -> &mut Self {
75        self.strategy = strategy;
76        self
77    }
78
79    /// Set compression level (0-22).
80    ///
81    /// - Level 0: Raw/RLE blocks (fastest, no compression)
82    /// - Levels 1-3: Fast compression (greedy matching)
83    /// - Levels 4-9: Balanced compression (lazy matching)
84    /// - Levels 10-22: High compression (deep search)
85    pub fn set_level(&mut self, level: i32) -> &mut Self {
86        self.level = level.clamp(0, 22);
87        self
88    }
89
90    /// Set a pre-trained dictionary for improved compression of small data.
91    pub fn set_dictionary(&mut self, dict: &[u8]) -> &mut Self {
92        if dict.is_empty() {
93            self.dictionary = None;
94            self.dict_id = None;
95        } else {
96            let id = crate::xxhash::xxhash64(dict) as u32;
97            self.dictionary = Some(dict.to_vec());
98            self.dict_id = Some(id);
99        }
100        self
101    }
102
103    /// Compress data into a Zstandard frame.
104    ///
105    /// Uses the configured compression level and strategy.
106    pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
107        let mut output = Vec::with_capacity(data.len() + 32);
108
109        // Write magic number
110        output.extend_from_slice(&ZSTD_MAGIC);
111
112        // Write frame header
113        self.write_frame_header(&mut output, data.len());
114
115        // Write blocks with compression
116        if self.level > 0 {
117            self.write_compressed_blocks(&mut output, data)?;
118        } else {
119            self.write_blocks(&mut output, data);
120        }
121
122        // Write content checksum if enabled
123        if self.include_checksum {
124            let checksum = xxhash64_checksum(data);
125            output.extend_from_slice(&checksum.to_le_bytes());
126        }
127
128        Ok(output)
129    }
130
131    /// Compress data into a Zstandard frame using parallel block compression
132    /// (requires `parallel` feature).
133    #[cfg(feature = "parallel")]
134    pub fn compress_parallel(&self, data: &[u8]) -> Result<Vec<u8>> {
135        let mut output = Vec::with_capacity(data.len() + 32);
136
137        // Write magic number
138        output.extend_from_slice(&ZSTD_MAGIC);
139
140        // Write frame header
141        self.write_frame_header(&mut output, data.len());
142
143        // Split data into blocks
144        if data.is_empty() {
145            write_empty_block(&mut output);
146        } else {
147            let chunks: Vec<&[u8]> = data.chunks(MAX_BLOCK_SIZE).collect();
148
149            // Process blocks in parallel
150            let block_data: Vec<(bool, Vec<u8>)> = chunks
151                .par_iter()
152                .enumerate()
153                .map(|(idx, chunk)| {
154                    let is_last = idx == chunks.len() - 1;
155
156                    // Try RLE encoding if strategy allows
157                    if self.strategy == CompressionStrategy::RleOnly {
158                        if let Some(rle_byte) = detect_rle(chunk) {
159                            let mut block_output = Vec::new();
160                            write_rle_block_to(&mut block_output, rle_byte, chunk.len(), is_last);
161                            return (is_last, block_output);
162                        }
163                    }
164
165                    // Fall back to raw block
166                    let mut block_output = Vec::new();
167                    write_raw_block_to(&mut block_output, chunk, is_last);
168                    (is_last, block_output)
169                })
170                .collect();
171
172            // Assemble blocks sequentially
173            for (_is_last, block_bytes) in block_data {
174                output.extend_from_slice(&block_bytes);
175            }
176        }
177
178        // Write content checksum if enabled
179        if self.include_checksum {
180            let checksum = xxhash64_checksum(data);
181            output.extend_from_slice(&checksum.to_le_bytes());
182        }
183
184        Ok(output)
185    }
186
187    /// Write frame header descriptor.
188    fn write_frame_header(&self, output: &mut Vec<u8>, content_size: usize) {
189        let mut descriptor: u8 = 0;
190
191        if self.include_checksum {
192            descriptor |= 0x04; // Content_Checksum_flag
193        }
194
195        // Single_Segment_flag = 1 (no window descriptor needed)
196        descriptor |= 0x20;
197
198        // Dictionary ID flag
199        let dict_id_flag = if self.dict_id.is_some() { 3u8 } else { 0u8 };
200        descriptor |= dict_id_flag;
201
202        // Determine content size encoding
203        let (fcs_flag, fcs_bytes) = if !self.include_content_size || content_size <= 255 {
204            (0u8, 1)
205        } else if content_size <= 65535 + 256 {
206            (1u8, 2)
207        } else if content_size <= u32::MAX as usize {
208            (2u8, 4)
209        } else {
210            (3u8, 8)
211        };
212
213        descriptor |= fcs_flag << 6;
214        output.push(descriptor);
215
216        // Write Dictionary_ID (4 bytes if present)
217        if let Some(id) = self.dict_id {
218            output.extend_from_slice(&id.to_le_bytes());
219        }
220
221        // Write Frame_Content_Size (required for single segment)
222        match fcs_bytes {
223            1 => {
224                output.push(content_size as u8);
225            }
226            2 => {
227                let adjusted = (content_size - 256) as u16;
228                output.extend_from_slice(&adjusted.to_le_bytes());
229            }
230            4 => {
231                output.extend_from_slice(&(content_size as u32).to_le_bytes());
232            }
233            8 => {
234                output.extend_from_slice(&(content_size as u64).to_le_bytes());
235            }
236            _ => unreachable!(),
237        }
238    }
239
240    /// Write data as raw/RLE blocks (level 0).
241    fn write_blocks(&self, output: &mut Vec<u8>, data: &[u8]) {
242        if data.is_empty() {
243            write_empty_block(output);
244            return;
245        }
246
247        let mut offset = 0;
248        while offset < data.len() {
249            let remaining = data.len() - offset;
250            let block_size = remaining.min(MAX_BLOCK_SIZE);
251            let is_last = offset + block_size >= data.len();
252            let block_data = &data[offset..offset + block_size];
253
254            // Try RLE encoding if strategy allows
255            if self.strategy == CompressionStrategy::RleOnly {
256                if let Some(rle_byte) = detect_rle(block_data) {
257                    write_rle_block_to(output, rle_byte, block_size, is_last);
258                    offset += block_size;
259                    continue;
260                }
261            }
262
263            // Fall back to raw block
264            write_raw_block_to(output, block_data, is_last);
265            offset += block_size;
266        }
267    }
268
269    /// Write data as compressed blocks using LZ77 (levels 1-22).
270    fn write_compressed_blocks(&self, output: &mut Vec<u8>, data: &[u8]) -> Result<()> {
271        if data.is_empty() {
272            write_empty_block(output);
273            return Ok(());
274        }
275
276        let config = LevelConfig::for_level(self.level);
277        let mut finder = MatchFinder::new(&config);
278        let dict = self.dictionary.as_deref().unwrap_or(&[]);
279
280        let mut offset = 0;
281        while offset < data.len() {
282            let remaining = data.len() - offset;
283            let block_size = remaining.min(config.target_block_size);
284            let is_last = offset + block_size >= data.len();
285            let block_data = &data[offset..offset + block_size];
286
287            // Try RLE first (always efficient for homogeneous data)
288            if let Some(rle_byte) = detect_rle(block_data) {
289                write_rle_block_to(output, rle_byte, block_size, is_last);
290                offset += block_size;
291                continue;
292            }
293
294            // Find LZ77 matches
295            let sequences = finder.find_sequences(block_data, dict)?;
296
297            // Try to encode as a compressed block
298            match encode_compressed_block(&sequences) {
299                Ok(compressed_content) => {
300                    // Only use compressed block if it's actually smaller
301                    if compressed_content.len() < block_data.len() {
302                        write_compressed_block_to(output, &compressed_content, is_last);
303                    } else {
304                        // Compressed is larger, use raw block
305                        write_raw_block_to(output, block_data, is_last);
306                    }
307                }
308                Err(_) => {
309                    // Compression failed, fall back to raw block
310                    write_raw_block_to(output, block_data, is_last);
311                }
312            }
313
314            finder.reset();
315            offset += block_size;
316        }
317
318        Ok(())
319    }
320}
321
322impl Default for ZstdEncoder {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328// --- Block writing helpers ---
329
330/// Write an empty last block.
331fn write_empty_block(output: &mut Vec<u8>) {
332    let block_header: u32 = 1; // last=1, type=Raw(0), size=0
333    output.push((block_header & 0xFF) as u8);
334    output.push(((block_header >> 8) & 0xFF) as u8);
335    output.push(((block_header >> 16) & 0xFF) as u8);
336}
337
338/// Write a raw (uncompressed) block.
339fn write_raw_block_to(output: &mut Vec<u8>, data: &[u8], is_last: bool) {
340    let last_flag = if is_last { 1u32 } else { 0u32 };
341    let block_header: u32 = last_flag | ((data.len() as u32) << 3);
342    output.push((block_header & 0xFF) as u8);
343    output.push(((block_header >> 8) & 0xFF) as u8);
344    output.push(((block_header >> 16) & 0xFF) as u8);
345    output.extend_from_slice(data);
346}
347
348/// Write an RLE block.
349fn write_rle_block_to(output: &mut Vec<u8>, byte: u8, size: usize, is_last: bool) {
350    let last_flag = if is_last { 1u32 } else { 0u32 };
351    let block_type = 1u32 << 1; // RLE = 1
352    let block_header: u32 = last_flag | block_type | ((size as u32) << 3);
353    output.push((block_header & 0xFF) as u8);
354    output.push(((block_header >> 8) & 0xFF) as u8);
355    output.push(((block_header >> 16) & 0xFF) as u8);
356    output.push(byte);
357}
358
359/// Write a compressed block.
360fn write_compressed_block_to(output: &mut Vec<u8>, content: &[u8], is_last: bool) {
361    let last_flag = if is_last { 1u32 } else { 0u32 };
362    let block_type = 2u32 << 1; // Compressed = 2
363    let block_header: u32 = last_flag | block_type | ((content.len() as u32) << 3);
364    output.push((block_header & 0xFF) as u8);
365    output.push(((block_header >> 8) & 0xFF) as u8);
366    output.push(((block_header >> 16) & 0xFF) as u8);
367    output.extend_from_slice(content);
368}
369
370/// Detect if block can be encoded as RLE (all bytes the same).
371fn detect_rle(data: &[u8]) -> Option<u8> {
372    if data.is_empty() {
373        return None;
374    }
375    let first = data[0];
376    for chunk in data.chunks(16) {
377        if !chunk.iter().all(|&b| b == first) {
378            return None;
379        }
380    }
381    Some(first)
382}
383
384// --- Convenience functions ---
385
386/// Compress data using default settings (raw/RLE blocks, level 0).
387///
388/// For actual LZ77 compression, use [`compress_with_level`] or configure
389/// [`ZstdEncoder`] with [`set_level`](ZstdEncoder::set_level).
390pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
391    ZstdEncoder::new().compress(data)
392}
393
394/// Compress data with a specific compression level (1-22).
395///
396/// This is the primary compression function for most use cases.
397///
398/// # Arguments
399/// * `data` - Data to compress
400/// * `level` - Compression level (1 = fastest, 22 = best compression)
401pub fn compress_with_level(data: &[u8], level: i32) -> Result<Vec<u8>> {
402    let mut encoder = ZstdEncoder::new();
403    encoder.set_level(level);
404    encoder.compress(data)
405}
406
407/// Compress data without checksum.
408pub fn compress_no_checksum(data: &[u8]) -> Result<Vec<u8>> {
409    let mut encoder = ZstdEncoder::new();
410    encoder.set_checksum(false);
411    encoder.compress(data)
412}
413
414/// Compress data using parallel block compression (requires `parallel` feature).
415#[cfg(feature = "parallel")]
416pub fn compress_parallel(data: &[u8]) -> Result<Vec<u8>> {
417    ZstdEncoder::new().compress_parallel(data)
418}
419
420/// Convenience function: compress data and return bytes (compatible with
421/// `zstd::encode_all` pattern).
422///
423/// # Arguments
424/// * `data` - Data to compress (implements `AsRef<[u8]>`)
425/// * `level` - Compression level (1-22)
426pub fn encode_all(data: &[u8], level: i32) -> Result<Vec<u8>> {
427    compress_with_level(data, level)
428}
429
430/// Convenience function: decompress data (compatible with `zstd::decode_all` pattern).
431pub fn decode_all(data: &[u8]) -> Result<Vec<u8>> {
432    crate::decompress(data)
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::decompress;
439
440    #[test]
441    fn test_compress_empty() {
442        let data: &[u8] = &[];
443        let compressed = compress(data).unwrap();
444        assert_eq!(&compressed[0..4], &ZSTD_MAGIC);
445        let decompressed = decompress(&compressed).unwrap();
446        assert_eq!(decompressed, data);
447    }
448
449    #[test]
450    fn test_compress_small() {
451        let data = b"Hello, Zstandard!";
452        let compressed = compress(data).unwrap();
453        let decompressed = decompress(&compressed).unwrap();
454        assert_eq!(decompressed, data.as_slice());
455    }
456
457    #[test]
458    fn test_compress_larger() {
459        let data = vec![0x42u8; 1000];
460        let compressed = compress(&data).unwrap();
461        let decompressed = decompress(&compressed).unwrap();
462        assert_eq!(decompressed, data);
463    }
464
465    #[test]
466    fn test_compress_multi_block() {
467        let data = vec![0xABu8; MAX_BLOCK_SIZE + 1000];
468        let compressed = compress(&data).unwrap();
469        let decompressed = decompress(&compressed).unwrap();
470        assert_eq!(decompressed, data);
471    }
472
473    #[test]
474    fn test_compress_no_checksum() {
475        let data = b"Test without checksum";
476        let compressed = compress_no_checksum(data).unwrap();
477        let decompressed = decompress(&compressed).unwrap();
478        assert_eq!(decompressed, data.as_slice());
479    }
480
481    #[test]
482    fn test_encoder_builder() {
483        let data = b"Builder pattern test";
484        let mut encoder = ZstdEncoder::new();
485        encoder.set_checksum(true).set_content_size(true);
486        let compressed = encoder.compress(data).unwrap();
487        let decompressed = decompress(&compressed).unwrap();
488        assert_eq!(decompressed, data.as_slice());
489    }
490
491    #[test]
492    fn test_various_sizes() {
493        for size in [0, 1, 10, 100, 255, 256, 257, 1000, 65535, 65536, 100000] {
494            let data = vec![0x55u8; size];
495            let compressed = compress(&data).unwrap();
496            let decompressed = decompress(&compressed).unwrap();
497            assert_eq!(decompressed, data, "Failed for size {}", size);
498        }
499    }
500
501    #[test]
502    fn test_rle_compression() {
503        let data = vec![0xAAu8; 10000];
504        let compressed = compress(&data).unwrap();
505        assert!(
506            compressed.len() < data.len() / 10,
507            "RLE compression failed: {} vs {}",
508            compressed.len(),
509            data.len()
510        );
511        let decompressed = decompress(&compressed).unwrap();
512        assert_eq!(decompressed, data);
513    }
514
515    #[test]
516    fn test_rle_multi_block() {
517        let data = vec![0xBBu8; MAX_BLOCK_SIZE * 3];
518        let compressed = compress(&data).unwrap();
519        assert!(
520            compressed.len() < 100,
521            "Expected small output, got {}",
522            compressed.len()
523        );
524        let decompressed = decompress(&compressed).unwrap();
525        assert_eq!(decompressed, data);
526    }
527
528    #[test]
529    fn test_rle_mixed_data() {
530        let mut data = vec![0xCCu8; 1000];
531        data.extend_from_slice(b"Hello, World!");
532        data.extend_from_slice(&vec![0xDDu8; 1000]);
533        let compressed = compress(&data).unwrap();
534        let decompressed = decompress(&compressed).unwrap();
535        assert_eq!(decompressed, data);
536    }
537
538    #[test]
539    fn test_detect_rle() {
540        assert_eq!(detect_rle(&[0xAA; 100]), Some(0xAA));
541        assert_eq!(detect_rle(&[0x00; 50]), Some(0x00));
542        assert_eq!(detect_rle(&[0xFF]), Some(0xFF));
543        assert_eq!(detect_rle(&[0xAA, 0xAA, 0xBB]), None);
544        assert_eq!(detect_rle(&[0x00, 0x01]), None);
545        assert_eq!(detect_rle(&[]), None);
546    }
547
548    #[test]
549    fn test_raw_strategy() {
550        let data = vec![0xEEu8; 1000];
551        let mut encoder = ZstdEncoder::new();
552        encoder.set_strategy(CompressionStrategy::Raw);
553        let compressed = encoder.compress(&data).unwrap();
554        assert!(compressed.len() > data.len());
555        let decompressed = decompress(&compressed).unwrap();
556        assert_eq!(decompressed, data);
557    }
558
559    #[test]
560    fn test_compress_with_level() {
561        // Test that level-based compression produces valid output
562        let data = b"The quick brown fox jumps over the lazy dog. \
563                     The quick brown fox jumps over the lazy dog. \
564                     The quick brown fox jumps over the lazy dog.";
565
566        for level in [1, 3, 6, 9, 15, 22] {
567            let compressed = compress_with_level(data, level).unwrap();
568            let decompressed = decompress(&compressed).unwrap();
569            assert_eq!(
570                decompressed,
571                data.as_slice(),
572                "Roundtrip failed for level {}",
573                level
574            );
575        }
576    }
577
578    #[test]
579    fn test_encode_all_decode_all() {
580        let data = b"Testing encode_all and decode_all convenience functions";
581        let compressed = encode_all(data, 3).unwrap();
582        let decompressed = decode_all(&compressed).unwrap();
583        assert_eq!(decompressed, data.as_slice());
584    }
585
586    #[test]
587    fn test_level_compression_ratio() {
588        // Repetitive data should compress with LZ77
589        let mut data = Vec::new();
590        for _ in 0..100 {
591            data.extend_from_slice(b"ABCDEFGHIJKLMNOP");
592        }
593
594        let raw = compress(&data).unwrap();
595        let level3 = compress_with_level(&data, 3).unwrap();
596
597        // Level 3 should produce smaller output than raw for repetitive data
598        assert!(
599            level3.len() <= raw.len(),
600            "Level 3 ({}) should be <= raw ({}) for repetitive data",
601            level3.len(),
602            raw.len()
603        );
604
605        // Both should decompress correctly
606        assert_eq!(decompress(&raw).unwrap(), data);
607        assert_eq!(decompress(&level3).unwrap(), data);
608    }
609
610    #[test]
611    fn test_large_data_roundtrip() {
612        // Simulate compressible data similar to what network compression tests use.
613        let mut data = Vec::with_capacity(16384);
614        let pattern = b"RDF triple: <http://example.org/subject> <http://example.org/predicate> \"value\" .\n";
615        while data.len() < 16384 {
616            data.extend_from_slice(pattern);
617        }
618        data.truncate(16384);
619
620        for level in [1, 3] {
621            let compressed = encode_all(&data, level).unwrap();
622            let decompressed = decode_all(&compressed).unwrap();
623            assert_eq!(
624                decompressed, data,
625                "Large roundtrip failed for level {}",
626                level
627            );
628        }
629    }
630
631    #[test]
632    #[cfg(feature = "parallel")]
633    fn test_parallel_roundtrip_basic() {
634        let data = b"Hello, World! Parallel Zstandard compression.";
635        let compressed = compress_parallel(data).unwrap();
636        let decompressed = decompress(&compressed).unwrap();
637        assert_eq!(decompressed, data.as_slice());
638    }
639
640    #[test]
641    #[cfg(feature = "parallel")]
642    fn test_parallel_roundtrip_large() {
643        let data = vec![0xABu8; 5_000_000];
644        let compressed = compress_parallel(&data).unwrap();
645        let decompressed = decompress(&compressed).unwrap();
646        assert_eq!(decompressed, data);
647    }
648
649    #[test]
650    #[cfg(feature = "parallel")]
651    fn test_parallel_rle_compression() {
652        let data = vec![0xCCu8; 2_000_000];
653        let compressed = compress_parallel(&data).unwrap();
654        assert!(compressed.len() < data.len() / 100);
655        let decompressed = decompress(&compressed).unwrap();
656        assert_eq!(decompressed, data);
657    }
658
659    #[test]
660    #[cfg(feature = "parallel")]
661    fn test_parallel_empty() {
662        let data: &[u8] = &[];
663        let compressed = compress_parallel(data).unwrap();
664        let decompressed = decompress(&compressed).unwrap();
665        assert_eq!(decompressed, data);
666    }
667
668    #[test]
669    #[cfg(feature = "parallel")]
670    fn test_parallel_vs_serial() {
671        let data = b"Testing parallel vs serial compression output.";
672        let serial = compress(data).unwrap();
673        let parallel = compress_parallel(data).unwrap();
674        let serial_decompressed = decompress(&serial).unwrap();
675        let parallel_decompressed = decompress(&parallel).unwrap();
676        assert_eq!(serial_decompressed, data.as_slice());
677        assert_eq!(parallel_decompressed, data.as_slice());
678    }
679
680    #[test]
681    #[cfg(feature = "parallel")]
682    fn test_parallel_encoder_options() {
683        let data = vec![0xFFu8; 1_000_000];
684        let mut encoder = ZstdEncoder::new();
685        encoder
686            .set_checksum(false)
687            .set_strategy(CompressionStrategy::RleOnly);
688        let compressed = encoder.compress_parallel(&data).unwrap();
689        let decompressed = decompress(&compressed).unwrap();
690        assert_eq!(decompressed, data);
691    }
692
693    #[test]
694    #[cfg(feature = "parallel")]
695    fn test_parallel_multi_block() {
696        let data = vec![0x55u8; MAX_BLOCK_SIZE * 3 + 5000];
697        let compressed = compress_parallel(&data).unwrap();
698        let decompressed = decompress(&compressed).unwrap();
699        assert_eq!(decompressed, data);
700    }
701}