Skip to main content

phasm_core/codec/jpeg/
huffman.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Huffman coding tables for JPEG entropy decoding and encoding.
6
7use super::bitio::BitReader;
8use super::error::{JpegError, Result};
9
10/// Huffman decode table with two-level lookup.
11///
12/// Level 1: 8-bit fast lookup table (covers most codes).
13/// Level 2: slow path for codes longer than 8 bits.
14pub struct HuffmanDecodeTable {
15    /// Fast lookup: indexed by top 8 bits of the code stream.
16    /// Each entry: (symbol, code_length). If code_length == 0, use slow path.
17    fast: [(u8, u8); 256],
18    /// For codes > 8 bits: (code, length, symbol) sorted by (length, code).
19    slow: Vec<(u16, u8, u8)>,
20    /// Maximum code length in this table.
21    max_len: u8,
22}
23
24impl HuffmanDecodeTable {
25    /// Build a decode table from JPEG-style counts and symbols.
26    ///
27    /// `bits`: counts[i] = number of codes of length i+1 (16 entries).
28    /// `huffval`: the symbols, in order of increasing code length.
29    pub fn build(bits: &[u8; 16], huffval: &[u8]) -> Result<Self> {
30        let mut fast = [(0u8, 0u8); 256];
31        let mut slow = Vec::new();
32        let mut max_len = 0u8;
33
34        // Generate canonical Huffman codes per ITU-T T.81 Annex C
35        let mut code: u32 = 0;
36        let mut si = 0; // symbol index into huffval
37
38        for length in 1..=16u8 {
39            let count = bits[(length - 1) as usize] as usize;
40            for _ in 0..count {
41                if si >= huffval.len() {
42                    return Err(JpegError::InvalidMarkerData("DHT symbol count mismatch"));
43                }
44                let symbol = huffval[si];
45                si += 1;
46                max_len = length;
47
48                if length <= 8 {
49                    // Fill fast table: this code, left-shifted to 8 bits,
50                    // covers 2^(8-length) entries
51                    let base = (code << (8 - length)) as usize;
52                    let fill = 1usize << (8 - length);
53                    for j in 0..fill {
54                        fast[base + j] = (symbol, length);
55                    }
56                } else {
57                    slow.push((code as u16, length, symbol));
58                }
59                code += 1;
60            }
61            code <<= 1;
62        }
63
64        Ok(Self {
65            fast,
66            slow,
67            max_len,
68        })
69    }
70
71    /// Decode one Huffman symbol from the bit stream.
72    pub fn decode(&self, reader: &mut BitReader) -> Result<u8> {
73        // Peek up to 8 bits for fast table lookup
74        let peek_len = 8.min(self.max_len.max(1));
75        let peek = reader.peek_bits(peek_len)?;
76        let idx = if self.max_len >= 8 {
77            peek as usize
78        } else {
79            (peek << (8 - self.max_len)) as usize
80        };
81
82        let (symbol, length) = self.fast[idx];
83        if length > 0 {
84            reader.skip_bits(length);
85            return Ok(symbol);
86        }
87
88        // Slow path: try longer codes
89        self.decode_slow(reader)
90    }
91
92    fn decode_slow(&self, reader: &mut BitReader) -> Result<u8> {
93        // Read up to max_len bits and try to match
94        for &(code, length, symbol) in &self.slow {
95            let bits = reader.peek_bits(length)?;
96            if bits == code {
97                reader.skip_bits(length);
98                return Ok(symbol);
99            }
100        }
101        Err(JpegError::HuffmanDecode)
102    }
103}
104
105/// Huffman encode table: maps symbol → (code_bits, code_length).
106pub struct HuffmanEncodeTable {
107    /// For each of the 256 possible symbols: (code, length).
108    /// Length 0 means the symbol is not in the table.
109    table: [(u16, u8); 256],
110}
111
112impl HuffmanEncodeTable {
113    /// Build an encode table from JPEG-style counts and symbols.
114    pub fn build(bits: &[u8; 16], huffval: &[u8]) -> Self {
115        let mut table = [(0u16, 0u8); 256];
116        let mut code: u32 = 0;
117        let mut si = 0;
118
119        for length in 1..=16u8 {
120            let count = bits[(length - 1) as usize] as usize;
121            for _ in 0..count {
122                if si < huffval.len() {
123                    let symbol = huffval[si] as usize;
124                    table[symbol] = (code as u16, length);
125                    si += 1;
126                }
127                code += 1;
128            }
129            code <<= 1;
130        }
131
132        Self { table }
133    }
134
135    /// Encode a symbol: returns (code_bits, code_length).
136    /// Returns `Err` if the symbol has no code in this table.
137    pub fn encode(&self, symbol: u8) -> Result<(u16, u8)> {
138        let (code, len) = self.table[symbol as usize];
139        if len == 0 {
140            Err(JpegError::InvalidMarkerData(
141                "Huffman table missing code for symbol",
142            ))
143        } else {
144            Ok((code, len))
145        }
146    }
147}
148
149/// Extend a signed value from its JPEG "additional bits" representation.
150///
151/// Per ITU-T T.81 Table F.1: if the high bit is 0, the value is negative.
152pub fn extend_sign(value: u16, bits: u8) -> i16 {
153    if bits == 0 {
154        return 0;
155    }
156    let half = 1i32 << (bits - 1);
157    if (value as i32) < half {
158        // Negative value
159        (value as i32 - (1i32 << bits) + 1) as i16
160    } else {
161        value as i16
162    }
163}
164
165/// Encode a signed value into JPEG "additional bits" representation.
166/// Returns (magnitude_bits, category/size).
167pub fn encode_value(value: i16) -> (u16, u8) {
168    if value == 0 {
169        return (0, 0);
170    }
171    let abs = value.unsigned_abs();
172    let size = 16 - abs.leading_zeros() as u8;
173    let bits = if value > 0 {
174        value as u16
175    } else {
176        // For negative values, JPEG uses one's complement
177        (value - 1) as u16
178    };
179    (bits & ((1u16 << size) - 1), size)
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    // Standard JPEG luminance DC Huffman table (ITU-T T.81 Table K.3)
187    fn lum_dc_table() -> ([u8; 16], Vec<u8>) {
188        let bits = [0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0];
189        let vals = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
190        (bits, vals)
191    }
192
193    #[test]
194    fn build_decode_table() {
195        let (bits, vals) = lum_dc_table();
196        let table = HuffmanDecodeTable::build(&bits, &vals).unwrap();
197        assert!(table.max_len <= 16);
198    }
199
200    #[test]
201    fn encode_decode_roundtrip() {
202        let (bits, vals) = lum_dc_table();
203        let enc = HuffmanEncodeTable::build(&bits, &vals);
204        let dec = HuffmanDecodeTable::build(&bits, &vals).unwrap();
205
206        // Encode all symbols, then decode and verify
207        for &sym in &vals {
208            let (code, len) = enc.encode(sym).unwrap();
209
210            // Create a bit stream from this code
211            let mut byte_data = vec![0u8; 4];
212            // Place code in the top bits of the first bytes
213            let shifted = (code as u32) << (32 - len);
214            byte_data[0] = (shifted >> 24) as u8;
215            byte_data[1] = (shifted >> 16) as u8;
216            byte_data[2] = (shifted >> 8) as u8;
217            byte_data[3] = shifted as u8;
218
219            // Handle byte-stuffing: if any byte is 0xFF, we need 0x00 after it
220            let mut stuffed = Vec::new();
221            for &b in &byte_data {
222                stuffed.push(b);
223                if b == 0xFF {
224                    stuffed.push(0x00);
225                }
226            }
227
228            let mut reader = BitReader::new(&stuffed, 0);
229            let decoded = dec.decode(&mut reader).unwrap();
230            assert_eq!(decoded, sym, "symbol {sym} round-trip failed");
231        }
232    }
233
234    #[test]
235    fn extend_sign_values() {
236        // Category 1: value 0 → -1, value 1 → +1
237        assert_eq!(extend_sign(0, 1), -1);
238        assert_eq!(extend_sign(1, 1), 1);
239
240        // Category 3: values 0–3 → -7 to -4, values 4–7 → +4 to +7
241        assert_eq!(extend_sign(0, 3), -7);
242        assert_eq!(extend_sign(3, 3), -4);
243        assert_eq!(extend_sign(4, 3), 4);
244        assert_eq!(extend_sign(7, 3), 7);
245
246        // Category 0
247        assert_eq!(extend_sign(0, 0), 0);
248    }
249
250    #[test]
251    fn encode_value_roundtrip() {
252        for v in -255i16..=255 {
253            let (bits, size) = encode_value(v);
254            if v == 0 {
255                assert_eq!(size, 0);
256            } else {
257                let recovered = extend_sign(bits, size);
258                assert_eq!(recovered, v, "round-trip failed for {v}");
259            }
260        }
261    }
262}