1use crate::literals::LiteralsDecoder;
6use crate::sequences::{Sequence, SequencesDecoder};
7use crate::xxhash::xxhash64_checksum;
8use crate::{BlockType, MAX_BLOCK_SIZE, MAX_WINDOW_SIZE, ZSTD_MAGIC};
9use oxiarc_core::error::{OxiArcError, Result};
10
11const FHD_SINGLE_SEGMENT: u8 = 0x20;
13const FHD_CONTENT_CHECKSUM: u8 = 0x04;
14const FHD_DICT_ID_FLAG_MASK: u8 = 0x03;
15const FHD_CONTENT_SIZE_FLAG_MASK: u8 = 0xC0;
16
17#[derive(Debug, Clone)]
19pub struct FrameHeader {
20 pub window_size: usize,
22 pub content_size: Option<u64>,
24 #[allow(dead_code)]
26 pub dict_id: Option<u32>,
27 pub has_checksum: bool,
29 pub header_size: usize,
31}
32
33pub fn parse_frame_header(data: &[u8]) -> Result<FrameHeader> {
35 if data.len() < 5 {
36 return Err(OxiArcError::CorruptedData {
37 offset: 0,
38 message: "truncated frame header".to_string(),
39 });
40 }
41
42 if data[0..4] != ZSTD_MAGIC {
44 return Err(OxiArcError::invalid_magic(ZSTD_MAGIC, &data[0..4]));
45 }
46
47 let descriptor = data[4];
48 let single_segment = (descriptor & FHD_SINGLE_SEGMENT) != 0;
49 let has_checksum = (descriptor & FHD_CONTENT_CHECKSUM) != 0;
50 let dict_id_flag = descriptor & FHD_DICT_ID_FLAG_MASK;
51 let content_size_flag = (descriptor & FHD_CONTENT_SIZE_FLAG_MASK) >> 6;
52
53 let mut pos = 5;
54
55 let window_size = if single_segment {
57 0 } else {
59 if data.len() <= pos {
60 return Err(OxiArcError::CorruptedData {
61 offset: pos as u64,
62 message: "missing window descriptor".to_string(),
63 });
64 }
65 let wd = data[pos];
66 pos += 1;
67
68 let exponent = (wd >> 3) as u32;
69 let mantissa = (wd & 0x07) as u32;
70 let base = 1u64 << (10 + exponent);
71 let window = base + (base >> 3) * mantissa as u64;
72 window.min(MAX_WINDOW_SIZE as u64) as usize
73 };
74
75 let dict_id = match dict_id_flag {
77 0 => None,
78 1 => {
79 if data.len() <= pos {
80 return Err(OxiArcError::CorruptedData {
81 offset: pos as u64,
82 message: "missing dictionary ID".to_string(),
83 });
84 }
85 let id = data[pos] as u32;
86 pos += 1;
87 Some(id)
88 }
89 2 => {
90 if data.len() < pos + 2 {
91 return Err(OxiArcError::CorruptedData {
92 offset: pos as u64,
93 message: "truncated dictionary ID".to_string(),
94 });
95 }
96 let id = u16::from_le_bytes([data[pos], data[pos + 1]]) as u32;
97 pos += 2;
98 Some(id)
99 }
100 3 => {
101 if data.len() < pos + 4 {
102 return Err(OxiArcError::CorruptedData {
103 offset: pos as u64,
104 message: "truncated dictionary ID".to_string(),
105 });
106 }
107 let id = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
108 pos += 4;
109 Some(id)
110 }
111 _ => unreachable!(),
112 };
113
114 let content_size = if single_segment || content_size_flag != 0 {
116 let size_bytes = match content_size_flag {
117 0 => 1, 1 => 2,
119 2 => 4,
120 3 => 8,
121 _ => unreachable!(),
122 };
123
124 if data.len() < pos + size_bytes {
125 return Err(OxiArcError::CorruptedData {
126 offset: pos as u64,
127 message: "truncated content size".to_string(),
128 });
129 }
130
131 let size = match size_bytes {
132 1 => data[pos] as u64,
133 2 => {
134 let s = u16::from_le_bytes([data[pos], data[pos + 1]]) as u64;
135 s + 256 }
137 4 => {
138 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as u64
139 }
140 8 => u64::from_le_bytes([
141 data[pos],
142 data[pos + 1],
143 data[pos + 2],
144 data[pos + 3],
145 data[pos + 4],
146 data[pos + 5],
147 data[pos + 6],
148 data[pos + 7],
149 ]),
150 _ => unreachable!(),
151 };
152 pos += size_bytes;
153 Some(size)
154 } else {
155 None
156 };
157
158 let window_size = if single_segment {
160 content_size
161 .unwrap_or(MAX_WINDOW_SIZE as u64)
162 .min(MAX_WINDOW_SIZE as u64) as usize
163 } else {
164 window_size
165 };
166
167 Ok(FrameHeader {
168 window_size,
169 content_size,
170 dict_id,
171 has_checksum,
172 header_size: pos,
173 })
174}
175
176pub struct ZstdDecoder {
178 literals_decoder: LiteralsDecoder,
180 sequences_decoder: SequencesDecoder,
182 output: Vec<u8>,
184 window_size: usize,
186 dictionary: Option<Vec<u8>>,
188}
189
190impl ZstdDecoder {
191 pub fn new() -> Self {
193 Self {
194 literals_decoder: LiteralsDecoder::new(),
195 sequences_decoder: SequencesDecoder::new(),
196 output: Vec::new(),
197 window_size: MAX_WINDOW_SIZE,
198 dictionary: None,
199 }
200 }
201
202 pub fn set_dictionary(&mut self, dict: &[u8]) {
206 if dict.is_empty() {
207 self.dictionary = None;
208 } else {
209 self.dictionary = Some(dict.to_vec());
210 }
211 }
212
213 pub fn decode_frame(&mut self, data: &[u8]) -> Result<Vec<u8>> {
215 let header = parse_frame_header(data)?;
216 self.window_size = header.window_size;
217
218 if let Some(size) = header.content_size {
220 self.output.reserve(size as usize);
221 }
222
223 let mut pos = header.header_size;
224
225 loop {
227 if data.len() < pos + 3 {
228 return Err(OxiArcError::CorruptedData {
229 offset: pos as u64,
230 message: "truncated block header".to_string(),
231 });
232 }
233
234 let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
236 pos += 3;
237
238 let last_block = (block_header & 1) != 0;
239 let block_type = BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
240 let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
241
242 if block_size > MAX_BLOCK_SIZE {
243 return Err(OxiArcError::CorruptedData {
244 offset: pos as u64,
245 message: format!("block size {} exceeds maximum", block_size),
246 });
247 }
248
249 let compressed_size = match block_type {
251 BlockType::Rle => 1,
252 _ => block_size,
253 };
254
255 if data.len() < pos + compressed_size {
256 return Err(OxiArcError::CorruptedData {
257 offset: pos as u64,
258 message: "truncated block data".to_string(),
259 });
260 }
261
262 let block_data = &data[pos..pos + compressed_size];
263 pos += compressed_size;
264
265 match block_type {
266 BlockType::Raw => {
267 self.output.extend_from_slice(block_data);
268 }
269 BlockType::Rle => {
270 self.output
273 .extend(std::iter::repeat_n(block_data[0], block_size));
274 }
275 BlockType::Compressed => {
276 self.decode_compressed_block(block_data)?;
277 }
278 BlockType::Reserved => {
279 return Err(OxiArcError::CorruptedData {
280 offset: pos as u64,
281 message: "reserved block type".to_string(),
282 });
283 }
284 }
285
286 if last_block {
287 break;
288 }
289 }
290
291 if header.has_checksum {
293 if data.len() < pos + 4 {
294 return Err(OxiArcError::CorruptedData {
295 offset: pos as u64,
296 message: "missing content checksum".to_string(),
297 });
298 }
299
300 let expected =
301 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
302 let computed = xxhash64_checksum(&self.output);
303
304 if expected != computed {
305 return Err(OxiArcError::CrcMismatch { expected, computed });
306 }
307 }
308
309 if let Some(expected_size) = header.content_size {
311 if self.output.len() as u64 != expected_size {
312 return Err(OxiArcError::CorruptedData {
313 offset: 0,
314 message: format!(
315 "content size mismatch: expected {}, got {}",
316 expected_size,
317 self.output.len()
318 ),
319 });
320 }
321 }
322
323 Ok(std::mem::take(&mut self.output))
324 }
325
326 fn decode_compressed_block(&mut self, data: &[u8]) -> Result<()> {
328 let (literals, literals_size) = self.literals_decoder.decode(data)?;
330
331 let sequences_data = &data[literals_size..];
333 let (sequences, _) = self.sequences_decoder.decode(sequences_data)?;
334
335 self.execute_sequences(&literals, &sequences)?;
337
338 Ok(())
339 }
340
341 fn execute_sequences(&mut self, literals: &[u8], sequences: &[Sequence]) -> Result<()> {
343 let mut lit_pos = 0;
344 let dict = self.dictionary.as_deref().unwrap_or(&[]);
345 let dict_len = dict.len();
346
347 for seq in sequences {
348 if seq.literal_length > 0 {
350 if lit_pos + seq.literal_length > literals.len() {
351 return Err(OxiArcError::CorruptedData {
352 offset: 0,
353 message: "literal length exceeds available literals".to_string(),
354 });
355 }
356 self.output
357 .extend_from_slice(&literals[lit_pos..lit_pos + seq.literal_length]);
358 lit_pos += seq.literal_length;
359 }
360
361 if seq.match_length > 0 {
363 let max_offset = self.output.len() + dict_len;
364 if seq.offset == 0 || seq.offset > max_offset {
365 return Err(OxiArcError::CorruptedData {
366 offset: 0,
367 message: format!(
368 "invalid offset {} (output length {}, dict length {})",
369 seq.offset,
370 self.output.len(),
371 dict_len
372 ),
373 });
374 }
375
376 if seq.offset <= self.output.len() {
377 let start = self.output.len() - seq.offset;
379 for i in 0..seq.match_length {
380 let byte = self.output[start + (i % seq.offset)];
381 self.output.push(byte);
382 }
383 } else {
384 let dict_and_output_len = dict_len + self.output.len();
387 let start_in_combined = dict_and_output_len - seq.offset;
388
389 for i in 0..seq.match_length {
390 let pos_in_combined = start_in_combined + (i % seq.offset);
391 let byte = if pos_in_combined < dict_len {
392 dict[pos_in_combined]
393 } else {
394 self.output[pos_in_combined - dict_len]
395 };
396 self.output.push(byte);
397 }
398 }
399 }
400 }
401
402 if lit_pos < literals.len() {
404 self.output.extend_from_slice(&literals[lit_pos..]);
405 }
406
407 Ok(())
408 }
409
410 pub fn reset(&mut self) {
412 self.output.clear();
413 self.sequences_decoder.reset();
414 }
415}
416
417impl Default for ZstdDecoder {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423pub fn decompress(data: &[u8]) -> Result<Vec<u8>> {
425 let mut decoder = ZstdDecoder::new();
426 decoder.decode_frame(data)
427}
428
429pub fn decompress_with_dict(data: &[u8], dict: &[u8]) -> Result<Vec<u8>> {
431 let mut decoder = ZstdDecoder::new();
432 decoder.set_dictionary(dict);
433 decoder.decode_frame(data)
434}
435
436pub fn decompress_frame(data: &[u8]) -> Result<(Vec<u8>, usize)> {
442 let header = parse_frame_header(data)?;
443 let mut decoder = ZstdDecoder::new();
444 decoder.window_size = header.window_size;
445
446 if let Some(size) = header.content_size {
447 decoder.output.reserve(size as usize);
448 }
449
450 let mut pos = header.header_size;
451
452 loop {
454 if data.len() < pos + 3 {
455 return Err(OxiArcError::CorruptedData {
456 offset: pos as u64,
457 message: "truncated block header".to_string(),
458 });
459 }
460
461 let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
462 pos += 3;
463
464 let last_block = (block_header & 1) != 0;
465 let block_type = crate::BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
466 let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
467
468 if block_size > crate::MAX_BLOCK_SIZE {
469 return Err(OxiArcError::CorruptedData {
470 offset: pos as u64,
471 message: format!("block size {} exceeds maximum", block_size),
472 });
473 }
474
475 let compressed_size = match block_type {
476 crate::BlockType::Rle => 1,
477 _ => block_size,
478 };
479
480 if data.len() < pos + compressed_size {
481 return Err(OxiArcError::CorruptedData {
482 offset: pos as u64,
483 message: "truncated block data".to_string(),
484 });
485 }
486
487 let block_data = &data[pos..pos + compressed_size];
488 pos += compressed_size;
489
490 match block_type {
491 crate::BlockType::Raw => {
492 decoder.output.extend_from_slice(block_data);
493 }
494 crate::BlockType::Rle => {
495 decoder
496 .output
497 .extend(std::iter::repeat_n(block_data[0], block_size));
498 }
499 crate::BlockType::Compressed => {
500 decoder.decode_compressed_block(block_data)?;
501 }
502 crate::BlockType::Reserved => {
503 return Err(OxiArcError::CorruptedData {
504 offset: pos as u64,
505 message: "reserved block type".to_string(),
506 });
507 }
508 }
509
510 if last_block {
511 break;
512 }
513 }
514
515 if header.has_checksum {
517 if data.len() < pos + 4 {
518 return Err(OxiArcError::CorruptedData {
519 offset: pos as u64,
520 message: "missing content checksum".to_string(),
521 });
522 }
523
524 let expected = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
525 let computed = xxhash64_checksum(&decoder.output);
526
527 if expected != computed {
528 return Err(OxiArcError::CrcMismatch { expected, computed });
529 }
530 pos += 4;
531 }
532
533 if let Some(expected_size) = header.content_size {
535 if decoder.output.len() as u64 != expected_size {
536 return Err(OxiArcError::CorruptedData {
537 offset: 0,
538 message: format!(
539 "content size mismatch: expected {}, got {}",
540 expected_size,
541 decoder.output.len()
542 ),
543 });
544 }
545 }
546
547 let decompressed = std::mem::take(&mut decoder.output);
548 Ok((decompressed, pos))
549}
550
551pub fn decompress_multi_frame(data: &[u8]) -> Result<Vec<u8>> {
557 let mut output = Vec::new();
558 let mut pos = 0;
559
560 while pos < data.len() {
561 if data.len() - pos < 4 {
563 break;
564 }
565 let magic = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
566
567 if magic == 0xFD2FB528 {
568 let (decompressed, consumed) = decompress_frame(&data[pos..])?;
570 output.extend_from_slice(&decompressed);
571 pos += consumed;
572 } else if (crate::SKIPPABLE_MAGIC_LOW..=crate::SKIPPABLE_MAGIC_HIGH).contains(&magic) {
573 if data.len() - pos < 8 {
575 break;
576 }
577 let skip_size =
578 u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
579 as usize;
580 pos += 8 + skip_size;
581 } else {
582 break;
584 }
585 }
586
587 Ok(output)
588}
589
590pub fn write_skippable_frame(user_data: &[u8], magic_nibble: u8) -> Vec<u8> {
596 let magic = crate::SKIPPABLE_MAGIC_LOW | (magic_nibble & 0xF) as u32;
597 let mut out = Vec::with_capacity(8 + user_data.len());
598 out.extend_from_slice(&magic.to_le_bytes());
599 out.extend_from_slice(&(user_data.len() as u32).to_le_bytes());
600 out.extend_from_slice(user_data);
601 out
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[test]
609 fn test_parse_frame_header_minimal() {
610 let mut data = Vec::new();
612 data.extend_from_slice(&ZSTD_MAGIC);
613 data.push(0x20); data.push(5); let header = parse_frame_header(&data).unwrap();
617
618 assert_eq!(header.content_size, Some(5));
619 assert!(!header.has_checksum);
620 assert!(header.dict_id.is_none());
621 }
622
623 #[test]
624 fn test_parse_frame_header_with_checksum() {
625 let mut data = Vec::new();
626 data.extend_from_slice(&ZSTD_MAGIC);
627 data.push(0x24); data.push(10); let header = parse_frame_header(&data).unwrap();
631
632 assert!(header.has_checksum);
633 assert_eq!(header.content_size, Some(10));
634 }
635
636 #[test]
637 fn test_invalid_magic() {
638 let data = [0x00, 0x00, 0x00, 0x00, 0x00];
639 let result = parse_frame_header(&data);
640 assert!(result.is_err());
641 }
642
643 #[test]
644 fn test_decoder_creation() {
645 let decoder = ZstdDecoder::new();
646 assert_eq!(decoder.window_size, MAX_WINDOW_SIZE);
647 }
648
649 #[test]
650 fn test_block_type_parsing() {
651 assert_eq!(BlockType::from_bits(0).unwrap(), BlockType::Raw);
652 assert_eq!(BlockType::from_bits(1).unwrap(), BlockType::Rle);
653 assert_eq!(BlockType::from_bits(2).unwrap(), BlockType::Compressed);
654 assert!(BlockType::from_bits(3).is_err());
655 }
656}