Skip to main content

oxiarc_zstd/
frame.rs

1//! Zstandard frame parsing and decompression.
2//!
3//! Handles the top-level frame format including header, blocks, and checksum.
4
5use crate::literals::LiteralsDecoder;
6use crate::sequences::{Sequence, SequencesDecoder};
7use crate::xxhash::xxhash64_checksum;
8use crate::{BlockType, MAX_BLOCK_SIZE, MAX_WINDOW_SIZE, ZSTD_MAGIC};
9use oxiarc_core::error::{OxiArcError, Result};
10
11/// Frame header descriptor flags.
12const FHD_SINGLE_SEGMENT: u8 = 0x20;
13const FHD_CONTENT_CHECKSUM: u8 = 0x04;
14const FHD_DICT_ID_FLAG_MASK: u8 = 0x03;
15const FHD_CONTENT_SIZE_FLAG_MASK: u8 = 0xC0;
16
17/// Zstandard frame header.
18#[derive(Debug, Clone)]
19pub struct FrameHeader {
20    /// Window size for decompression buffer.
21    pub window_size: usize,
22    /// Uncompressed content size (if known).
23    pub content_size: Option<u64>,
24    /// Dictionary ID (if present).
25    #[allow(dead_code)]
26    pub dict_id: Option<u32>,
27    /// Whether content checksum is present.
28    pub has_checksum: bool,
29    /// Header size in bytes.
30    pub header_size: usize,
31}
32
33/// Parse frame header.
34pub fn parse_frame_header(data: &[u8]) -> Result<FrameHeader> {
35    if data.len() < 5 {
36        return Err(OxiArcError::CorruptedData {
37            offset: 0,
38            message: "truncated frame header".to_string(),
39        });
40    }
41
42    // Check magic
43    if data[0..4] != ZSTD_MAGIC {
44        return Err(OxiArcError::invalid_magic(ZSTD_MAGIC, &data[0..4]));
45    }
46
47    let descriptor = data[4];
48    let single_segment = (descriptor & FHD_SINGLE_SEGMENT) != 0;
49    let has_checksum = (descriptor & FHD_CONTENT_CHECKSUM) != 0;
50    let dict_id_flag = descriptor & FHD_DICT_ID_FLAG_MASK;
51    let content_size_flag = (descriptor & FHD_CONTENT_SIZE_FLAG_MASK) >> 6;
52
53    let mut pos = 5;
54
55    // Window descriptor (absent if single segment)
56    let window_size = if single_segment {
57        0 // Will be determined from content size
58    } else {
59        if data.len() <= pos {
60            return Err(OxiArcError::CorruptedData {
61                offset: pos as u64,
62                message: "missing window descriptor".to_string(),
63            });
64        }
65        let wd = data[pos];
66        pos += 1;
67
68        let exponent = (wd >> 3) as u32;
69        let mantissa = (wd & 0x07) as u32;
70        let base = 1u64 << (10 + exponent);
71        let window = base + (base >> 3) * mantissa as u64;
72        window.min(MAX_WINDOW_SIZE as u64) as usize
73    };
74
75    // Dictionary ID
76    let dict_id = match dict_id_flag {
77        0 => None,
78        1 => {
79            if data.len() <= pos {
80                return Err(OxiArcError::CorruptedData {
81                    offset: pos as u64,
82                    message: "missing dictionary ID".to_string(),
83                });
84            }
85            let id = data[pos] as u32;
86            pos += 1;
87            Some(id)
88        }
89        2 => {
90            if data.len() < pos + 2 {
91                return Err(OxiArcError::CorruptedData {
92                    offset: pos as u64,
93                    message: "truncated dictionary ID".to_string(),
94                });
95            }
96            let id = u16::from_le_bytes([data[pos], data[pos + 1]]) as u32;
97            pos += 2;
98            Some(id)
99        }
100        3 => {
101            if data.len() < pos + 4 {
102                return Err(OxiArcError::CorruptedData {
103                    offset: pos as u64,
104                    message: "truncated dictionary ID".to_string(),
105                });
106            }
107            let id = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
108            pos += 4;
109            Some(id)
110        }
111        _ => unreachable!(),
112    };
113
114    // Content size
115    let content_size = if single_segment || content_size_flag != 0 {
116        let size_bytes = match content_size_flag {
117            0 => 1, // Single segment implies 1 byte
118            1 => 2,
119            2 => 4,
120            3 => 8,
121            _ => unreachable!(),
122        };
123
124        if data.len() < pos + size_bytes {
125            return Err(OxiArcError::CorruptedData {
126                offset: pos as u64,
127                message: "truncated content size".to_string(),
128            });
129        }
130
131        let size = match size_bytes {
132            1 => data[pos] as u64,
133            2 => {
134                let s = u16::from_le_bytes([data[pos], data[pos + 1]]) as u64;
135                s + 256 // Add 256 for 2-byte size
136            }
137            4 => {
138                u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as u64
139            }
140            8 => u64::from_le_bytes([
141                data[pos],
142                data[pos + 1],
143                data[pos + 2],
144                data[pos + 3],
145                data[pos + 4],
146                data[pos + 5],
147                data[pos + 6],
148                data[pos + 7],
149            ]),
150            _ => unreachable!(),
151        };
152        pos += size_bytes;
153        Some(size)
154    } else {
155        None
156    };
157
158    // Adjust window size for single segment
159    let window_size = if single_segment {
160        content_size
161            .unwrap_or(MAX_WINDOW_SIZE as u64)
162            .min(MAX_WINDOW_SIZE as u64) as usize
163    } else {
164        window_size
165    };
166
167    Ok(FrameHeader {
168        window_size,
169        content_size,
170        dict_id,
171        has_checksum,
172        header_size: pos,
173    })
174}
175
176/// Zstandard decoder.
177pub struct ZstdDecoder {
178    /// Literals decoder.
179    literals_decoder: LiteralsDecoder,
180    /// Sequences decoder.
181    sequences_decoder: SequencesDecoder,
182    /// Output buffer (sliding window).
183    output: Vec<u8>,
184    /// Window size.
185    window_size: usize,
186    /// Optional dictionary for decompression.
187    dictionary: Option<Vec<u8>>,
188}
189
190impl ZstdDecoder {
191    /// Create a new decoder.
192    pub fn new() -> Self {
193        Self {
194            literals_decoder: LiteralsDecoder::new(),
195            sequences_decoder: SequencesDecoder::new(),
196            output: Vec::new(),
197            window_size: MAX_WINDOW_SIZE,
198            dictionary: None,
199        }
200    }
201
202    /// Set a dictionary for decompression.
203    ///
204    /// Must match the dictionary used during compression.
205    pub fn set_dictionary(&mut self, dict: &[u8]) {
206        if dict.is_empty() {
207            self.dictionary = None;
208        } else {
209            self.dictionary = Some(dict.to_vec());
210        }
211    }
212
213    /// Decode a complete Zstandard frame.
214    pub fn decode_frame(&mut self, data: &[u8]) -> Result<Vec<u8>> {
215        let header = parse_frame_header(data)?;
216        self.window_size = header.window_size;
217
218        // Reserve space for output
219        if let Some(size) = header.content_size {
220            self.output.reserve(size as usize);
221        }
222
223        let mut pos = header.header_size;
224
225        // Decode blocks
226        loop {
227            if data.len() < pos + 3 {
228                return Err(OxiArcError::CorruptedData {
229                    offset: pos as u64,
230                    message: "truncated block header".to_string(),
231                });
232            }
233
234            // Read block header (3 bytes, little-endian)
235            let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
236            pos += 3;
237
238            let last_block = (block_header & 1) != 0;
239            let block_type = BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
240            let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
241
242            if block_size > MAX_BLOCK_SIZE {
243                return Err(OxiArcError::CorruptedData {
244                    offset: pos as u64,
245                    message: format!("block size {} exceeds maximum", block_size),
246                });
247            }
248
249            // For RLE blocks, block_size is the regenerated size and only 1 byte of data follows
250            let compressed_size = match block_type {
251                BlockType::Rle => 1,
252                _ => block_size,
253            };
254
255            if data.len() < pos + compressed_size {
256                return Err(OxiArcError::CorruptedData {
257                    offset: pos as u64,
258                    message: "truncated block data".to_string(),
259                });
260            }
261
262            let block_data = &data[pos..pos + compressed_size];
263            pos += compressed_size;
264
265            match block_type {
266                BlockType::Raw => {
267                    self.output.extend_from_slice(block_data);
268                }
269                BlockType::Rle => {
270                    // block_size is the regenerated size for RLE
271                    // The actual data is just 1 byte
272                    self.output
273                        .extend(std::iter::repeat_n(block_data[0], block_size));
274                }
275                BlockType::Compressed => {
276                    self.decode_compressed_block(block_data)?;
277                }
278                BlockType::Reserved => {
279                    return Err(OxiArcError::CorruptedData {
280                        offset: pos as u64,
281                        message: "reserved block type".to_string(),
282                    });
283                }
284            }
285
286            if last_block {
287                break;
288            }
289        }
290
291        // Verify checksum if present
292        if header.has_checksum {
293            if data.len() < pos + 4 {
294                return Err(OxiArcError::CorruptedData {
295                    offset: pos as u64,
296                    message: "missing content checksum".to_string(),
297                });
298            }
299
300            let expected =
301                u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
302            let computed = xxhash64_checksum(&self.output);
303
304            if expected != computed {
305                return Err(OxiArcError::CrcMismatch { expected, computed });
306            }
307        }
308
309        // Verify content size if known
310        if let Some(expected_size) = header.content_size {
311            if self.output.len() as u64 != expected_size {
312                return Err(OxiArcError::CorruptedData {
313                    offset: 0,
314                    message: format!(
315                        "content size mismatch: expected {}, got {}",
316                        expected_size,
317                        self.output.len()
318                    ),
319                });
320            }
321        }
322
323        Ok(std::mem::take(&mut self.output))
324    }
325
326    /// Decode a compressed block.
327    fn decode_compressed_block(&mut self, data: &[u8]) -> Result<()> {
328        // Decode literals
329        let (literals, literals_size) = self.literals_decoder.decode(data)?;
330
331        // Decode sequences
332        let sequences_data = &data[literals_size..];
333        let (sequences, _) = self.sequences_decoder.decode(sequences_data)?;
334
335        // Execute sequences
336        self.execute_sequences(&literals, &sequences)?;
337
338        Ok(())
339    }
340
341    /// Execute sequences to produce output.
342    fn execute_sequences(&mut self, literals: &[u8], sequences: &[Sequence]) -> Result<()> {
343        let mut lit_pos = 0;
344        let dict = self.dictionary.as_deref().unwrap_or(&[]);
345        let dict_len = dict.len();
346
347        for seq in sequences {
348            // Copy literals
349            if seq.literal_length > 0 {
350                if lit_pos + seq.literal_length > literals.len() {
351                    return Err(OxiArcError::CorruptedData {
352                        offset: 0,
353                        message: "literal length exceeds available literals".to_string(),
354                    });
355                }
356                self.output
357                    .extend_from_slice(&literals[lit_pos..lit_pos + seq.literal_length]);
358                lit_pos += seq.literal_length;
359            }
360
361            // Copy match
362            if seq.match_length > 0 {
363                let max_offset = self.output.len() + dict_len;
364                if seq.offset == 0 || seq.offset > max_offset {
365                    return Err(OxiArcError::CorruptedData {
366                        offset: 0,
367                        message: format!(
368                            "invalid offset {} (output length {}, dict length {})",
369                            seq.offset,
370                            self.output.len(),
371                            dict_len
372                        ),
373                    });
374                }
375
376                if seq.offset <= self.output.len() {
377                    // Normal case: offset within output buffer
378                    let start = self.output.len() - seq.offset;
379                    for i in 0..seq.match_length {
380                        let byte = self.output[start + (i % seq.offset)];
381                        self.output.push(byte);
382                    }
383                } else {
384                    // Dictionary reference: offset extends into dictionary
385                    // The logical buffer is [dict | output], and we go back `offset` from end
386                    let dict_and_output_len = dict_len + self.output.len();
387                    let start_in_combined = dict_and_output_len - seq.offset;
388
389                    for i in 0..seq.match_length {
390                        let pos_in_combined = start_in_combined + (i % seq.offset);
391                        let byte = if pos_in_combined < dict_len {
392                            dict[pos_in_combined]
393                        } else {
394                            self.output[pos_in_combined - dict_len]
395                        };
396                        self.output.push(byte);
397                    }
398                }
399            }
400        }
401
402        // Copy remaining literals
403        if lit_pos < literals.len() {
404            self.output.extend_from_slice(&literals[lit_pos..]);
405        }
406
407        Ok(())
408    }
409
410    /// Reset decoder state for a new frame.
411    pub fn reset(&mut self) {
412        self.output.clear();
413        self.sequences_decoder.reset();
414    }
415}
416
417impl Default for ZstdDecoder {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423/// Decompress Zstandard data.
424pub fn decompress(data: &[u8]) -> Result<Vec<u8>> {
425    let mut decoder = ZstdDecoder::new();
426    decoder.decode_frame(data)
427}
428
429/// Decompress Zstandard data using a dictionary.
430pub fn decompress_with_dict(data: &[u8], dict: &[u8]) -> Result<Vec<u8>> {
431    let mut decoder = ZstdDecoder::new();
432    decoder.set_dictionary(dict);
433    decoder.decode_frame(data)
434}
435
436/// Decompress a single Zstandard frame, returning the decompressed data and
437/// the number of bytes consumed from `data`.
438///
439/// This allows callers to locate the end of one frame in a concatenated
440/// stream and proceed to the next.
441pub fn decompress_frame(data: &[u8]) -> Result<(Vec<u8>, usize)> {
442    let header = parse_frame_header(data)?;
443    let mut decoder = ZstdDecoder::new();
444    decoder.window_size = header.window_size;
445
446    if let Some(size) = header.content_size {
447        decoder.output.reserve(size as usize);
448    }
449
450    let mut pos = header.header_size;
451
452    // Decode blocks, tracking how many bytes we consume.
453    loop {
454        if data.len() < pos + 3 {
455            return Err(OxiArcError::CorruptedData {
456                offset: pos as u64,
457                message: "truncated block header".to_string(),
458            });
459        }
460
461        let block_header = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], 0]);
462        pos += 3;
463
464        let last_block = (block_header & 1) != 0;
465        let block_type = crate::BlockType::from_bits(((block_header >> 1) & 0x03) as u8)?;
466        let block_size = ((block_header >> 3) & 0x1FFFFF) as usize;
467
468        if block_size > crate::MAX_BLOCK_SIZE {
469            return Err(OxiArcError::CorruptedData {
470                offset: pos as u64,
471                message: format!("block size {} exceeds maximum", block_size),
472            });
473        }
474
475        let compressed_size = match block_type {
476            crate::BlockType::Rle => 1,
477            _ => block_size,
478        };
479
480        if data.len() < pos + compressed_size {
481            return Err(OxiArcError::CorruptedData {
482                offset: pos as u64,
483                message: "truncated block data".to_string(),
484            });
485        }
486
487        let block_data = &data[pos..pos + compressed_size];
488        pos += compressed_size;
489
490        match block_type {
491            crate::BlockType::Raw => {
492                decoder.output.extend_from_slice(block_data);
493            }
494            crate::BlockType::Rle => {
495                decoder
496                    .output
497                    .extend(std::iter::repeat_n(block_data[0], block_size));
498            }
499            crate::BlockType::Compressed => {
500                decoder.decode_compressed_block(block_data)?;
501            }
502            crate::BlockType::Reserved => {
503                return Err(OxiArcError::CorruptedData {
504                    offset: pos as u64,
505                    message: "reserved block type".to_string(),
506                });
507            }
508        }
509
510        if last_block {
511            break;
512        }
513    }
514
515    // Verify checksum if present.
516    if header.has_checksum {
517        if data.len() < pos + 4 {
518            return Err(OxiArcError::CorruptedData {
519                offset: pos as u64,
520                message: "missing content checksum".to_string(),
521            });
522        }
523
524        let expected = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
525        let computed = xxhash64_checksum(&decoder.output);
526
527        if expected != computed {
528            return Err(OxiArcError::CrcMismatch { expected, computed });
529        }
530        pos += 4;
531    }
532
533    // Verify content size if known.
534    if let Some(expected_size) = header.content_size {
535        if decoder.output.len() as u64 != expected_size {
536            return Err(OxiArcError::CorruptedData {
537                offset: 0,
538                message: format!(
539                    "content size mismatch: expected {}, got {}",
540                    expected_size,
541                    decoder.output.len()
542                ),
543            });
544        }
545    }
546
547    let decompressed = std::mem::take(&mut decoder.output);
548    Ok((decompressed, pos))
549}
550
551/// Decompress one or more concatenated Zstandard frames.
552///
553/// Skippable frames (magic `0x184D2A50`–`0x184D2A5F`) are silently skipped.
554/// Unknown magic values cause iteration to stop and the accumulated output is
555/// returned (trailing garbage is tolerated).
556pub fn decompress_multi_frame(data: &[u8]) -> Result<Vec<u8>> {
557    let mut output = Vec::new();
558    let mut pos = 0;
559
560    while pos < data.len() {
561        // Need at least 4 bytes to read a magic number.
562        if data.len() - pos < 4 {
563            break;
564        }
565        let magic = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
566
567        if magic == 0xFD2FB528 {
568            // Normal Zstd frame.
569            let (decompressed, consumed) = decompress_frame(&data[pos..])?;
570            output.extend_from_slice(&decompressed);
571            pos += consumed;
572        } else if (crate::SKIPPABLE_MAGIC_LOW..=crate::SKIPPABLE_MAGIC_HIGH).contains(&magic) {
573            // Skippable frame: 4 bytes magic + 4 bytes size + <size> bytes data.
574            if data.len() - pos < 8 {
575                break;
576            }
577            let skip_size =
578                u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
579                    as usize;
580            pos += 8 + skip_size;
581        } else {
582            // Unknown magic — stop gracefully.
583            break;
584        }
585    }
586
587    Ok(output)
588}
589
590/// Write a skippable Zstandard frame containing arbitrary user data.
591///
592/// `magic_nibble` selects which skippable magic to use; it is masked to the
593/// lower 4 bits so the resulting magic is always in the range
594/// `0x184D2A50`–`0x184D2A5F`.
595pub fn write_skippable_frame(user_data: &[u8], magic_nibble: u8) -> Vec<u8> {
596    let magic = crate::SKIPPABLE_MAGIC_LOW | (magic_nibble & 0xF) as u32;
597    let mut out = Vec::with_capacity(8 + user_data.len());
598    out.extend_from_slice(&magic.to_le_bytes());
599    out.extend_from_slice(&(user_data.len() as u32).to_le_bytes());
600    out.extend_from_slice(user_data);
601    out
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[test]
609    fn test_parse_frame_header_minimal() {
610        // Minimal frame: magic + descriptor (single segment, 1 byte content size)
611        let mut data = Vec::new();
612        data.extend_from_slice(&ZSTD_MAGIC);
613        data.push(0x20); // Single segment flag
614        data.push(5); // Content size = 5
615
616        let header = parse_frame_header(&data).unwrap();
617
618        assert_eq!(header.content_size, Some(5));
619        assert!(!header.has_checksum);
620        assert!(header.dict_id.is_none());
621    }
622
623    #[test]
624    fn test_parse_frame_header_with_checksum() {
625        let mut data = Vec::new();
626        data.extend_from_slice(&ZSTD_MAGIC);
627        data.push(0x24); // Single segment + checksum
628        data.push(10); // Content size = 10
629
630        let header = parse_frame_header(&data).unwrap();
631
632        assert!(header.has_checksum);
633        assert_eq!(header.content_size, Some(10));
634    }
635
636    #[test]
637    fn test_invalid_magic() {
638        let data = [0x00, 0x00, 0x00, 0x00, 0x00];
639        let result = parse_frame_header(&data);
640        assert!(result.is_err());
641    }
642
643    #[test]
644    fn test_decoder_creation() {
645        let decoder = ZstdDecoder::new();
646        assert_eq!(decoder.window_size, MAX_WINDOW_SIZE);
647    }
648
649    #[test]
650    fn test_block_type_parsing() {
651        assert_eq!(BlockType::from_bits(0).unwrap(), BlockType::Raw);
652        assert_eq!(BlockType::from_bits(1).unwrap(), BlockType::Rle);
653        assert_eq!(BlockType::from_bits(2).unwrap(), BlockType::Compressed);
654        assert!(BlockType::from_bits(3).is_err());
655    }
656}