1use crate::compressed_block::encode_compressed_block;
10use crate::lz77::{LevelConfig, MatchFinder};
11use crate::xxhash::xxhash64_checksum;
12use crate::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
13use oxiarc_core::error::Result;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum CompressionStrategy {
21 Raw,
23 #[default]
25 RleOnly,
26}
27
28#[derive(Debug, Clone)]
33pub struct ZstdEncoder {
34 include_checksum: bool,
36 include_content_size: bool,
38 strategy: CompressionStrategy,
40 level: i32,
42 dictionary: Option<Vec<u8>>,
44 dict_id: Option<u32>,
46}
47
48impl ZstdEncoder {
49 pub fn new() -> Self {
51 Self {
52 include_checksum: true,
53 include_content_size: true,
54 strategy: CompressionStrategy::default(),
55 level: 0,
56 dictionary: None,
57 dict_id: None,
58 }
59 }
60
61 pub fn set_checksum(&mut self, include: bool) -> &mut Self {
63 self.include_checksum = include;
64 self
65 }
66
67 pub fn set_content_size(&mut self, include: bool) -> &mut Self {
69 self.include_content_size = include;
70 self
71 }
72
73 pub fn set_strategy(&mut self, strategy: CompressionStrategy) -> &mut Self {
75 self.strategy = strategy;
76 self
77 }
78
79 pub fn set_level(&mut self, level: i32) -> &mut Self {
86 self.level = level.clamp(0, 22);
87 self
88 }
89
90 pub fn set_dictionary(&mut self, dict: &[u8]) -> &mut Self {
92 if dict.is_empty() {
93 self.dictionary = None;
94 self.dict_id = None;
95 } else {
96 let id = crate::xxhash::xxhash64(dict) as u32;
97 self.dictionary = Some(dict.to_vec());
98 self.dict_id = Some(id);
99 }
100 self
101 }
102
103 pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
107 let mut output = Vec::with_capacity(data.len() + 32);
108
109 output.extend_from_slice(&ZSTD_MAGIC);
111
112 self.write_frame_header(&mut output, data.len());
114
115 if self.level > 0 {
117 self.write_compressed_blocks(&mut output, data)?;
118 } else {
119 self.write_blocks(&mut output, data);
120 }
121
122 if self.include_checksum {
124 let checksum = xxhash64_checksum(data);
125 output.extend_from_slice(&checksum.to_le_bytes());
126 }
127
128 Ok(output)
129 }
130
131 #[cfg(feature = "parallel")]
134 pub fn compress_parallel(&self, data: &[u8]) -> Result<Vec<u8>> {
135 let mut output = Vec::with_capacity(data.len() + 32);
136
137 output.extend_from_slice(&ZSTD_MAGIC);
139
140 self.write_frame_header(&mut output, data.len());
142
143 if data.is_empty() {
145 write_empty_block(&mut output);
146 } else {
147 let chunks: Vec<&[u8]> = data.chunks(MAX_BLOCK_SIZE).collect();
148
149 let block_data: Vec<(bool, Vec<u8>)> = chunks
151 .par_iter()
152 .enumerate()
153 .map(|(idx, chunk)| {
154 let is_last = idx == chunks.len() - 1;
155
156 if self.strategy == CompressionStrategy::RleOnly {
158 if let Some(rle_byte) = detect_rle(chunk) {
159 let mut block_output = Vec::new();
160 write_rle_block_to(&mut block_output, rle_byte, chunk.len(), is_last);
161 return (is_last, block_output);
162 }
163 }
164
165 let mut block_output = Vec::new();
167 write_raw_block_to(&mut block_output, chunk, is_last);
168 (is_last, block_output)
169 })
170 .collect();
171
172 for (_is_last, block_bytes) in block_data {
174 output.extend_from_slice(&block_bytes);
175 }
176 }
177
178 if self.include_checksum {
180 let checksum = xxhash64_checksum(data);
181 output.extend_from_slice(&checksum.to_le_bytes());
182 }
183
184 Ok(output)
185 }
186
187 fn write_frame_header(&self, output: &mut Vec<u8>, content_size: usize) {
189 let mut descriptor: u8 = 0;
190
191 if self.include_checksum {
192 descriptor |= 0x04; }
194
195 descriptor |= 0x20;
197
198 let dict_id_flag = if self.dict_id.is_some() { 3u8 } else { 0u8 };
200 descriptor |= dict_id_flag;
201
202 let (fcs_flag, fcs_bytes) = if !self.include_content_size || content_size <= 255 {
204 (0u8, 1)
205 } else if content_size <= 65535 + 256 {
206 (1u8, 2)
207 } else if content_size <= u32::MAX as usize {
208 (2u8, 4)
209 } else {
210 (3u8, 8)
211 };
212
213 descriptor |= fcs_flag << 6;
214 output.push(descriptor);
215
216 if let Some(id) = self.dict_id {
218 output.extend_from_slice(&id.to_le_bytes());
219 }
220
221 match fcs_bytes {
223 1 => {
224 output.push(content_size as u8);
225 }
226 2 => {
227 let adjusted = (content_size - 256) as u16;
228 output.extend_from_slice(&adjusted.to_le_bytes());
229 }
230 4 => {
231 output.extend_from_slice(&(content_size as u32).to_le_bytes());
232 }
233 8 => {
234 output.extend_from_slice(&(content_size as u64).to_le_bytes());
235 }
236 _ => unreachable!(),
237 }
238 }
239
240 fn write_blocks(&self, output: &mut Vec<u8>, data: &[u8]) {
242 if data.is_empty() {
243 write_empty_block(output);
244 return;
245 }
246
247 let mut offset = 0;
248 while offset < data.len() {
249 let remaining = data.len() - offset;
250 let block_size = remaining.min(MAX_BLOCK_SIZE);
251 let is_last = offset + block_size >= data.len();
252 let block_data = &data[offset..offset + block_size];
253
254 if self.strategy == CompressionStrategy::RleOnly {
256 if let Some(rle_byte) = detect_rle(block_data) {
257 write_rle_block_to(output, rle_byte, block_size, is_last);
258 offset += block_size;
259 continue;
260 }
261 }
262
263 write_raw_block_to(output, block_data, is_last);
265 offset += block_size;
266 }
267 }
268
269 fn write_compressed_blocks(&self, output: &mut Vec<u8>, data: &[u8]) -> Result<()> {
271 if data.is_empty() {
272 write_empty_block(output);
273 return Ok(());
274 }
275
276 let config = LevelConfig::for_level(self.level);
277 let mut finder = MatchFinder::new(&config);
278 let dict = self.dictionary.as_deref().unwrap_or(&[]);
279
280 let mut offset = 0;
281 while offset < data.len() {
282 let remaining = data.len() - offset;
283 let block_size = remaining.min(config.target_block_size);
284 let is_last = offset + block_size >= data.len();
285 let block_data = &data[offset..offset + block_size];
286
287 if let Some(rle_byte) = detect_rle(block_data) {
289 write_rle_block_to(output, rle_byte, block_size, is_last);
290 offset += block_size;
291 continue;
292 }
293
294 let sequences = finder.find_sequences(block_data, dict)?;
296
297 match encode_compressed_block(&sequences) {
299 Ok(compressed_content) => {
300 if compressed_content.len() < block_data.len() {
302 write_compressed_block_to(output, &compressed_content, is_last);
303 } else {
304 write_raw_block_to(output, block_data, is_last);
306 }
307 }
308 Err(_) => {
309 write_raw_block_to(output, block_data, is_last);
311 }
312 }
313
314 finder.reset();
315 offset += block_size;
316 }
317
318 Ok(())
319 }
320}
321
322impl Default for ZstdEncoder {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328fn write_empty_block(output: &mut Vec<u8>) {
332 let block_header: u32 = 1; output.push((block_header & 0xFF) as u8);
334 output.push(((block_header >> 8) & 0xFF) as u8);
335 output.push(((block_header >> 16) & 0xFF) as u8);
336}
337
338fn write_raw_block_to(output: &mut Vec<u8>, data: &[u8], is_last: bool) {
340 let last_flag = if is_last { 1u32 } else { 0u32 };
341 let block_header: u32 = last_flag | ((data.len() as u32) << 3);
342 output.push((block_header & 0xFF) as u8);
343 output.push(((block_header >> 8) & 0xFF) as u8);
344 output.push(((block_header >> 16) & 0xFF) as u8);
345 output.extend_from_slice(data);
346}
347
348fn write_rle_block_to(output: &mut Vec<u8>, byte: u8, size: usize, is_last: bool) {
350 let last_flag = if is_last { 1u32 } else { 0u32 };
351 let block_type = 1u32 << 1; let block_header: u32 = last_flag | block_type | ((size as u32) << 3);
353 output.push((block_header & 0xFF) as u8);
354 output.push(((block_header >> 8) & 0xFF) as u8);
355 output.push(((block_header >> 16) & 0xFF) as u8);
356 output.push(byte);
357}
358
359fn write_compressed_block_to(output: &mut Vec<u8>, content: &[u8], is_last: bool) {
361 let last_flag = if is_last { 1u32 } else { 0u32 };
362 let block_type = 2u32 << 1; let block_header: u32 = last_flag | block_type | ((content.len() as u32) << 3);
364 output.push((block_header & 0xFF) as u8);
365 output.push(((block_header >> 8) & 0xFF) as u8);
366 output.push(((block_header >> 16) & 0xFF) as u8);
367 output.extend_from_slice(content);
368}
369
370fn detect_rle(data: &[u8]) -> Option<u8> {
372 if data.is_empty() {
373 return None;
374 }
375 let first = data[0];
376 for chunk in data.chunks(16) {
377 if !chunk.iter().all(|&b| b == first) {
378 return None;
379 }
380 }
381 Some(first)
382}
383
384pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
391 ZstdEncoder::new().compress(data)
392}
393
394pub fn compress_with_level(data: &[u8], level: i32) -> Result<Vec<u8>> {
402 let mut encoder = ZstdEncoder::new();
403 encoder.set_level(level);
404 encoder.compress(data)
405}
406
407pub fn compress_no_checksum(data: &[u8]) -> Result<Vec<u8>> {
409 let mut encoder = ZstdEncoder::new();
410 encoder.set_checksum(false);
411 encoder.compress(data)
412}
413
414#[cfg(feature = "parallel")]
416pub fn compress_parallel(data: &[u8]) -> Result<Vec<u8>> {
417 ZstdEncoder::new().compress_parallel(data)
418}
419
420pub fn encode_all(data: &[u8], level: i32) -> Result<Vec<u8>> {
427 compress_with_level(data, level)
428}
429
430pub fn decode_all(data: &[u8]) -> Result<Vec<u8>> {
432 crate::decompress(data)
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::decompress;
439
440 #[test]
441 fn test_compress_empty() {
442 let data: &[u8] = &[];
443 let compressed = compress(data).unwrap();
444 assert_eq!(&compressed[0..4], &ZSTD_MAGIC);
445 let decompressed = decompress(&compressed).unwrap();
446 assert_eq!(decompressed, data);
447 }
448
449 #[test]
450 fn test_compress_small() {
451 let data = b"Hello, Zstandard!";
452 let compressed = compress(data).unwrap();
453 let decompressed = decompress(&compressed).unwrap();
454 assert_eq!(decompressed, data.as_slice());
455 }
456
457 #[test]
458 fn test_compress_larger() {
459 let data = vec![0x42u8; 1000];
460 let compressed = compress(&data).unwrap();
461 let decompressed = decompress(&compressed).unwrap();
462 assert_eq!(decompressed, data);
463 }
464
465 #[test]
466 fn test_compress_multi_block() {
467 let data = vec![0xABu8; MAX_BLOCK_SIZE + 1000];
468 let compressed = compress(&data).unwrap();
469 let decompressed = decompress(&compressed).unwrap();
470 assert_eq!(decompressed, data);
471 }
472
473 #[test]
474 fn test_compress_no_checksum() {
475 let data = b"Test without checksum";
476 let compressed = compress_no_checksum(data).unwrap();
477 let decompressed = decompress(&compressed).unwrap();
478 assert_eq!(decompressed, data.as_slice());
479 }
480
481 #[test]
482 fn test_encoder_builder() {
483 let data = b"Builder pattern test";
484 let mut encoder = ZstdEncoder::new();
485 encoder.set_checksum(true).set_content_size(true);
486 let compressed = encoder.compress(data).unwrap();
487 let decompressed = decompress(&compressed).unwrap();
488 assert_eq!(decompressed, data.as_slice());
489 }
490
491 #[test]
492 fn test_various_sizes() {
493 for size in [0, 1, 10, 100, 255, 256, 257, 1000, 65535, 65536, 100000] {
494 let data = vec![0x55u8; size];
495 let compressed = compress(&data).unwrap();
496 let decompressed = decompress(&compressed).unwrap();
497 assert_eq!(decompressed, data, "Failed for size {}", size);
498 }
499 }
500
501 #[test]
502 fn test_rle_compression() {
503 let data = vec![0xAAu8; 10000];
504 let compressed = compress(&data).unwrap();
505 assert!(
506 compressed.len() < data.len() / 10,
507 "RLE compression failed: {} vs {}",
508 compressed.len(),
509 data.len()
510 );
511 let decompressed = decompress(&compressed).unwrap();
512 assert_eq!(decompressed, data);
513 }
514
515 #[test]
516 fn test_rle_multi_block() {
517 let data = vec![0xBBu8; MAX_BLOCK_SIZE * 3];
518 let compressed = compress(&data).unwrap();
519 assert!(
520 compressed.len() < 100,
521 "Expected small output, got {}",
522 compressed.len()
523 );
524 let decompressed = decompress(&compressed).unwrap();
525 assert_eq!(decompressed, data);
526 }
527
528 #[test]
529 fn test_rle_mixed_data() {
530 let mut data = vec![0xCCu8; 1000];
531 data.extend_from_slice(b"Hello, World!");
532 data.extend_from_slice(&vec![0xDDu8; 1000]);
533 let compressed = compress(&data).unwrap();
534 let decompressed = decompress(&compressed).unwrap();
535 assert_eq!(decompressed, data);
536 }
537
538 #[test]
539 fn test_detect_rle() {
540 assert_eq!(detect_rle(&[0xAA; 100]), Some(0xAA));
541 assert_eq!(detect_rle(&[0x00; 50]), Some(0x00));
542 assert_eq!(detect_rle(&[0xFF]), Some(0xFF));
543 assert_eq!(detect_rle(&[0xAA, 0xAA, 0xBB]), None);
544 assert_eq!(detect_rle(&[0x00, 0x01]), None);
545 assert_eq!(detect_rle(&[]), None);
546 }
547
548 #[test]
549 fn test_raw_strategy() {
550 let data = vec![0xEEu8; 1000];
551 let mut encoder = ZstdEncoder::new();
552 encoder.set_strategy(CompressionStrategy::Raw);
553 let compressed = encoder.compress(&data).unwrap();
554 assert!(compressed.len() > data.len());
555 let decompressed = decompress(&compressed).unwrap();
556 assert_eq!(decompressed, data);
557 }
558
559 #[test]
560 fn test_compress_with_level() {
561 let data = b"The quick brown fox jumps over the lazy dog. \
563 The quick brown fox jumps over the lazy dog. \
564 The quick brown fox jumps over the lazy dog.";
565
566 for level in [1, 3, 6, 9, 15, 22] {
567 let compressed = compress_with_level(data, level).unwrap();
568 let decompressed = decompress(&compressed).unwrap();
569 assert_eq!(
570 decompressed,
571 data.as_slice(),
572 "Roundtrip failed for level {}",
573 level
574 );
575 }
576 }
577
578 #[test]
579 fn test_encode_all_decode_all() {
580 let data = b"Testing encode_all and decode_all convenience functions";
581 let compressed = encode_all(data, 3).unwrap();
582 let decompressed = decode_all(&compressed).unwrap();
583 assert_eq!(decompressed, data.as_slice());
584 }
585
586 #[test]
587 fn test_level_compression_ratio() {
588 let mut data = Vec::new();
590 for _ in 0..100 {
591 data.extend_from_slice(b"ABCDEFGHIJKLMNOP");
592 }
593
594 let raw = compress(&data).unwrap();
595 let level3 = compress_with_level(&data, 3).unwrap();
596
597 assert!(
599 level3.len() <= raw.len(),
600 "Level 3 ({}) should be <= raw ({}) for repetitive data",
601 level3.len(),
602 raw.len()
603 );
604
605 assert_eq!(decompress(&raw).unwrap(), data);
607 assert_eq!(decompress(&level3).unwrap(), data);
608 }
609
610 #[test]
611 fn test_large_data_roundtrip() {
612 let mut data = Vec::with_capacity(16384);
614 let pattern = b"RDF triple: <http://example.org/subject> <http://example.org/predicate> \"value\" .\n";
615 while data.len() < 16384 {
616 data.extend_from_slice(pattern);
617 }
618 data.truncate(16384);
619
620 for level in [1, 3] {
621 let compressed = encode_all(&data, level).unwrap();
622 let decompressed = decode_all(&compressed).unwrap();
623 assert_eq!(
624 decompressed, data,
625 "Large roundtrip failed for level {}",
626 level
627 );
628 }
629 }
630
631 #[test]
632 #[cfg(feature = "parallel")]
633 fn test_parallel_roundtrip_basic() {
634 let data = b"Hello, World! Parallel Zstandard compression.";
635 let compressed = compress_parallel(data).unwrap();
636 let decompressed = decompress(&compressed).unwrap();
637 assert_eq!(decompressed, data.as_slice());
638 }
639
640 #[test]
641 #[cfg(feature = "parallel")]
642 fn test_parallel_roundtrip_large() {
643 let data = vec![0xABu8; 5_000_000];
644 let compressed = compress_parallel(&data).unwrap();
645 let decompressed = decompress(&compressed).unwrap();
646 assert_eq!(decompressed, data);
647 }
648
649 #[test]
650 #[cfg(feature = "parallel")]
651 fn test_parallel_rle_compression() {
652 let data = vec![0xCCu8; 2_000_000];
653 let compressed = compress_parallel(&data).unwrap();
654 assert!(compressed.len() < data.len() / 100);
655 let decompressed = decompress(&compressed).unwrap();
656 assert_eq!(decompressed, data);
657 }
658
659 #[test]
660 #[cfg(feature = "parallel")]
661 fn test_parallel_empty() {
662 let data: &[u8] = &[];
663 let compressed = compress_parallel(data).unwrap();
664 let decompressed = decompress(&compressed).unwrap();
665 assert_eq!(decompressed, data);
666 }
667
668 #[test]
669 #[cfg(feature = "parallel")]
670 fn test_parallel_vs_serial() {
671 let data = b"Testing parallel vs serial compression output.";
672 let serial = compress(data).unwrap();
673 let parallel = compress_parallel(data).unwrap();
674 let serial_decompressed = decompress(&serial).unwrap();
675 let parallel_decompressed = decompress(¶llel).unwrap();
676 assert_eq!(serial_decompressed, data.as_slice());
677 assert_eq!(parallel_decompressed, data.as_slice());
678 }
679
680 #[test]
681 #[cfg(feature = "parallel")]
682 fn test_parallel_encoder_options() {
683 let data = vec![0xFFu8; 1_000_000];
684 let mut encoder = ZstdEncoder::new();
685 encoder
686 .set_checksum(false)
687 .set_strategy(CompressionStrategy::RleOnly);
688 let compressed = encoder.compress_parallel(&data).unwrap();
689 let decompressed = decompress(&compressed).unwrap();
690 assert_eq!(decompressed, data);
691 }
692
693 #[test]
694 #[cfg(feature = "parallel")]
695 fn test_parallel_multi_block() {
696 let data = vec![0x55u8; MAX_BLOCK_SIZE * 3 + 5000];
697 let compressed = compress_parallel(&data).unwrap();
698 let decompressed = decompress(&compressed).unwrap();
699 assert_eq!(decompressed, data);
700 }
701}