use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
const PRIMARY_BITS: u32 = 9;
const PRIMARY_SIZE: usize = 1 << PRIMARY_BITS;
const LUT_LEN_SHIFT: u32 = 16;
const LUT_SYM_MASK: u32 = (1 << LUT_LEN_SHIFT) - 1;
#[derive(Debug, Clone)]
pub(crate) struct HuffmanDecoder {
counts: [u32; 16],
symbols: Vec<u32>,
first_code: [u32; 16],
first_idx: [u32; 16],
max_length: u8,
single_symbol: Option<u32>,
lut: alloc::boxed::Box<[u32; PRIMARY_SIZE]>,
}
impl HuffmanDecoder {
pub(crate) fn single(sym: u32) -> Self {
Self {
counts: [0; 16],
symbols: Vec::new(),
first_code: [0; 16],
first_idx: [0; 16],
max_length: 0,
single_symbol: Some(sym),
lut: alloc::boxed::Box::new([0u32; PRIMARY_SIZE]),
}
}
pub(crate) fn from_lengths_sparse(pairs: &[(u32, u8)]) -> Result<Self, Error> {
let mut owned: Vec<(u32, u8)> = pairs.to_vec();
owned.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let mut counts = [0u32; 16];
let mut max_length = 0u8;
for &(_sym, len) in &owned {
if len == 0 || len > 15 {
return Err(Error::InvalidHuffmanTree);
}
counts[len as usize] += 1;
if len > max_length {
max_length = len;
}
}
if max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut kraft: u32 = 0;
for l in 1..=15u32 {
kraft += counts[l as usize] << (15 - l);
}
if kraft != (1 << 15) {
return Err(Error::InvalidHuffmanTree);
}
let mut first_code = [0u32; 16];
let mut first_idx = [0u32; 16];
let mut code: u32 = 0;
let mut idx: u32 = 0;
for l in 1..=15 {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l];
idx += counts[l];
}
let mut symbols = vec![0u32; owned.len()];
let mut next = first_idx;
for &(sym, len) in &owned {
let slot = next[len as usize] as usize;
symbols[slot] = sym;
next[len as usize] += 1;
}
let mut lut = alloc::boxed::Box::new([0u32; PRIMARY_SIZE]);
let mut next_code = first_code;
for &(sym, len) in &owned {
let code = next_code[len as usize];
next_code[len as usize] += 1;
if (len as u32) > PRIMARY_BITS {
continue;
}
let reversed = reverse_bits_lo(code, len as u32);
let entry = sym | ((len as u32) << LUT_LEN_SHIFT);
let stride = 1usize << len;
let mut slot = reversed as usize;
while slot < PRIMARY_SIZE {
lut[slot] = entry;
slot += stride;
}
}
Ok(Self {
counts,
symbols,
first_code,
first_idx,
max_length,
single_symbol: None,
lut,
})
}
pub(crate) fn from_lengths(lengths: &[u8]) -> Result<Self, Error> {
let mut pairs: Vec<(u32, u8)> = Vec::new();
for (i, &l) in lengths.iter().enumerate() {
if l > 0 {
pairs.push((i as u32, l));
}
}
if pairs.is_empty() {
return Err(Error::InvalidHuffmanTree);
}
if pairs.len() == 1 {
if pairs[0].1 == 1 {
return Err(Error::InvalidHuffmanTree);
}
return Err(Error::InvalidHuffmanTree);
}
Self::from_lengths_sparse(&pairs)
}
pub(crate) fn from_lengths_allow_single(lengths: &[u8]) -> Result<Self, Error> {
let nonzero = lengths.iter().filter(|&&l| l > 0).count();
if nonzero == 1 {
let sym = lengths.iter().position(|&l| l > 0).unwrap() as u32;
return Ok(Self::single(sym));
}
Self::from_lengths(lengths)
}
pub(crate) fn decode(&self, br: &mut BitSource<'_>) -> Result<u32, Error> {
if let Some(s) = self.single_symbol {
return Ok(s);
}
if self.max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let max = self.max_length as u32;
if br.remaining() >= PRIMARY_BITS as usize {
let idx = br.peek_bits(PRIMARY_BITS) as usize;
let entry = self.lut[idx];
let len = entry >> LUT_LEN_SHIFT;
if len > 0 {
br.set_position(br.position() + len as usize);
return Ok(entry & LUT_SYM_MASK);
}
}
let mut code: u32 = 0;
for length in 1..=max {
let bit = br.read_bit()?;
code = (code << 1) | bit;
let count = self.counts[length as usize];
if count > 0 {
let first = self.first_code[length as usize];
if code >= first && code < first + count {
let sym_idx = self.first_idx[length as usize] + (code - first);
return Ok(self.symbols[sym_idx as usize]);
}
}
}
Err(Error::InvalidHuffmanTree)
}
}
const fn reverse_bits_lo(mut v: u32, n: u32) -> u32 {
let mut out = 0u32;
let mut i = 0;
while i < n {
out = (out << 1) | (v & 1);
v >>= 1;
i += 1;
}
out
}
#[derive(Debug)]
pub(crate) struct BitSource<'a> {
data: &'a [u8],
load_pos: usize,
acc: u64,
nbits: u32,
}
impl<'a> BitSource<'a> {
pub(crate) fn at(data: &'a [u8], pos: usize) -> Self {
Self {
data,
load_pos: pos,
acc: 0,
nbits: 0,
}
}
pub(crate) fn position(&self) -> usize {
self.load_pos - self.nbits as usize
}
pub(crate) fn set_position(&mut self, p: usize) {
self.load_pos = p;
self.acc = 0;
self.nbits = 0;
}
#[allow(dead_code)]
pub(crate) fn remaining(&self) -> usize {
(self.data.len() * 8 - self.load_pos) + self.nbits as usize
}
fn refill(&mut self) {
if (self.load_pos & 7) == 0 && self.nbits <= 56 {
let byte_pos = self.load_pos >> 3;
if byte_pos + 8 <= self.data.len() {
let bytes: [u8; 8] = self.data[byte_pos..byte_pos + 8]
.try_into()
.expect("8-byte slice");
let chunk = u64::from_le_bytes(bytes);
self.acc |= chunk << self.nbits;
let added = 64 - self.nbits;
self.load_pos += added as usize;
self.nbits = 64;
return;
}
}
while self.nbits <= 56 {
let byte_pos = self.load_pos >> 3;
if byte_pos >= self.data.len() {
break;
}
let bit_off = (self.load_pos & 7) as u32;
let take = 8 - bit_off;
let chunk = (self.data[byte_pos] as u64) >> bit_off;
self.acc |= chunk << self.nbits;
self.nbits += take;
self.load_pos += take as usize;
}
}
pub(crate) fn read_bit(&mut self) -> Result<u32, Error> {
if self.nbits == 0 {
self.refill();
if self.nbits == 0 {
return Err(Error::UnexpectedEnd);
}
}
let bit = (self.acc & 1) as u32;
self.acc >>= 1;
self.nbits -= 1;
Ok(bit)
}
pub(crate) fn peek_bits(&mut self, n: u32) -> u32 {
debug_assert!(n > 0 && n <= 32);
debug_assert!(n as usize <= self.remaining());
if self.nbits < n {
self.refill();
}
debug_assert!(self.nbits >= n);
if n == 32 {
self.acc as u32
} else {
(self.acc & ((1u64 << n) - 1)) as u32
}
}
pub(crate) fn read_bits(&mut self, n: u32) -> Result<u32, Error> {
debug_assert!(n <= 32);
if n == 0 {
return Ok(0);
}
if self.nbits < n {
self.refill();
if self.nbits < n {
return Err(Error::UnexpectedEnd);
}
}
let v = (self.acc & ((1u64 << n) - 1)) as u32;
self.acc >>= n;
self.nbits -= n;
Ok(v)
}
pub(crate) fn align_to_byte(&mut self) {
let r = (self.position() & 7) as u32;
if r != 0 {
let drop = 8 - r;
if drop <= self.nbits {
self.acc >>= drop;
self.nbits -= drop;
} else {
let extra = drop - self.nbits;
self.acc = 0;
self.nbits = 0;
self.load_pos += extra as usize;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_symbol_zero_bits() {
let d = HuffmanDecoder::single(42);
let data = [0u8; 1];
let mut src = BitSource::at(&data, 0);
assert_eq!(d.decode(&mut src).unwrap(), 42);
assert_eq!(src.position(), 0);
}
#[test]
fn two_symbols_one_bit_each() {
let d = HuffmanDecoder::from_lengths_sparse(&[(0, 1), (1, 1)]).unwrap();
let data = [0b1010_1010u8];
let mut src = BitSource::at(&data, 0);
assert_eq!(d.decode(&mut src).unwrap(), 0);
assert_eq!(d.decode(&mut src).unwrap(), 1);
}
#[test]
fn read_bits_lsb_first() {
let data = [0b1011_0100u8, 0b0000_0001];
let mut src = BitSource::at(&data, 0);
assert_eq!(src.read_bits(4).unwrap(), 4);
assert_eq!(src.read_bits(8).unwrap(), 0x1B);
}
#[test]
fn fast_path_byte_aligned_refill() {
let data: [u8; 9] = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0xFF];
let mut src = BitSource::at(&data, 0);
for &expected in &[0x01u32, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0xFF] {
assert_eq!(src.read_bits(8).unwrap(), expected);
}
assert_eq!(src.position(), 9 * 8);
}
#[test]
fn unaligned_start_refill() {
let data: [u8; 5] = [0xAB, 0xCD, 0xEF, 0x12, 0x34];
let mut src = BitSource::at(&data, 3);
assert_eq!(src.read_bits(5).unwrap(), 0xAB >> 3);
assert_eq!(src.read_bits(8).unwrap(), 0xCD);
assert_eq!(src.position(), 16);
}
#[test]
fn unexpected_end_short_input() {
let data = [0x55u8];
let mut src = BitSource::at(&data, 0);
let before = src.position();
assert!(src.read_bits(16).is_err());
assert_eq!(src.position(), before);
assert_eq!(src.read_bits(8).unwrap(), 0x55);
assert!(src.read_bit().is_err());
}
#[test]
fn set_position_rolls_back_accumulator() {
let data = [0xFFu8, 0x00, 0xAA, 0x55];
let mut src = BitSource::at(&data, 0);
let saved = src.position();
assert_eq!(src.read_bits(12).unwrap(), 0x0FF);
src.set_position(saved);
assert_eq!(src.read_bits(8).unwrap(), 0xFF);
assert_eq!(src.read_bits(8).unwrap(), 0x00);
}
#[test]
fn align_to_byte_drops_partial() {
let data = [0b1111_0000u8, 0b1010_1010];
let mut src = BitSource::at(&data, 0);
assert_eq!(src.read_bits(3).unwrap(), 0b000);
src.align_to_byte();
assert_eq!(src.position(), 8);
assert_eq!(src.read_bits(8).unwrap(), 0b1010_1010);
}
}