1use super::bitio::BitReader;
8use super::error::{JpegError, Result};
9
10pub struct HuffmanDecodeTable {
15 fast: [(u8, u8); 256],
18 slow: Vec<(u16, u8, u8)>,
20 max_len: u8,
22}
23
24impl HuffmanDecodeTable {
25 pub fn build(bits: &[u8; 16], huffval: &[u8]) -> Result<Self> {
30 let mut fast = [(0u8, 0u8); 256];
31 let mut slow = Vec::new();
32 let mut max_len = 0u8;
33
34 let mut code: u32 = 0;
36 let mut si = 0; for length in 1..=16u8 {
39 let count = bits[(length - 1) as usize] as usize;
40 for _ in 0..count {
41 if si >= huffval.len() {
42 return Err(JpegError::InvalidMarkerData("DHT symbol count mismatch"));
43 }
44 let symbol = huffval[si];
45 si += 1;
46 max_len = length;
47
48 if length <= 8 {
49 let base = (code << (8 - length)) as usize;
52 let fill = 1usize << (8 - length);
53 for j in 0..fill {
54 fast[base + j] = (symbol, length);
55 }
56 } else {
57 slow.push((code as u16, length, symbol));
58 }
59 code += 1;
60 }
61 code <<= 1;
62 }
63
64 Ok(Self {
65 fast,
66 slow,
67 max_len,
68 })
69 }
70
71 pub fn decode(&self, reader: &mut BitReader) -> Result<u8> {
73 let peek_len = 8.min(self.max_len.max(1));
75 let peek = reader.peek_bits(peek_len)?;
76 let idx = if self.max_len >= 8 {
77 peek as usize
78 } else {
79 (peek << (8 - self.max_len)) as usize
80 };
81
82 let (symbol, length) = self.fast[idx];
83 if length > 0 {
84 reader.skip_bits(length);
85 return Ok(symbol);
86 }
87
88 self.decode_slow(reader)
90 }
91
92 fn decode_slow(&self, reader: &mut BitReader) -> Result<u8> {
93 for &(code, length, symbol) in &self.slow {
95 let bits = reader.peek_bits(length)?;
96 if bits == code {
97 reader.skip_bits(length);
98 return Ok(symbol);
99 }
100 }
101 Err(JpegError::HuffmanDecode)
102 }
103}
104
105pub struct HuffmanEncodeTable {
107 table: [(u16, u8); 256],
110}
111
112impl HuffmanEncodeTable {
113 pub fn build(bits: &[u8; 16], huffval: &[u8]) -> Self {
115 let mut table = [(0u16, 0u8); 256];
116 let mut code: u32 = 0;
117 let mut si = 0;
118
119 for length in 1..=16u8 {
120 let count = bits[(length - 1) as usize] as usize;
121 for _ in 0..count {
122 if si < huffval.len() {
123 let symbol = huffval[si] as usize;
124 table[symbol] = (code as u16, length);
125 si += 1;
126 }
127 code += 1;
128 }
129 code <<= 1;
130 }
131
132 Self { table }
133 }
134
135 pub fn encode(&self, symbol: u8) -> Result<(u16, u8)> {
138 let (code, len) = self.table[symbol as usize];
139 if len == 0 {
140 Err(JpegError::InvalidMarkerData(
141 "Huffman table missing code for symbol",
142 ))
143 } else {
144 Ok((code, len))
145 }
146 }
147}
148
149pub fn extend_sign(value: u16, bits: u8) -> i16 {
153 if bits == 0 {
154 return 0;
155 }
156 let half = 1i32 << (bits - 1);
157 if (value as i32) < half {
158 (value as i32 - (1i32 << bits) + 1) as i16
160 } else {
161 value as i16
162 }
163}
164
165pub fn encode_value(value: i16) -> (u16, u8) {
168 if value == 0 {
169 return (0, 0);
170 }
171 let abs = value.unsigned_abs();
172 let size = 16 - abs.leading_zeros() as u8;
173 let bits = if value > 0 {
174 value as u16
175 } else {
176 (value - 1) as u16
178 };
179 (bits & ((1u16 << size) - 1), size)
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 fn lum_dc_table() -> ([u8; 16], Vec<u8>) {
188 let bits = [0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0];
189 let vals = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
190 (bits, vals)
191 }
192
193 #[test]
194 fn build_decode_table() {
195 let (bits, vals) = lum_dc_table();
196 let table = HuffmanDecodeTable::build(&bits, &vals).unwrap();
197 assert!(table.max_len <= 16);
198 }
199
200 #[test]
201 fn encode_decode_roundtrip() {
202 let (bits, vals) = lum_dc_table();
203 let enc = HuffmanEncodeTable::build(&bits, &vals);
204 let dec = HuffmanDecodeTable::build(&bits, &vals).unwrap();
205
206 for &sym in &vals {
208 let (code, len) = enc.encode(sym).unwrap();
209
210 let mut byte_data = vec![0u8; 4];
212 let shifted = (code as u32) << (32 - len);
214 byte_data[0] = (shifted >> 24) as u8;
215 byte_data[1] = (shifted >> 16) as u8;
216 byte_data[2] = (shifted >> 8) as u8;
217 byte_data[3] = shifted as u8;
218
219 let mut stuffed = Vec::new();
221 for &b in &byte_data {
222 stuffed.push(b);
223 if b == 0xFF {
224 stuffed.push(0x00);
225 }
226 }
227
228 let mut reader = BitReader::new(&stuffed, 0);
229 let decoded = dec.decode(&mut reader).unwrap();
230 assert_eq!(decoded, sym, "symbol {sym} round-trip failed");
231 }
232 }
233
234 #[test]
235 fn extend_sign_values() {
236 assert_eq!(extend_sign(0, 1), -1);
238 assert_eq!(extend_sign(1, 1), 1);
239
240 assert_eq!(extend_sign(0, 3), -7);
242 assert_eq!(extend_sign(3, 3), -4);
243 assert_eq!(extend_sign(4, 3), 4);
244 assert_eq!(extend_sign(7, 3), 7);
245
246 assert_eq!(extend_sign(0, 0), 0);
248 }
249
250 #[test]
251 fn encode_value_roundtrip() {
252 for v in -255i16..=255 {
253 let (bits, size) = encode_value(v);
254 if v == 0 {
255 assert_eq!(size, 0);
256 } else {
257 let recovered = extend_sign(bits, size);
258 assert_eq!(recovered, v, "round-trip failed for {v}");
259 }
260 }
261 }
262}