Skip to main content

rar_stream/decompress/
huffman.rs

1//! Huffman decoder for RAR compression.
2//!
3//! RAR uses canonical Huffman codes with up to 15-bit code lengths.
4
5use super::{DecompressError, Result, BitReader};
6
7/// Maximum code length in bits.
8pub const MAX_CODE_LENGTH: usize = 15;
9
10/// Table sizes for RAR3/4 format.
11pub const MAINCODE_SIZE: usize = 299;
12pub const OFFSETCODE_SIZE: usize = 60;
13pub const LOWOFFSETCODE_SIZE: usize = 17;
14pub const LENGTHCODE_SIZE: usize = 28;
15pub const HUFFMAN_TABLE_SIZE: usize = MAINCODE_SIZE + OFFSETCODE_SIZE + LOWOFFSETCODE_SIZE + LENGTHCODE_SIZE;
16
17/// Huffman decoding table entry.
18#[derive(Clone, Copy, Default)]
19pub struct HuffmanEntry {
20    /// Symbol value
21    pub symbol: u16,
22    /// Code length in bits
23    pub length: u8,
24}
25
26/// Huffman decoding table.
27/// Uses a lookup table for fast decoding of short codes.
28pub struct HuffmanTable {
29    /// Quick lookup table for codes up to QUICK_BITS
30    quick_table: Vec<HuffmanEntry>,
31    /// Sorted symbols for longer codes
32    symbols: Vec<u16>,
33    /// Code length counts
34    length_counts: [u16; MAX_CODE_LENGTH + 1],
35    /// First code value for each length
36    first_code: [u32; MAX_CODE_LENGTH + 1],
37    /// First symbol index for each length
38    first_symbol: [u16; MAX_CODE_LENGTH + 1],
39}
40
41/// Bits for quick lookup table.
42const QUICK_BITS: u32 = 10;
43const QUICK_SIZE: usize = 1 << QUICK_BITS;
44
45impl HuffmanTable {
46    /// Create a new Huffman table from code lengths.
47    pub fn new(lengths: &[u8]) -> Result<Self> {
48        let mut table = Self {
49            quick_table: vec![HuffmanEntry::default(); QUICK_SIZE],
50            symbols: vec![0; lengths.len()],
51            length_counts: [0; MAX_CODE_LENGTH + 1],
52            first_code: [0; MAX_CODE_LENGTH + 1],
53            first_symbol: [0; MAX_CODE_LENGTH + 1],
54        };
55
56        // Count code lengths
57        for &len in lengths {
58            if len > 0 && (len as usize) <= MAX_CODE_LENGTH {
59                table.length_counts[len as usize] += 1;
60            }
61        }
62
63        // Calculate first code for each length (canonical Huffman)
64        let mut code = 0u32;
65        for i in 1..=MAX_CODE_LENGTH {
66            code = (code + table.length_counts[i - 1] as u32) << 1;
67            table.first_code[i] = code;
68        }
69
70        // Calculate first symbol index for each length
71        let mut idx = 0u16;
72        for i in 1..=MAX_CODE_LENGTH {
73            table.first_symbol[i] = idx;
74            idx += table.length_counts[i];
75        }
76
77        // Build symbol list sorted by code
78        let mut indices = table.first_symbol;
79        for (symbol, &len) in lengths.iter().enumerate() {
80            if len > 0 && (len as usize) <= MAX_CODE_LENGTH {
81                let i = indices[len as usize] as usize;
82                if i < table.symbols.len() {
83                    table.symbols[i] = symbol as u16;
84                    indices[len as usize] += 1;
85                }
86            }
87        }
88
89        // Build quick lookup table
90        for (symbol, &len) in lengths.iter().enumerate() {
91            if len > 0 && len as u32 <= QUICK_BITS {
92                let len = len as u32;
93                // Calculate the canonical code for this symbol
94                let symbol_idx = table.symbols[..table.first_symbol[len as usize + 1] as usize]
95                    .iter()
96                    .position(|&s| s == symbol as u16);
97                
98                if let Some(idx) = symbol_idx {
99                    let code = table.first_code[len as usize] + idx as u32 
100                        - table.first_symbol[len as usize] as u32;
101                    
102                    // Fill all table entries that start with this code
103                    let fill_bits = QUICK_BITS - len;
104                    let start = (code << fill_bits) as usize;
105                    let count = 1 << fill_bits;
106                    
107                    for j in 0..count {
108                        let entry_idx = start + j;
109                        if entry_idx < QUICK_SIZE {
110                            table.quick_table[entry_idx] = HuffmanEntry {
111                                symbol: symbol as u16,
112                                length: len as u8,
113                            };
114                        }
115                    }
116                }
117            }
118        }
119
120        Ok(table)
121    }
122
123    /// Debug: dump the canonical codes for each symbol
124    #[cfg(test)]
125    pub fn dump_codes(&self, name: &str, lengths: &[u8]) {
126        eprintln!("=== {} Huffman codes ===", name);
127        eprintln!("length_counts: {:?}", &self.length_counts[1..=5]);
128        eprintln!("first_code: {:?}", &self.first_code[1..=5]);
129        eprintln!("first_symbol: {:?}", &self.first_symbol[1..=5]);
130        eprintln!("symbols: {:?}", &self.symbols);
131        
132        for (symbol, &len) in lengths.iter().enumerate() {
133            if len > 0 && (len as usize) <= MAX_CODE_LENGTH {
134                // Find where this symbol is in the sorted list
135                let first_sym = self.first_symbol[len as usize] as usize;
136                let count = self.length_counts[len as usize] as usize;
137                let end = first_sym + count;
138                
139                for i in first_sym..end {
140                    if i < self.symbols.len() && self.symbols[i] == symbol as u16 {
141                        let code = self.first_code[len as usize] + (i as u32 - self.first_symbol[len as usize] as u32);
142                        // Print code in binary with proper length padding
143                        let code_str: String = format!("{:0width$b}", code, width = len as usize);
144                        eprintln!("  symbol {:>2}: len={}, code={}", symbol, len, code_str);
145                        break;
146                    }
147                }
148            }
149        }
150    }
151
152    /// Decode a symbol from the bit reader.
153    pub fn decode(&self, reader: &mut BitReader) -> Result<u16> {
154        let bits = reader.peek_bits(QUICK_BITS);
155        let entry = &self.quick_table[bits as usize];
156        
157        if entry.length > 0 {
158            reader.advance_bits(entry.length as u32);
159            return Ok(entry.symbol);
160        }
161
162        // Slow path for longer codes
163        let code = reader.peek_bits(MAX_CODE_LENGTH as u32);
164        
165        for len in (QUICK_BITS as usize + 1)..=MAX_CODE_LENGTH {
166            let shift = MAX_CODE_LENGTH - len;
167            let masked = code >> shift;
168            
169            if masked >= self.first_code[len] {
170                let count = self.length_counts[len] as u32;
171                let first = self.first_code[len];
172                
173                if masked < first + count {
174                    let idx = self.first_symbol[len] as u32 + (masked - first);
175                    if (idx as usize) < self.symbols.len() {
176                        reader.advance_bits(len as u32);
177                        return Ok(self.symbols[idx as usize]);
178                    }
179                }
180            }
181        }
182
183        Err(DecompressError::InvalidHuffmanCode)
184    }
185}
186
187/// Huffman decoder that can read code lengths from the stream.
188pub struct HuffmanDecoder {
189    /// Main code table (literals + lengths)
190    pub main_table: Option<HuffmanTable>,
191    /// Distance/offset table
192    pub dist_table: Option<HuffmanTable>,
193    /// Low distance table
194    pub low_dist_table: Option<HuffmanTable>,
195    /// Length table
196    pub len_table: Option<HuffmanTable>,
197    /// Stored length table for incremental updates
198    length_table: [u8; HUFFMAN_TABLE_SIZE],
199}
200
201impl HuffmanDecoder {
202    pub fn new() -> Self {
203        Self {
204            main_table: None,
205            dist_table: None,
206            low_dist_table: None,
207            len_table: None,
208            length_table: [0; HUFFMAN_TABLE_SIZE],
209        }
210    }
211
212    /// Reset the length table.
213    pub fn reset_tables(&mut self) {
214        self.length_table = [0; HUFFMAN_TABLE_SIZE];
215    }
216
217    /// Read code lengths from the bit stream and build tables.
218    /// This matches the RAR3/4 format.
219    pub fn read_tables(&mut self, reader: &mut BitReader) -> Result<()> {
220        // Read reset flag - if 0, we keep previous length table
221        let reset_tables = reader.read_bit()?;
222        if reset_tables {
223            self.length_table = [0; HUFFMAN_TABLE_SIZE];
224        }
225
226        #[cfg(test)]
227        eprintln!("reset_tables={}, bit_pos={}", reset_tables, reader.bit_position());
228
229        self.read_tables_inner(reader)
230    }
231
232    /// Read tables after header bits have been consumed.
233    pub fn read_tables_after_header(&mut self, reader: &mut BitReader) -> Result<()> {
234        self.read_tables_inner(reader)
235    }
236
237    /// Internal table reading.
238    fn read_tables_inner(&mut self, reader: &mut BitReader) -> Result<()> {
239        // Read bit lengths for the precode (20 symbols, 4 bits each)
240        let mut precode_lengths = [0u8; 20];
241        let mut i = 0;
242        while i < 20 {
243            let len = reader.read_bits(4)? as u8;
244            if len == 0x0F {
245                // Special case: zero run
246                let zero_count = reader.read_bits(4)? as usize;
247                if zero_count > 0 {
248                    for _ in 0..(zero_count + 2).min(20 - i) {
249                        precode_lengths[i] = 0;
250                        i += 1;
251                    }
252                    continue;
253                }
254            }
255            precode_lengths[i] = len;
256            i += 1;
257        }
258
259        #[cfg(test)]
260        eprintln!("precode_lengths={:?}", precode_lengths);
261
262        let precode_table = HuffmanTable::new(&precode_lengths)?;
263
264        // Read main length table using precode
265        i = 0;
266        #[cfg(test)]
267        let mut sym_count = 0;
268        while i < HUFFMAN_TABLE_SIZE {
269            let sym = precode_table.decode(reader)?;
270            
271            #[cfg(test)]
272            {
273                if sym_count < 30 {
274                    eprint!("sym[{}]={} ", i, sym);
275                    sym_count += 1;
276                }
277            }
278            
279            if sym < 16 {
280                // Add to previous value (mod 16)
281                self.length_table[i] = (self.length_table[i] + sym as u8) & 0x0F;
282                i += 1;
283            } else if sym == 16 {
284                // Repeat previous length, count = 3 + 3bits
285                if i == 0 {
286                    return Err(DecompressError::InvalidHuffmanCode);
287                }
288                let count = 3 + reader.read_bits(3)? as usize;
289                let prev = self.length_table[i - 1];
290                for _ in 0..count.min(HUFFMAN_TABLE_SIZE - i) {
291                    self.length_table[i] = prev;
292                    i += 1;
293                }
294            } else if sym == 17 {
295                // Repeat previous length, count = 11 + 7bits
296                if i == 0 {
297                    return Err(DecompressError::InvalidHuffmanCode);
298                }
299                let count = 11 + reader.read_bits(7)? as usize;
300                let prev = self.length_table[i - 1];
301                for _ in 0..count.min(HUFFMAN_TABLE_SIZE - i) {
302                    self.length_table[i] = prev;
303                    i += 1;
304                }
305            } else if sym == 18 {
306                // Insert zeros, count = 3 + 3bits
307                let count = 3 + reader.read_bits(3)? as usize;
308                for _ in 0..count.min(HUFFMAN_TABLE_SIZE - i) {
309                    self.length_table[i] = 0;
310                    i += 1;
311                }
312            } else {
313                // sym == 19: Insert zeros, count = 11 + 7bits
314                let count = 11 + reader.read_bits(7)? as usize;
315                for _ in 0..count.min(HUFFMAN_TABLE_SIZE - i) {
316                    self.length_table[i] = 0;
317                    i += 1;
318                }
319            }
320        }
321        #[cfg(test)]
322        eprintln!();
323
324        #[cfg(test)]
325        eprintln!("length_table first 20: {:?}", &self.length_table[..20]);
326
327        // Build the four Huffman tables from length_table
328        let mut offset = 0;
329        
330        self.main_table = Some(HuffmanTable::new(&self.length_table[offset..offset + MAINCODE_SIZE])?);
331        offset += MAINCODE_SIZE;
332        
333        self.dist_table = Some(HuffmanTable::new(&self.length_table[offset..offset + OFFSETCODE_SIZE])?);
334        offset += OFFSETCODE_SIZE;
335        
336        #[cfg(test)]
337        {
338            let low_lengths = &self.length_table[offset..offset + LOWOFFSETCODE_SIZE];
339            eprintln!("low_dist_table lengths: {:?}", low_lengths);
340        }
341        
342        self.low_dist_table = Some(HuffmanTable::new(&self.length_table[offset..offset + LOWOFFSETCODE_SIZE])?);
343        
344        #[cfg(test)]
345        {
346            let low_lengths = &self.length_table[offset..offset + LOWOFFSETCODE_SIZE];
347            self.low_dist_table.as_ref().unwrap().dump_codes("low_dist", low_lengths);
348        }
349        
350        offset += LOWOFFSETCODE_SIZE;
351        
352        self.len_table = Some(HuffmanTable::new(&self.length_table[offset..offset + LENGTHCODE_SIZE])?);
353
354        Ok(())
355    }
356}
357
358impl Default for HuffmanDecoder {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_huffman_table_simple() {
370        // Simple table: 2 symbols with lengths [1, 1]
371        // Symbol 0 = code 0, Symbol 1 = code 1
372        let lengths = [1u8, 1];
373        let table = HuffmanTable::new(&lengths).unwrap();
374
375        let data = [0b10000000]; // First bit is 1 -> symbol 1
376        let mut reader = BitReader::new(&data);
377        assert_eq!(table.decode(&mut reader).unwrap(), 1);
378    }
379
380    #[test]
381    fn test_huffman_table_varying_lengths() {
382        // Symbol 0: length 1, code 0
383        // Symbol 1: length 2, code 10
384        // Symbol 2: length 2, code 11
385        let lengths = [1u8, 2, 2];
386        let table = HuffmanTable::new(&lengths).unwrap();
387
388        let data = [0b01011000]; // 0 (sym 0), 10 (sym 1), 11 (sym 2)
389        let mut reader = BitReader::new(&data);
390        
391        assert_eq!(table.decode(&mut reader).unwrap(), 0);
392        assert_eq!(table.decode(&mut reader).unwrap(), 1);
393        assert_eq!(table.decode(&mut reader).unwrap(), 2);
394    }
395}