Skip to main content

haagenti_zstd/block/
literals.rs

1//! Literals section decoding.
2//!
3//! The literals section contains raw byte data that is copied to the output.
4
5use crate::fse::BitReader;
6use crate::huffman::{build_table_from_weights, parse_huffman_weights, HuffmanDecoder};
7use haagenti_core::{Error, Result};
8
9/// Literals block type.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum LiteralsBlockType {
12    /// Raw literals - uncompressed bytes.
13    Raw,
14    /// RLE literals - single byte repeated.
15    Rle,
16    /// Huffman compressed literals with new tree.
17    Compressed,
18    /// Huffman compressed using previous tree.
19    Treeless,
20}
21
22impl LiteralsBlockType {
23    /// Parse block type from 2-bit field.
24    pub fn from_field(field: u8) -> Self {
25        match field {
26            0 => LiteralsBlockType::Raw,
27            1 => LiteralsBlockType::Rle,
28            2 => LiteralsBlockType::Compressed,
29            3 => LiteralsBlockType::Treeless,
30            _ => unreachable!(),
31        }
32    }
33}
34
35/// Parsed literals section.
36#[derive(Debug, Clone)]
37pub struct LiteralsSection {
38    /// Block type.
39    pub block_type: LiteralsBlockType,
40    /// Regenerated (uncompressed) size.
41    pub regenerated_size: usize,
42    /// Compressed size (for compressed modes).
43    pub compressed_size: usize,
44    /// The literal data.
45    data: Vec<u8>,
46}
47
48impl LiteralsSection {
49    /// Create a new raw literals section for testing.
50    pub fn new_raw(data: Vec<u8>) -> Self {
51        let size = data.len();
52        Self {
53            block_type: LiteralsBlockType::Raw,
54            regenerated_size: size,
55            compressed_size: size,
56            data,
57        }
58    }
59
60    /// Parse a literals section from input.
61    ///
62    /// Returns the parsed section and the number of bytes consumed.
63    pub fn parse(input: &[u8]) -> Result<(Self, usize)> {
64        if input.is_empty() {
65            return Err(Error::corrupted("Empty literals section"));
66        }
67
68        let header_byte = input[0];
69        let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
70        let size_format = (header_byte >> 2) & 0x03;
71
72        match block_type {
73            LiteralsBlockType::Raw | LiteralsBlockType::Rle => {
74                Self::parse_raw_rle(input, block_type, size_format)
75            }
76            LiteralsBlockType::Compressed | LiteralsBlockType::Treeless => {
77                Self::parse_compressed(input, block_type, size_format)
78            }
79        }
80    }
81
82    /// Parse raw or RLE literals.
83    fn parse_raw_rle(
84        input: &[u8],
85        block_type: LiteralsBlockType,
86        size_format: u8,
87    ) -> Result<(Self, usize)> {
88        let (regenerated_size, header_size) = match size_format {
89            // Size_Format = 0b00 or 0b10: 5-bit size
90            0 | 2 => {
91                let size = (input[0] >> 3) as usize;
92                (size, 1)
93            }
94            // Size_Format = 0b01: 12-bit size
95            1 => {
96                if input.len() < 2 {
97                    return Err(Error::corrupted("Literals header truncated"));
98                }
99                let size = ((input[0] >> 4) as usize) | ((input[1] as usize) << 4);
100                (size, 2)
101            }
102            // Size_Format = 0b11: 20-bit size
103            3 => {
104                if input.len() < 3 {
105                    return Err(Error::corrupted("Literals header truncated"));
106                }
107                let size = ((input[0] >> 4) as usize)
108                    | ((input[1] as usize) << 4)
109                    | ((input[2] as usize) << 12);
110                (size, 3)
111            }
112            _ => unreachable!(),
113        };
114
115        let data_start = header_size;
116        let data = match block_type {
117            LiteralsBlockType::Raw => {
118                if input.len() < data_start + regenerated_size {
119                    return Err(Error::corrupted("Raw literals truncated"));
120                }
121                input[data_start..data_start + regenerated_size].to_vec()
122            }
123            LiteralsBlockType::Rle => {
124                if input.len() < data_start + 1 {
125                    return Err(Error::corrupted("RLE literals missing byte"));
126                }
127                vec![input[data_start]; regenerated_size]
128            }
129            _ => unreachable!(),
130        };
131
132        let total_size = match block_type {
133            LiteralsBlockType::Raw => header_size + regenerated_size,
134            LiteralsBlockType::Rle => header_size + 1,
135            _ => unreachable!(),
136        };
137
138        Ok((
139            Self {
140                block_type,
141                regenerated_size,
142                compressed_size: match block_type {
143                    LiteralsBlockType::Raw => regenerated_size,
144                    LiteralsBlockType::Rle => 1,
145                    _ => unreachable!(),
146                },
147                data,
148            },
149            total_size,
150        ))
151    }
152
153    /// Parse compressed literals (Huffman).
154    fn parse_compressed(
155        input: &[u8],
156        block_type: LiteralsBlockType,
157        size_format: u8,
158    ) -> Result<(Self, usize)> {
159        // Determine stream count and parse sizes
160        let is_single_stream = size_format == 3;
161
162        // Parse sizes based on format
163        let (regenerated_size, compressed_size, header_size) = match size_format {
164            // 4 streams, 10-bit sizes (3-byte header)
165            // RFC 8878: regen[3:0] = byte0[7:4], regen[9:4] = byte1[5:0]
166            //           comp[1:0] = byte1[7:6], comp[9:2] = byte2[7:0]
167            0 => {
168                if input.len() < 3 {
169                    return Err(Error::corrupted("Compressed literals header truncated"));
170                }
171                let regen = ((input[0] >> 4) as usize) | (((input[1] & 0x3F) as usize) << 4);
172                let comp = ((input[1] >> 6) as usize) | ((input[2] as usize) << 2);
173                (regen, comp, 3)
174            }
175            // 4 streams, 14-bit regen size, 10-bit comp size (4-byte header)
176            // RFC 8878: byte0[7:4]=regen[3:0], byte1=regen[11:4], byte2[1:0]=regen[13:12]
177            //           byte2[7:2]=comp[5:0], byte3=comp[9:2]? No...
178            // Actually: byte2[7:6]=comp[1:0], byte3=comp[9:2]
179            1 => {
180                if input.len() < 4 {
181                    return Err(Error::corrupted("Compressed literals header truncated"));
182                }
183                let regen = ((input[0] >> 4) as usize)
184                    | ((input[1] as usize) << 4)
185                    | (((input[2] & 0x03) as usize) << 12);
186                let comp = ((input[2] >> 6) as usize) | ((input[3] as usize) << 2);
187                (regen, comp, 4)
188            }
189            // 4 streams, 18-bit sizes
190            2 => {
191                if input.len() < 5 {
192                    return Err(Error::corrupted("Compressed literals header truncated"));
193                }
194                let regen = (((input[0] >> 4) & 0x3F) as usize)
195                    | ((input[1] as usize) << 4)
196                    | (((input[2] & 0x0F) as usize) << 12);
197                let comp = ((input[2] >> 4) as usize)
198                    | ((input[3] as usize) << 4)
199                    | (((input[4] & 0x03) as usize) << 12);
200                (regen, comp, 5)
201            }
202            // 1 stream, 10-bit sizes (3-byte header, single stream)
203            // Same format as Size_Format=0 but single stream instead of 4
204            3 => {
205                if input.len() < 3 {
206                    return Err(Error::corrupted("Compressed literals header truncated"));
207                }
208                let regen = ((input[0] >> 4) as usize) | (((input[1] & 0x3F) as usize) << 4);
209                let comp = ((input[1] >> 6) as usize) | ((input[2] as usize) << 2);
210                (regen, comp, 3)
211            }
212            _ => unreachable!(),
213        };
214
215        if input.len() < header_size + compressed_size {
216            return Err(Error::corrupted("Compressed literals data truncated"));
217        }
218
219        let compressed_data = &input[header_size..header_size + compressed_size];
220
221        // For treeless mode, we'd need a previously stored Huffman table
222        if block_type == LiteralsBlockType::Treeless {
223            return Err(Error::Unsupported(
224                "Treeless Huffman literals require previous table state".into(),
225            ));
226        }
227
228        // Decode Huffman-compressed literals
229        let data =
230            Self::decode_huffman_literals(compressed_data, regenerated_size, is_single_stream)?;
231
232        let total_size = header_size + compressed_size;
233
234        Ok((
235            Self {
236                block_type,
237                regenerated_size,
238                compressed_size,
239                data,
240            },
241            total_size,
242        ))
243    }
244
245    /// Decode Huffman-compressed literals.
246    fn decode_huffman_literals(
247        data: &[u8],
248        regenerated_size: usize,
249        is_single_stream: bool,
250    ) -> Result<Vec<u8>> {
251        if data.is_empty() {
252            return Err(Error::corrupted("Empty Huffman literals data"));
253        }
254
255        // Parse Huffman weights from the beginning of data
256        let (weights, weights_consumed) = parse_huffman_weights(data)?;
257
258        // Build Huffman table
259        let table = build_table_from_weights(weights)?;
260        let decoder = HuffmanDecoder::new(&table);
261
262        let stream_data = &data[weights_consumed..];
263
264        if is_single_stream {
265            Self::decode_single_stream(&decoder, stream_data, regenerated_size)
266        } else {
267            Self::decode_four_streams(&decoder, stream_data, regenerated_size)
268        }
269    }
270
271    /// Decode a single Huffman stream.
272    fn decode_single_stream(
273        decoder: &HuffmanDecoder,
274        data: &[u8],
275        regenerated_size: usize,
276    ) -> Result<Vec<u8>> {
277        if data.is_empty() {
278            if regenerated_size == 0 {
279                return Ok(Vec::new());
280            }
281            return Err(Error::corrupted("Empty stream data for Huffman decoding"));
282        }
283
284        // Huffman streams are read backwards (from end to start)
285        let mut output = Vec::with_capacity(regenerated_size);
286        let mut bits = BitReader::new_reversed(data)?;
287
288        for _ in 0..regenerated_size {
289            let symbol = decoder.decode_symbol(&mut bits)?;
290            output.push(symbol);
291        }
292
293        Ok(output)
294    }
295
296    /// Decode four parallel Huffman streams.
297    fn decode_four_streams(
298        decoder: &HuffmanDecoder,
299        data: &[u8],
300        regenerated_size: usize,
301    ) -> Result<Vec<u8>> {
302        // 4-stream format has 6-byte header with stream sizes
303        if data.len() < 6 {
304            return Err(Error::corrupted("4-stream header too short"));
305        }
306
307        // Read jump table: 3 x 2-byte offsets (little-endian)
308        let jump1 = u16::from_le_bytes([data[0], data[1]]) as usize;
309        let jump2 = u16::from_le_bytes([data[2], data[3]]) as usize;
310        let jump3 = u16::from_le_bytes([data[4], data[5]]) as usize;
311
312        // Stream boundaries
313        let stream1_start = 6;
314        let stream2_start = 6 + jump1;
315        let stream3_start = 6 + jump2;
316        let stream4_start = 6 + jump3;
317        let stream4_end = data.len();
318
319        // Validate boundaries
320        if stream2_start > data.len() || stream3_start > data.len() || stream4_start > data.len() {
321            return Err(Error::corrupted(
322                "Invalid stream boundaries in 4-stream literals",
323            ));
324        }
325
326        // Calculate output size per stream (regenerated_size split into 4)
327        let base_size = regenerated_size / 4;
328        let remainder = regenerated_size % 4;
329
330        let sizes = [
331            base_size + if remainder > 0 { 1 } else { 0 },
332            base_size + if remainder > 1 { 1 } else { 0 },
333            base_size + if remainder > 2 { 1 } else { 0 },
334            base_size,
335        ];
336
337        let stream_ranges = [
338            (stream1_start, stream2_start),
339            (stream2_start, stream3_start),
340            (stream3_start, stream4_start),
341            (stream4_start, stream4_end),
342        ];
343
344        let mut output = Vec::with_capacity(regenerated_size);
345
346        // Decode each stream
347        for (i, &(start, end)) in stream_ranges.iter().enumerate() {
348            if start >= end {
349                // Empty stream
350                if sizes[i] > 0 {
351                    return Err(Error::corrupted(format!(
352                        "Stream {} is empty but expects {} symbols",
353                        i, sizes[i]
354                    )));
355                }
356                continue;
357            }
358
359            let stream_data = &data[start..end];
360            let stream_output = Self::decode_single_stream(decoder, stream_data, sizes[i])?;
361            output.extend(stream_output);
362        }
363
364        Ok(output)
365    }
366
367    /// Get the literal data.
368    pub fn data(&self) -> &[u8] {
369        &self.data
370    }
371}
372
373// =============================================================================
374// Tests
375// =============================================================================
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_literals_block_type_parsing() {
383        assert_eq!(LiteralsBlockType::from_field(0), LiteralsBlockType::Raw);
384        assert_eq!(LiteralsBlockType::from_field(1), LiteralsBlockType::Rle);
385        assert_eq!(
386            LiteralsBlockType::from_field(2),
387            LiteralsBlockType::Compressed
388        );
389        assert_eq!(
390            LiteralsBlockType::from_field(3),
391            LiteralsBlockType::Treeless
392        );
393    }
394
395    #[test]
396    fn test_raw_literals_5bit_size() {
397        // Raw, size_format=0, size=5 (5 << 3 = 40, type=0 -> 0b00101000 = 0x28)
398        // Actually: header byte = (size << 3) | (size_format << 2) | type
399        // size=5: (5 << 3) | (0 << 2) | 0 = 0x28
400        let mut input = vec![0x28]; // size=5, format=0, type=Raw
401        input.extend_from_slice(b"Hello");
402
403        let (section, consumed) = LiteralsSection::parse(&input).unwrap();
404        assert_eq!(section.block_type, LiteralsBlockType::Raw);
405        assert_eq!(section.regenerated_size, 5);
406        assert_eq!(section.data, b"Hello");
407        assert_eq!(consumed, 6); // 1 header + 5 data
408    }
409
410    #[test]
411    fn test_rle_literals_5bit_size() {
412        // RLE, size_format=0, size=10
413        // header = (10 << 3) | (0 << 2) | 1 = 0x51
414        let input = vec![0x51, b'X']; // size=10, format=0, type=RLE, byte='X'
415
416        let (section, consumed) = LiteralsSection::parse(&input).unwrap();
417        assert_eq!(section.block_type, LiteralsBlockType::Rle);
418        assert_eq!(section.regenerated_size, 10);
419        assert_eq!(section.data, vec![b'X'; 10]);
420        assert_eq!(consumed, 2); // 1 header + 1 byte
421    }
422
423    #[test]
424    fn test_raw_literals_12bit_size() {
425        // Raw, size_format=1, size=256
426        // byte0: (size_low << 4) | (1 << 2) | 0
427        // size_low = size & 0x0F = 0
428        // size_high = size >> 4 = 16
429        // byte0 = (0 << 4) | (1 << 2) | 0 = 0x04
430        // byte1 = size_high = 16
431        let mut input = vec![0x04, 0x10]; // size=256
432        input.resize(2 + 256, b'A');
433
434        let (section, consumed) = LiteralsSection::parse(&input).unwrap();
435        assert_eq!(section.block_type, LiteralsBlockType::Raw);
436        assert_eq!(section.regenerated_size, 256);
437        assert_eq!(consumed, 2 + 256);
438    }
439
440    #[test]
441    fn test_empty_input_error() {
442        let result = LiteralsSection::parse(&[]);
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn test_truncated_raw_error() {
448        // Raw, size=10, but only 5 bytes of data
449        let input = vec![0x50, b'H', b'e', b'l', b'l', b'o'];
450        let result = LiteralsSection::parse(&input);
451        assert!(result.is_err());
452    }
453
454    #[test]
455    fn test_new_raw_helper() {
456        let section = LiteralsSection::new_raw(b"test".to_vec());
457        assert_eq!(section.block_type, LiteralsBlockType::Raw);
458        assert_eq!(section.regenerated_size, 4);
459        assert_eq!(section.data(), b"test");
460    }
461
462    #[test]
463    fn test_compressed_header_type_detection() {
464        // Test that compressed literals type is detected correctly
465        // Type=Compressed (2), any size format
466        let header_byte = 0x0E; // Type=2 (Compressed), Size_Format=3
467        let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
468        assert_eq!(block_type, LiteralsBlockType::Compressed);
469
470        let header_byte = 0x02; // Type=2 (Compressed), Size_Format=0
471        let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
472        assert_eq!(block_type, LiteralsBlockType::Compressed);
473    }
474
475    #[test]
476    fn test_treeless_requires_previous_table() {
477        // Treeless mode (type=3) should fail without previous table state
478        // Construct minimal treeless header: type=3, size_format=3, regen=5, comp=10
479        // byte0 = (5 << 4) | (3 << 2) | 3 = 0x5F
480        // byte1 = ((5 >> 4) & 0x03) | ((10 & 0x3F) << 2) = 0 | 0x28 = 0x28
481        // byte2 = (10 >> 6) = 0
482        // Then add fake compressed data
483        let mut input = vec![0x5F, 0x28, 0x00];
484        input.extend(vec![0x80; 10]); // Fake compressed data with sentinel
485
486        let result = LiteralsSection::parse(&input);
487
488        // Should fail with "requires previous table" error
489        assert!(result.is_err());
490        if let Err(e) = result {
491            let msg = format!("{:?}", e);
492            assert!(
493                msg.contains("previous table") || msg.contains("Treeless"),
494                "Expected 'previous table' or 'Treeless' error, got: {}",
495                msg
496            );
497        }
498    }
499
500    #[test]
501    fn test_compressed_literals_truncated_data_error() {
502        // Compressed literals with data shorter than declared
503        // Type=2, size_format=3 (single stream), regen=10, comp=20
504        // But only provide 5 bytes of data
505        let input = vec![0xA2, 0x50, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05];
506        let result = LiteralsSection::parse(&input);
507
508        // Should fail due to truncated data
509        assert!(result.is_err());
510    }
511
512    #[test]
513    fn test_size_format_detection() {
514        // Verify size_format extraction from header
515        for size_format in 0..4u8 {
516            let header_byte = 0x02 | (size_format << 2); // Compressed type with various formats
517            let extracted = (header_byte >> 2) & 0x03;
518            assert_eq!(extracted, size_format);
519        }
520    }
521}