1use crate::block::{decode_raw_block, decode_rle_block, LiteralsSection, SequencesSection};
6use crate::frame::{xxhash64, BlockHeader, BlockType, FrameHeader, ZSTD_MAGIC};
7use haagenti_core::{Error, Result};
8
9#[derive(Debug)]
11pub struct DecompressContext {
12 output: Vec<u8>,
14 #[allow(dead_code)]
16 window_size: usize,
17 repeat_offsets: [u32; 3],
19}
20
21impl DecompressContext {
22 pub fn new(window_size: usize) -> Self {
24 Self {
25 output: Vec::with_capacity(window_size.min(1024 * 1024)),
26 window_size,
27 repeat_offsets: [1, 4, 8], }
29 }
30
31 pub fn output(&self) -> &[u8] {
33 &self.output
34 }
35
36 pub fn into_output(self) -> Vec<u8> {
38 self.output
39 }
40
41 pub fn update_offsets(&mut self, offset: u32) {
43 if offset != self.repeat_offsets[0] {
44 self.repeat_offsets[2] = self.repeat_offsets[1];
45 self.repeat_offsets[1] = self.repeat_offsets[0];
46 self.repeat_offsets[0] = offset;
47 }
48 }
49
50 pub fn get_repeat_offset(&self, code: u32) -> u32 {
52 match code {
53 1 => self.repeat_offsets[0],
54 2 => self.repeat_offsets[1],
55 3 => self.repeat_offsets[2],
56 _ => code, }
58 }
59}
60
61pub fn decompress_frame(input: &[u8]) -> Result<Vec<u8>> {
69 if input.len() < 4 {
71 return Err(Error::corrupted("Input too short for Zstd frame"));
72 }
73
74 let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
76 if magic != ZSTD_MAGIC {
77 return Err(Error::corrupted(format!(
78 "Invalid Zstd magic: expected 0x{:08X}, got 0x{:08X}",
79 ZSTD_MAGIC, magic
80 )));
81 }
82
83 let header = FrameHeader::parse(&input[4..])?;
85 let mut ctx = DecompressContext::new(header.window_size);
86
87 let mut pos = header.header_size;
89 loop {
90 if pos + BlockHeader::SIZE > input.len() {
91 return Err(Error::corrupted("Frame truncated at block header"));
92 }
93
94 let block_header = BlockHeader::parse(&input[pos..])?;
95 pos += BlockHeader::SIZE;
96
97 let compressed_size = block_header.compressed_size();
98 if pos + compressed_size > input.len() {
99 return Err(Error::corrupted("Frame truncated at block data"));
100 }
101
102 let block_data = &input[pos..pos + compressed_size];
103 pos += compressed_size;
104
105 match block_header.block_type {
107 BlockType::Raw => {
108 decode_raw_block(block_data, &mut ctx.output)?;
109 }
110 BlockType::Rle => {
111 decode_rle_block(
112 block_data,
113 block_header.decompressed_size(),
114 &mut ctx.output,
115 )?;
116 }
117 BlockType::Compressed => {
118 decode_compressed_block(block_data, &mut ctx)?;
119 }
120 BlockType::Reserved => {
121 return Err(Error::corrupted("Reserved block type"));
122 }
123 }
124
125 if block_header.last_block {
126 break;
127 }
128 }
129
130 if header.has_checksum {
132 if pos + 4 > input.len() {
133 return Err(Error::corrupted("Frame truncated at checksum"));
134 }
135 let expected =
136 u32::from_le_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
137 let actual = (xxhash64(&ctx.output, 0) & 0xFFFFFFFF) as u32;
138
139 if expected != actual {
140 return Err(Error::corrupted(format!(
141 "Checksum mismatch: expected 0x{:08X}, got 0x{:08X}",
142 expected, actual
143 )));
144 }
145 }
146
147 if let Some(expected_size) = header.frame_content_size {
149 if ctx.output.len() as u64 != expected_size {
150 return Err(Error::corrupted(format!(
151 "Content size mismatch: expected {}, got {}",
152 expected_size,
153 ctx.output.len()
154 )));
155 }
156 }
157
158 Ok(ctx.into_output())
159}
160
161pub fn decompress_frame_with_dict(
170 input: &[u8],
171 dict: Option<&crate::dictionary::ZstdDictionary>,
172) -> Result<Vec<u8>> {
173 if dict.is_none() {
178 return decompress_frame(input);
179 }
180
181 let dictionary = dict.unwrap();
182
183 if input.len() < 4 {
185 return Err(Error::corrupted("Input too short for Zstd frame"));
186 }
187
188 let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
190 if magic != ZSTD_MAGIC {
191 return Err(Error::corrupted(format!(
192 "Invalid Zstd magic: expected 0x{:08X}, got 0x{:08X}",
193 ZSTD_MAGIC, magic
194 )));
195 }
196
197 let header = FrameHeader::parse(&input[4..])?;
199 let mut ctx = DecompressContext::new(header.window_size);
200
201 ctx.output.extend_from_slice(dictionary.content());
203 let dict_len = dictionary.content().len();
204
205 let mut pos = header.header_size;
207 loop {
208 if pos + BlockHeader::SIZE > input.len() {
209 return Err(Error::corrupted("Frame truncated at block header"));
210 }
211
212 let block_header = BlockHeader::parse(&input[pos..])?;
213 pos += BlockHeader::SIZE;
214
215 let compressed_size = block_header.compressed_size();
216 if pos + compressed_size > input.len() {
217 return Err(Error::corrupted("Frame truncated at block data"));
218 }
219
220 let block_data = &input[pos..pos + compressed_size];
221 pos += compressed_size;
222
223 match block_header.block_type {
225 BlockType::Raw => {
226 decode_raw_block(block_data, &mut ctx.output)?;
227 }
228 BlockType::Rle => {
229 decode_rle_block(
230 block_data,
231 block_header.decompressed_size(),
232 &mut ctx.output,
233 )?;
234 }
235 BlockType::Compressed => {
236 decode_compressed_block(block_data, &mut ctx)?;
237 }
238 BlockType::Reserved => {
239 return Err(Error::corrupted("Reserved block type"));
240 }
241 }
242
243 if block_header.last_block {
244 break;
245 }
246 }
247
248 if header.has_checksum {
250 if pos + 4 > input.len() {
251 return Err(Error::corrupted("Frame truncated at checksum"));
252 }
253 let expected =
254 u32::from_le_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
255 let content = &ctx.output[dict_len..];
257 let actual = (xxhash64(content, 0) & 0xFFFFFFFF) as u32;
258
259 if expected != actual {
260 return Err(Error::corrupted(format!(
261 "Checksum mismatch: expected 0x{:08X}, got 0x{:08X}",
262 expected, actual
263 )));
264 }
265 }
266
267 if let Some(expected_size) = header.frame_content_size {
269 let actual_size = (ctx.output.len() - dict_len) as u64;
270 if actual_size != expected_size {
271 return Err(Error::corrupted(format!(
272 "Content size mismatch: expected {}, got {}",
273 expected_size, actual_size
274 )));
275 }
276 }
277
278 Ok(ctx.output[dict_len..].to_vec())
280}
281
282fn decode_compressed_block(input: &[u8], ctx: &mut DecompressContext) -> Result<()> {
284 if input.is_empty() {
285 return Err(Error::corrupted("Empty compressed block"));
286 }
287
288 let (literals, literals_consumed) = LiteralsSection::parse(input)?;
290
291 let sequences_data = &input[literals_consumed..];
293 let sequences = SequencesSection::parse(sequences_data, &literals)?;
294
295 execute_sequences(&literals, &sequences, ctx)?;
297
298 Ok(())
299}
300
301fn execute_sequences(
306 literals: &LiteralsSection,
307 sequences: &SequencesSection,
308 ctx: &mut DecompressContext,
309) -> Result<()> {
310 let literal_bytes = literals.data();
311 let mut literal_pos = 0;
312
313 let total_output: usize = sequences
315 .sequences
316 .iter()
317 .map(|s| s.literal_length as usize + s.match_length as usize)
318 .sum();
319 ctx.output
320 .reserve(total_output + literal_bytes.len() - literal_pos);
321
322 for seq in &sequences.sequences {
323 let literal_end = literal_pos + seq.literal_length as usize;
325 if literal_end > literal_bytes.len() {
326 return Err(Error::corrupted(
327 "Literal length exceeds available literals",
328 ));
329 }
330 ctx.output
331 .extend_from_slice(&literal_bytes[literal_pos..literal_end]);
332 literal_pos = literal_end;
333
334 let offset = seq.offset as usize;
337 let match_length = seq.match_length as usize;
338
339 if match_length > 0 && offset > 0 {
341 let out_len = ctx.output.len();
342 if offset > out_len {
343 return Err(Error::corrupted(format!(
344 "Match offset {} exceeds output size {}",
345 offset, out_len
346 )));
347 }
348
349 let match_start = out_len - offset;
350
351 if offset >= match_length {
353 ctx.output
355 .extend_from_within(match_start..match_start + match_length);
356 } else {
357 copy_match_overlapping(&mut ctx.output, match_start, offset, match_length);
359 }
360 }
361 }
362
363 if literal_pos < literal_bytes.len() {
365 ctx.output.extend_from_slice(&literal_bytes[literal_pos..]);
366 }
367
368 Ok(())
369}
370
371#[inline(always)]
376fn copy_match_overlapping(
377 output: &mut Vec<u8>,
378 match_start: usize,
379 offset: usize,
380 match_length: usize,
381) {
382 output.reserve(match_length);
384 let out_len = output.len();
385
386 unsafe {
388 output.set_len(out_len + match_length);
389 let dst = output.as_mut_ptr().add(out_len);
390 let src_base = output.as_ptr().add(match_start);
391
392 match offset {
393 1 => {
394 let byte = *src_base;
396 core::ptr::write_bytes(dst, byte, match_length);
397 }
398 2 => {
399 let pattern = core::ptr::read_unaligned(src_base as *const u16);
401 let mut i = 0;
402 while i + 2 <= match_length {
403 core::ptr::write_unaligned(dst.add(i) as *mut u16, pattern);
404 i += 2;
405 }
406 if i < match_length {
407 *dst.add(i) = *src_base;
408 }
409 }
410 3 => {
411 for i in 0..match_length {
413 *dst.add(i) = *src_base.add(i % 3);
414 }
415 }
416 4 => {
417 let pattern = core::ptr::read_unaligned(src_base as *const u32);
419 let mut i = 0;
420 while i + 4 <= match_length {
421 core::ptr::write_unaligned(dst.add(i) as *mut u32, pattern);
422 i += 4;
423 }
424 while i < match_length {
425 *dst.add(i) = *src_base.add(i % 4);
426 i += 1;
427 }
428 }
429 5..=7 => {
430 for i in 0..match_length {
432 *dst.add(i) = *src_base.add(i % offset);
433 }
434 }
435 _ => {
436 let mut i = 0;
438 while i + offset <= match_length {
440 core::ptr::copy_nonoverlapping(src_base, dst.add(i), offset);
441 i += offset;
442 }
443 if i < match_length {
445 core::ptr::copy_nonoverlapping(src_base, dst.add(i), match_length - i);
446 }
447 }
448 }
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_decompress_context_creation() {
462 let ctx = DecompressContext::new(1024);
463 assert_eq!(ctx.window_size, 1024);
464 assert!(ctx.output.is_empty());
465 }
466
467 #[test]
468 fn test_repeat_offsets() {
469 let mut ctx = DecompressContext::new(1024);
470
471 assert_eq!(ctx.get_repeat_offset(1), 1);
473 assert_eq!(ctx.get_repeat_offset(2), 4);
474 assert_eq!(ctx.get_repeat_offset(3), 8);
475
476 ctx.update_offsets(100);
478 assert_eq!(ctx.get_repeat_offset(1), 100);
479 assert_eq!(ctx.get_repeat_offset(2), 1);
480 assert_eq!(ctx.get_repeat_offset(3), 4);
481
482 ctx.update_offsets(200);
484 assert_eq!(ctx.get_repeat_offset(1), 200);
485 assert_eq!(ctx.get_repeat_offset(2), 100);
486 assert_eq!(ctx.get_repeat_offset(3), 1);
487 }
488
489 #[test]
490 fn test_repeat_offset_same_value() {
491 let mut ctx = DecompressContext::new(1024);
492 ctx.update_offsets(100);
493
494 ctx.update_offsets(100);
496 assert_eq!(ctx.get_repeat_offset(1), 100);
497 assert_eq!(ctx.get_repeat_offset(2), 1);
498 }
499
500 #[test]
501 fn test_magic_validation() {
502 let result = decompress_frame(&[0x00, 0x00, 0x00, 0x00]);
504 assert!(result.is_err());
505
506 let result = decompress_frame(&[0x28, 0xB5]);
508 assert!(result.is_err());
509 }
510
511 #[test]
512 fn test_valid_magic() {
513 let data = [0x28, 0xB5, 0x2F, 0xFD, 0x00];
515 let result = decompress_frame(&data);
516 assert!(result.is_err());
518 }
519
520 #[test]
521 fn test_simple_raw_frame() {
522 let mut frame = vec![];
529
530 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
532
533 frame.push(0x20);
536
537 frame.push(5);
539
540 frame.extend_from_slice(&[0x29, 0x00, 0x00]);
544
545 frame.extend_from_slice(b"Hello");
547
548 let result = decompress_frame(&frame).unwrap();
549 assert_eq!(result, b"Hello");
550 }
551
552 #[test]
553 fn test_rle_frame() {
554 let mut frame = vec![];
556
557 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
559
560 frame.push(0x20);
562
563 frame.push(10);
565
566 frame.extend_from_slice(&[0x53, 0x00, 0x00]);
569
570 frame.push(b'X');
572
573 let result = decompress_frame(&frame).unwrap();
574 assert_eq!(result, vec![b'X'; 10]);
575 }
576
577 #[test]
578 fn test_multi_block_frame() {
579 let mut frame = vec![];
581
582 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
584
585 frame.push(0x20);
587
588 frame.push(8);
590
591 frame.extend_from_slice(&[0x28, 0x00, 0x00]);
594 frame.extend_from_slice(b"Hello");
595
596 frame.extend_from_slice(&[0x19, 0x00, 0x00]);
599 frame.extend_from_slice(b"!!!");
600
601 let result = decompress_frame(&frame).unwrap();
602 assert_eq!(result, b"Hello!!!");
603 }
604
605 #[test]
606 fn test_content_size_mismatch() {
607 let mut frame = vec![];
609
610 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
612
613 frame.push(0x20);
615
616 frame.push(10);
618
619 frame.extend_from_slice(&[0x29, 0x00, 0x00]);
621 frame.extend_from_slice(b"Hello");
622
623 let result = decompress_frame(&frame);
624 assert!(result.is_err());
625 }
626
627 #[test]
628 fn test_frame_with_checksum() {
629 let mut frame = vec![];
631
632 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
634
635 frame.push(0x24);
638
639 frame.push(5);
641
642 frame.extend_from_slice(&[0x29, 0x00, 0x00]);
644 frame.extend_from_slice(b"Hello");
645
646 let hash = xxhash64(b"Hello", 0);
648 let checksum = (hash & 0xFFFFFFFF) as u32;
649 frame.extend_from_slice(&checksum.to_le_bytes());
650
651 let result = decompress_frame(&frame).unwrap();
652 assert_eq!(result, b"Hello");
653 }
654
655 #[test]
656 fn test_checksum_mismatch() {
657 let mut frame = vec![];
659
660 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
662
663 frame.push(0x24);
665
666 frame.push(5);
668
669 frame.extend_from_slice(&[0x29, 0x00, 0x00]);
671 frame.extend_from_slice(b"Hello");
672
673 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
675
676 let result = decompress_frame(&frame);
677 assert!(result.is_err());
678 }
679
680 #[test]
681 fn test_compressed_block_literals_only() {
682 let mut frame = vec![];
684
685 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
687
688 frame.push(0x20);
690
691 frame.push(5);
693
694 let literals = b"Hello";
699 let compressed_block = build_compressed_block_literals_only(literals);
700
701 let block_size = compressed_block.len();
702 let header = (block_size << 3) | 5;
704 frame.push((header & 0xFF) as u8);
705 frame.push(((header >> 8) & 0xFF) as u8);
706 frame.push(((header >> 16) & 0xFF) as u8);
707
708 frame.extend_from_slice(&compressed_block);
709
710 let result = decompress_frame(&frame).unwrap();
711 assert_eq!(result, b"Hello");
712 }
713
714 fn build_compressed_block_literals_only(literals: &[u8]) -> Vec<u8> {
716 let mut block = vec![];
717
718 let size = literals.len();
721
722 if size <= 31 {
723 block.push(((size << 3) | 0) as u8);
725 } else if size <= 4095 {
726 let byte0 = ((size & 0xF) << 4) | (1 << 2);
730 let byte1 = (size >> 4) & 0xFF;
731 block.push(byte0 as u8);
732 block.push(byte1 as u8);
733 } else {
734 unreachable!("Size too large for test");
736 }
737
738 block.extend_from_slice(literals);
740
741 block.push(0);
743
744 block
745 }
746
747 #[test]
748 fn test_compressed_block_with_rle_literals() {
749 let mut frame = vec![];
751
752 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
754
755 frame.push(0x20);
757
758 frame.push(10);
760
761 let compressed_block = build_compressed_block_rle_literals(b'A', 10);
763
764 let block_size = compressed_block.len();
765 let header = (block_size << 3) | 5; frame.push((header & 0xFF) as u8);
767 frame.push(((header >> 8) & 0xFF) as u8);
768 frame.push(((header >> 16) & 0xFF) as u8);
769
770 frame.extend_from_slice(&compressed_block);
771
772 let result = decompress_frame(&frame).unwrap();
773 assert_eq!(result, vec![b'A'; 10]);
774 }
775
776 fn build_compressed_block_rle_literals(byte: u8, repeat_count: usize) -> Vec<u8> {
778 let mut block = vec![];
779
780 if repeat_count <= 31 {
783 block.push(((repeat_count << 3) | 1) as u8);
785 } else if repeat_count <= 4095 {
786 let byte0 = ((repeat_count & 0xF) << 4) | (1 << 2) | 1;
788 let byte1 = (repeat_count >> 4) & 0xFF;
789 block.push(byte0 as u8);
790 block.push(byte1 as u8);
791 } else {
792 unreachable!("Size too large for test");
793 }
794
795 block.push(byte);
797
798 block.push(0);
800
801 block
802 }
803
804 #[test]
805 fn test_compressed_block_multi_literals() {
806 let mut frame = vec![];
808
809 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
811
812 frame.push(0x40);
814
815 let literals: Vec<u8> = (0..100).map(|i| (i % 256) as u8).collect();
817
818 let mut frame = vec![];
828 frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
829 frame.push(0x20); frame.push(100); let compressed_block = build_compressed_block_literals_only(&literals);
833
834 let block_size = compressed_block.len();
835 let header = (block_size << 3) | 5;
836 frame.push((header & 0xFF) as u8);
837 frame.push(((header >> 8) & 0xFF) as u8);
838 frame.push(((header >> 16) & 0xFF) as u8);
839
840 frame.extend_from_slice(&compressed_block);
841
842 let result = decompress_frame(&frame).unwrap();
843 assert_eq!(result, literals);
844 }
845}