use crate::error::CodecError;
use crate::jpeg::bitstream::BitStream;
use crate::jpeg::types::QuantTable;
const MAX_CODE_LEN: usize = 16;
const FAST_BITS: u8 = 11;
const FAST_SIZE: usize = 1 << FAST_BITS;
#[derive(Debug, Clone)]
pub struct HuffmanTable {
fast: Vec<(u8, u8)>,
symbols: Vec<u8>,
max_code: [i32; MAX_CODE_LEN],
val_offset: [i32; MAX_CODE_LEN],
}
impl HuffmanTable {
pub fn build(counts: &[u8; 16], values: &[u8]) -> crate::Result<Self> {
let total: usize = counts.iter().map(|&c| c as usize).sum();
if values.len() < total {
return Err(CodecError::InvalidData(
"DHT: fewer values than count sum".into(),
));
}
let mut max_code = [-1i32; MAX_CODE_LEN];
let mut val_offset = [0i32; MAX_CODE_LEN];
let mut code: u32 = 0;
let mut si = 0usize;
for (i, &count) in counts.iter().enumerate() {
if count > 0 {
val_offset[i] = si as i32 - code as i32;
si += count as usize;
max_code[i] = (code + count as u32 - 1) as i32;
code += count as u32;
}
code <<= 1;
}
let mut fast = vec![(0u8, 0u8); FAST_SIZE];
code = 0;
si = 0;
for (i, &count) in counts.iter().enumerate() {
let bit_len = (i + 1) as u8;
for _ in 0..count {
if bit_len <= FAST_BITS {
let symbol = values[si];
let fill = 1 << (FAST_BITS - bit_len);
let base = (code << (FAST_BITS - bit_len)) as usize;
for j in 0..fill {
fast[base + j] = (symbol, bit_len);
}
}
code += 1;
si += 1;
}
code <<= 1;
}
Ok(Self {
fast,
symbols: values[..total].to_vec(),
max_code,
val_offset,
})
}
#[inline]
pub fn decode_symbol(&self, bs: &mut BitStream<'_>) -> crate::Result<u8> {
let peek = bs.peek(FAST_BITS) as usize;
let (symbol, len) = self.fast[peek];
if len > 0 {
bs.consume(len);
return Ok(symbol);
}
self.decode_slow(bs)
}
fn decode_slow(&self, bs: &mut BitStream<'_>) -> crate::Result<u8> {
let mut code = bs.peek(FAST_BITS) as i32;
bs.consume(FAST_BITS);
for i in (FAST_BITS as usize)..MAX_CODE_LEN {
let next_bit = bs.read_bits(1) as i32;
code = (code << 1) | next_bit;
if code <= self.max_code[i] {
let idx = (code + self.val_offset[i]) as usize;
if idx < self.symbols.len() {
return Ok(self.symbols[idx]);
}
}
}
Err(CodecError::InvalidData("JPEG: invalid Huffman code".into()))
}
}
#[inline]
pub fn decode_block(
bs: &mut BitStream<'_>,
dc_table: &HuffmanTable,
ac_table: &HuffmanTable,
quant: &QuantTable,
coeffs: &mut [i32; 64],
dc_pred: &mut i32,
) -> crate::Result<()> {
*coeffs = [0i32; 64];
let dc_size = dc_table.decode_symbol(bs)?;
if dc_size > 0 {
if dc_size > 11 {
return Err(CodecError::InvalidData("JPEG: DC size > 11".into()));
}
let dc_val = bs.read_bits(dc_size);
let dc_diff = BitStream::extend(dc_val, dc_size);
*dc_pred += dc_diff;
}
coeffs[0] = *dc_pred * quant.values[0] as i32;
let mut k = 1;
while k < 64 {
let symbol = ac_table.decode_symbol(bs)?;
if symbol == 0x00 {
break;
}
let run = (symbol >> 4) as usize;
let size = symbol & 0x0F;
if symbol == 0xF0 {
k += 16;
if k > 64 {
return Err(CodecError::InvalidData("JPEG: AC run past end".into()));
}
continue;
}
k += run;
if k >= 64 {
return Err(CodecError::InvalidData(
"JPEG: AC coefficient past end".into(),
));
}
if size > 0 {
if size > 10 {
return Err(CodecError::InvalidData(format!(
"JPEG: AC coefficient size {size} exceeds spec maximum (10)"
)));
}
let val = bs.read_bits(size);
let coeff = BitStream::extend(val, size);
let natural_idx = crate::jpeg::types::ZIGZAG[k] as usize;
coeffs[natural_idx] = coeff * quant.values[natural_idx] as i32;
}
k += 1;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_simple_table() {
let mut counts = [0u8; 16];
counts[0] = 2; let values = [0x00, 0x01];
let table = HuffmanTable::build(&counts, &values).unwrap();
let data = [0b1000_0000]; let mut bs = BitStream::new(&data, 0);
assert_eq!(table.decode_symbol(&mut bs).unwrap(), 1);
}
#[test]
fn roundtrip_dc_values() {
let mut counts = [0u8; 16];
counts[0] = 0;
counts[1] = 1;
counts[2] = 5;
counts[3] = 1;
counts[4] = 1;
counts[5] = 1;
counts[6] = 1;
counts[7] = 1;
counts[8] = 1;
let values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let table = HuffmanTable::build(&counts, &values).unwrap();
assert_eq!(table.symbols.len(), 12);
}
#[test]
fn reject_invalid_counts() {
let mut counts = [0u8; 16];
counts[0] = 5;
let values = [0u8; 2]; let result = HuffmanTable::build(&counts, &values);
assert!(result.is_err());
}
}