1use crate::literals::LiteralsDecoder;
6use crate::sequences::{Sequence, SequencesDecoder};
7use crate::xxhash::xxhash64_checksum;
8use crate::{BlockType, MAX_BLOCK_SIZE, MAX_WINDOW_SIZE, ZSTD_MAGIC};
9use oxiarc_core::error::{OxiArcError, Result};
10
11const FHD_SINGLE_SEGMENT: u8 = 0x20;
13const FHD_CONTENT_CHECKSUM: u8 = 0x04;
14const FHD_DICT_ID_FLAG_MASK: u8 = 0x03;
15const FHD_CONTENT_SIZE_FLAG_MASK: u8 = 0xC0;
16
17#[derive(Debug, Clone)]
19pub struct FrameHeader {
20 pub window_size: usize,
22 pub content_size: Option<u64>,
24 #[allow(dead_code)]
26 pub dict_id: Option<u32>,
27 pub has_checksum: bool,
29 pub header_size: usize,
31}
32
33pub fn parse_frame_header(data: &[u8]) -> Result<FrameHeader> {
35 if data.len() < 5 {
36 return Err(OxiArcError::CorruptedData {
37 offset: 0,
38 message: "truncated frame header".to_string(),
39 });
40 }
41
42 if data[0..4] != ZSTD_MAGIC {
44 return Err(OxiArcError::invalid_magic(ZSTD_MAGIC, &data[0..4]));
45 }
46
47 let descriptor = data[4];
48 let single_segment = (descriptor & FHD_SINGLE_SEGMENT) != 0;
49 let has_checksum = (descriptor & FHD_CONTENT_CHECKSUM) != 0;
50 let dict_id_flag = descriptor & FHD_DICT_ID_FLAG_MASK;
51 let content_size_flag = (descriptor & FHD_CONTENT_SIZE_FLAG_MASK) >> 6;
52
53 let mut pos = 5;
54
55 let window_size = if single_segment {
57 0 } else {
59 if data.len() <= pos {
60 return Err(OxiArcError::CorruptedData {
61 offset: pos as u64,
62 message: "missing window descriptor".to_string(),
63 });
64 }
65 let wd = data[pos];
66 pos += 1;
67
68 let exponent = (wd >> 3) as u32;
69 let mantissa = (wd & 0x07) as u32;
70 let base = 1u64 << (10 + exponent);
71 let window = base + (base >> 3) * mantissa as u64;
72 window.min(MAX_WINDOW_SIZE as u64) as usize
73 };
74
75 let dict_id = match dict_id_flag {
77 0 => None,
78 1 => {
79 if data.len() <= pos {
80 return Err(OxiArcError::CorruptedData {
81 offset: pos as u64,
82 message: "missing dictionary ID".to_string(),
83 });
84 }
85 let id = data[pos] as u32;
86 pos += 1;
87 Some(id)
88 }
89 2 => {
90 if data.len() < pos + 2 {
91 return Err(OxiArcError::CorruptedData {
92 offset: pos as u64,
93 message: "truncated dictionary ID".to_string(),
94 });
95 }
96 let id = u16::from_le_bytes([data[pos], data[pos + 1]]) as u32;
97 pos += 2;
98 Some(id)
99 }
100 3 => {
101 if data.len() < pos + 4 {
102 return Err(OxiArcError::CorruptedData {
103 offset: pos as u64,
104 message: "truncated dictionary ID".to_string(),
105 });
106 }
107 let id = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
108 pos += 4;
109 Some(id)
110 }
111 _ => unreachable!(),
112 };
113
114 let content_size = if single_segment || content_size_flag != 0 {
116 let size_bytes = match content_size_flag {
117 0 => 1, 1 => 2,
119 2 => 4,
120 3 => 8,
121 _ => unreachable!(),
122 };
123
124 if data.len() < pos + size_bytes {
125 return Err(OxiArcError::CorruptedData {
126 offset: pos as u64,
127 message: "truncated content size".to_string(),
128 });
129 }
130
131 let size = match size_bytes {
132 1 => data[pos] as u64,
133 2 => {
134 let s = u16::from_le_bytes([data[pos], data[pos + 1]]) as u64;
135 s + 256 }
137 4 => {
138 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as u64
139 }
140 8 => u64::from_le_bytes([
141 data[pos],
142 data[pos + 1],
143 data[pos + 2],
144 data[pos + 3],
145 data[pos + 4],
146 data[pos + 5],
147 data[pos + 6],
148 data[pos + 7],
149 ]),
150 _ => unreachable!(),
151 };
152 pos += size_bytes;
153 Some(size)
154 } else {
155 None
156 };
157
158 let window_size = if single_segment {
160 content_size
161 .unwrap_or(MAX_WINDOW_SIZE as u64)
162 .min(MAX_WINDOW_SIZE as u64) as usize
163 } else {
164 window_size
165 };
166
167 Ok(FrameHeader {
168 window_size,
169 content_size,
170 dict_id,
171 has_checksum,
172 header_size: pos,
173 })
174}
175
176pub struct ZstdDecoder {
178 literals_decoder: LiteralsDecoder,
180 sequences_decoder: SequencesDecoder,
182 output: Vec<u8>,
184 window_size: usize,
186 dictionary: Option<Vec<u8>>,
188}
189
190impl ZstdDecoder {
191 pub fn new() -> Self {
193 Self {
194 literals_decoder: LiteralsDecoder::new(),
195 sequences_decoder: SequencesDecoder::new(),
196 output: Vec::new(),
197 window_size: MAX_WINDOW_SIZE,
198 dictionary: None,
199 }
200 }
201
202 pub fn set_dictionary(&mut self, dict: &[u8]) {
206 if dict.is_empty() {
207 self.dictionary = None;
208 } else {
209 self.dictionary = Some(dict.to_vec());
210 }
211 }
212
213 pub fn decode_frame(&mut self, data: &[u8]) -> Result<Vec<u8>> {
215 let header = parse_frame_header(data)?;
216 self.window_size = header.window_size;
217
218 if let Some(size) = header.content_size {
220 self.output.reserve(size as usize);
221 }
222
223 let mut pos = header.header_size;
224
225 loop {
227 if data.len() < pos + 3 {
228 return Err(OxiArcError::CorruptedData {
229 offset: pos as u64,
230 message: "truncated block header".to_string(),
231 });
232 }
233
234 let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
236 pos += 3;
237
238 let last_block = (block_header & 1) != 0;
239 let block_type = BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
240 let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
241
242 if block_size > MAX_BLOCK_SIZE {
243 return Err(OxiArcError::CorruptedData {
244 offset: pos as u64,
245 message: format!("block size {} exceeds maximum", block_size),
246 });
247 }
248
249 let compressed_size = match block_type {
251 BlockType::Rle => 1,
252 _ => block_size,
253 };
254
255 if data.len() < pos + compressed_size {
256 return Err(OxiArcError::CorruptedData {
257 offset: pos as u64,
258 message: "truncated block data".to_string(),
259 });
260 }
261
262 let block_data = &data[pos..pos + compressed_size];
263 pos += compressed_size;
264
265 match block_type {
266 BlockType::Raw => {
267 self.output.extend_from_slice(block_data);
268 }
269 BlockType::Rle => {
270 self.output
273 .extend(std::iter::repeat_n(block_data[0], block_size));
274 }
275 BlockType::Compressed => {
276 self.decode_compressed_block(block_data)?;
277 }
278 BlockType::Reserved => {
279 return Err(OxiArcError::CorruptedData {
280 offset: pos as u64,
281 message: "reserved block type".to_string(),
282 });
283 }
284 }
285
286 if last_block {
287 break;
288 }
289 }
290
291 if header.has_checksum {
293 if data.len() < pos + 4 {
294 return Err(OxiArcError::CorruptedData {
295 offset: pos as u64,
296 message: "missing content checksum".to_string(),
297 });
298 }
299
300 let expected =
301 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
302 let computed = xxhash64_checksum(&self.output);
303
304 if expected != computed {
305 return Err(OxiArcError::CrcMismatch { expected, computed });
306 }
307 }
308
309 if let Some(expected_size) = header.content_size {
311 if self.output.len() as u64 != expected_size {
312 return Err(OxiArcError::CorruptedData {
313 offset: 0,
314 message: format!(
315 "content size mismatch: expected {}, got {}",
316 expected_size,
317 self.output.len()
318 ),
319 });
320 }
321 }
322
323 Ok(std::mem::take(&mut self.output))
324 }
325
326 fn decode_compressed_block(&mut self, data: &[u8]) -> Result<()> {
328 let (literals, literals_size) = self.literals_decoder.decode(data)?;
330
331 let sequences_data = &data[literals_size..];
333 let (sequences, _) = self.sequences_decoder.decode(sequences_data)?;
334
335 self.execute_sequences(&literals, &sequences)?;
337
338 Ok(())
339 }
340
341 fn execute_sequences(&mut self, literals: &[u8], sequences: &[Sequence]) -> Result<()> {
343 let mut lit_pos = 0;
344 let dict = self.dictionary.as_deref().unwrap_or(&[]);
345 let dict_len = dict.len();
346
347 for seq in sequences {
348 if seq.literal_length > 0 {
350 if lit_pos + seq.literal_length > literals.len() {
351 return Err(OxiArcError::CorruptedData {
352 offset: 0,
353 message: "literal length exceeds available literals".to_string(),
354 });
355 }
356 self.output
357 .extend_from_slice(&literals[lit_pos..lit_pos + seq.literal_length]);
358 lit_pos += seq.literal_length;
359 }
360
361 if seq.match_length > 0 {
363 let max_offset = self.output.len() + dict_len;
364 if seq.offset == 0 || seq.offset > max_offset {
365 return Err(OxiArcError::CorruptedData {
366 offset: 0,
367 message: format!(
368 "invalid offset {} (output length {}, dict length {})",
369 seq.offset,
370 self.output.len(),
371 dict_len
372 ),
373 });
374 }
375
376 if seq.offset <= self.output.len() {
377 let start = self.output.len() - seq.offset;
379 for i in 0..seq.match_length {
380 let byte = self.output[start + (i % seq.offset)];
381 self.output.push(byte);
382 }
383 } else {
384 let dict_and_output_len = dict_len + self.output.len();
387 let start_in_combined = dict_and_output_len - seq.offset;
388
389 for i in 0..seq.match_length {
390 let pos_in_combined = start_in_combined + (i % seq.offset);
391 let byte = if pos_in_combined < dict_len {
392 dict[pos_in_combined]
393 } else {
394 self.output[pos_in_combined - dict_len]
395 };
396 self.output.push(byte);
397 }
398 }
399 }
400 }
401
402 if lit_pos < literals.len() {
404 self.output.extend_from_slice(&literals[lit_pos..]);
405 }
406
407 Ok(())
408 }
409
410 pub fn reset(&mut self) {
412 self.output.clear();
413 self.sequences_decoder.reset();
414 }
415}
416
417impl Default for ZstdDecoder {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423pub fn decompress(data: &[u8]) -> Result<Vec<u8>> {
425 let mut decoder = ZstdDecoder::new();
426 decoder.decode_frame(data)
427}
428
429pub fn decompress_with_dict(data: &[u8], dict: &[u8]) -> Result<Vec<u8>> {
431 let mut decoder = ZstdDecoder::new();
432 decoder.set_dictionary(dict);
433 decoder.decode_frame(data)
434}
435
436pub fn decompress_frame(data: &[u8]) -> Result<(Vec<u8>, usize)> {
442 let mut decoder = ZstdDecoder::new();
443 decompress_frame_with_decoder(data, &mut decoder)
444}
445
446pub fn decompress_multi_frame(data: &[u8]) -> Result<Vec<u8>> {
452 let mut output = Vec::new();
453 let mut pos = 0;
454
455 while pos < data.len() {
456 if data.len() - pos < 4 {
458 break;
459 }
460 let magic = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
461
462 if magic == 0xFD2FB528 {
463 let (decompressed, consumed) = decompress_frame(&data[pos..])?;
465 output.extend_from_slice(&decompressed);
466 pos += consumed;
467 } else if (crate::SKIPPABLE_MAGIC_LOW..=crate::SKIPPABLE_MAGIC_HIGH).contains(&magic) {
468 if data.len() - pos < 8 {
470 break;
471 }
472 let skip_size =
473 u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
474 as usize;
475 pos += 8 + skip_size;
476 } else {
477 break;
479 }
480 }
481
482 Ok(output)
483}
484
485pub fn decompress_multi_frame_with_dict(data: &[u8], dict: &[u8]) -> Result<Vec<u8>> {
496 let mut output = Vec::new();
497 let mut pos = 0;
498
499 while pos < data.len() {
500 if data.len() - pos < 4 {
502 break;
503 }
504 let magic = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
505
506 if magic == 0xFD2FB528 {
507 let mut decoder = ZstdDecoder::new();
511 decoder.set_dictionary(dict);
512 let (decompressed, consumed) =
513 decompress_frame_with_decoder(&data[pos..], &mut decoder)?;
514 output.extend_from_slice(&decompressed);
515 pos += consumed;
516 } else if (crate::SKIPPABLE_MAGIC_LOW..=crate::SKIPPABLE_MAGIC_HIGH).contains(&magic) {
517 if data.len() - pos < 8 {
519 break;
520 }
521 let skip_size =
522 u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
523 as usize;
524 pos += 8 + skip_size;
525 } else {
526 break;
528 }
529 }
530
531 Ok(output)
532}
533
534fn decompress_frame_with_decoder(
540 data: &[u8],
541 decoder: &mut ZstdDecoder,
542) -> Result<(Vec<u8>, usize)> {
543 let header = parse_frame_header(data)?;
544 decoder.window_size = header.window_size;
545
546 if let Some(size) = header.content_size {
547 decoder.output.reserve(size as usize);
548 }
549
550 let mut pos = header.header_size;
551
552 loop {
553 if data.len() < pos + 3 {
554 return Err(OxiArcError::CorruptedData {
555 offset: pos as u64,
556 message: "truncated block header".to_string(),
557 });
558 }
559
560 let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
561 pos += 3;
562
563 let last_block = (block_header & 1) != 0;
564 let block_type = crate::BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
565 let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
566
567 if block_size > crate::MAX_BLOCK_SIZE {
568 return Err(OxiArcError::CorruptedData {
569 offset: pos as u64,
570 message: format!("block size {} exceeds maximum", block_size),
571 });
572 }
573
574 let compressed_size = match block_type {
575 crate::BlockType::Rle => 1,
576 _ => block_size,
577 };
578
579 if data.len() < pos + compressed_size {
580 return Err(OxiArcError::CorruptedData {
581 offset: pos as u64,
582 message: "truncated block data".to_string(),
583 });
584 }
585
586 let block_data = &data[pos..pos + compressed_size];
587 pos += compressed_size;
588
589 match block_type {
590 crate::BlockType::Raw => {
591 decoder.output.extend_from_slice(block_data);
592 }
593 crate::BlockType::Rle => {
594 decoder
595 .output
596 .extend(std::iter::repeat_n(block_data[0], block_size));
597 }
598 crate::BlockType::Compressed => {
599 decoder.decode_compressed_block(block_data)?;
600 }
601 crate::BlockType::Reserved => {
602 return Err(OxiArcError::CorruptedData {
603 offset: pos as u64,
604 message: "reserved block type".to_string(),
605 });
606 }
607 }
608
609 if last_block {
610 break;
611 }
612 }
613
614 if header.has_checksum {
616 if data.len() < pos + 4 {
617 return Err(OxiArcError::CorruptedData {
618 offset: pos as u64,
619 message: "missing content checksum".to_string(),
620 });
621 }
622
623 let expected = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
624 let computed = xxhash64_checksum(&decoder.output);
625
626 if expected != computed {
627 return Err(OxiArcError::CrcMismatch { expected, computed });
628 }
629 pos += 4;
630 }
631
632 if let Some(expected_size) = header.content_size {
634 if decoder.output.len() as u64 != expected_size {
635 return Err(OxiArcError::CorruptedData {
636 offset: 0,
637 message: format!(
638 "content size mismatch: expected {}, got {}",
639 expected_size,
640 decoder.output.len()
641 ),
642 });
643 }
644 }
645
646 let decompressed = std::mem::take(&mut decoder.output);
647 Ok((decompressed, pos))
648}
649
650pub fn write_skippable_frame(user_data: &[u8], magic_nibble: u8) -> Vec<u8> {
656 let magic = crate::SKIPPABLE_MAGIC_LOW | (magic_nibble & 0xF) as u32;
657 let mut out = Vec::with_capacity(8 + user_data.len());
658 out.extend_from_slice(&magic.to_le_bytes());
659 out.extend_from_slice(&(user_data.len() as u32).to_le_bytes());
660 out.extend_from_slice(user_data);
661 out
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_parse_frame_header_minimal() {
670 let mut data = Vec::new();
672 data.extend_from_slice(&ZSTD_MAGIC);
673 data.push(0x20); data.push(5); let header = parse_frame_header(&data).expect("operation failed");
677
678 assert_eq!(header.content_size, Some(5));
679 assert!(!header.has_checksum);
680 assert!(header.dict_id.is_none());
681 }
682
683 #[test]
684 fn test_parse_frame_header_with_checksum() {
685 let mut data = Vec::new();
686 data.extend_from_slice(&ZSTD_MAGIC);
687 data.push(0x24); data.push(10); let header = parse_frame_header(&data).expect("operation failed");
691
692 assert!(header.has_checksum);
693 assert_eq!(header.content_size, Some(10));
694 }
695
696 #[test]
697 fn test_invalid_magic() {
698 let data = [0x00, 0x00, 0x00, 0x00, 0x00];
699 let result = parse_frame_header(&data);
700 assert!(result.is_err());
701 }
702
703 #[test]
704 fn test_decoder_creation() {
705 let decoder = ZstdDecoder::new();
706 assert_eq!(decoder.window_size, MAX_WINDOW_SIZE);
707 }
708
709 #[test]
710 fn test_block_type_parsing() {
711 assert_eq!(
712 BlockType::from_bits(0).expect("operation failed"),
713 BlockType::Raw
714 );
715 assert_eq!(
716 BlockType::from_bits(1).expect("operation failed"),
717 BlockType::Rle
718 );
719 assert_eq!(
720 BlockType::from_bits(2).expect("operation failed"),
721 BlockType::Compressed
722 );
723 assert!(BlockType::from_bits(3).is_err());
724 }
725}