use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
#[derive(Debug, Clone)]
pub(crate) struct HuffmanDecoder {
counts: [u32; 16],
symbols: Vec<u32>,
first_code: [u32; 16],
first_idx: [u32; 16],
max_length: u8,
single_symbol: Option<u32>,
}
impl HuffmanDecoder {
pub(crate) fn single(sym: u32) -> Self {
Self {
counts: [0; 16],
symbols: Vec::new(),
first_code: [0; 16],
first_idx: [0; 16],
max_length: 0,
single_symbol: Some(sym),
}
}
pub(crate) fn from_lengths_sparse(pairs: &[(u32, u8)]) -> Result<Self, Error> {
let mut owned: Vec<(u32, u8)> = pairs.to_vec();
owned.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let mut counts = [0u32; 16];
let mut max_length = 0u8;
for &(_sym, len) in &owned {
if len == 0 || len > 15 {
return Err(Error::InvalidHuffmanTree);
}
counts[len as usize] += 1;
if len > max_length {
max_length = len;
}
}
if max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut kraft: u32 = 0;
for l in 1..=15u32 {
kraft += counts[l as usize] << (15 - l);
}
if kraft != (1 << 15) {
return Err(Error::InvalidHuffmanTree);
}
let mut first_code = [0u32; 16];
let mut first_idx = [0u32; 16];
let mut code: u32 = 0;
let mut idx: u32 = 0;
for l in 1..=15 {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l];
idx += counts[l];
}
let mut symbols = vec![0u32; owned.len()];
let mut next = first_idx;
for &(sym, len) in &owned {
let slot = next[len as usize] as usize;
symbols[slot] = sym;
next[len as usize] += 1;
}
Ok(Self {
counts,
symbols,
first_code,
first_idx,
max_length,
single_symbol: None,
})
}
pub(crate) fn from_lengths(lengths: &[u8]) -> Result<Self, Error> {
let mut pairs: Vec<(u32, u8)> = Vec::new();
for (i, &l) in lengths.iter().enumerate() {
if l > 0 {
pairs.push((i as u32, l));
}
}
if pairs.is_empty() {
return Err(Error::InvalidHuffmanTree);
}
if pairs.len() == 1 {
if pairs[0].1 == 1 {
return Err(Error::InvalidHuffmanTree);
}
return Err(Error::InvalidHuffmanTree);
}
Self::from_lengths_sparse(&pairs)
}
pub(crate) fn from_lengths_allow_single(lengths: &[u8]) -> Result<Self, Error> {
let nonzero = lengths.iter().filter(|&&l| l > 0).count();
if nonzero == 1 {
let sym = lengths.iter().position(|&l| l > 0).unwrap() as u32;
return Ok(Self::single(sym));
}
Self::from_lengths(lengths)
}
pub(crate) fn decode(&self, br: &mut BitSource<'_>) -> Result<u32, Error> {
if let Some(s) = self.single_symbol {
return Ok(s);
}
if self.max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let max = self.max_length as u32;
let mut code: u32 = 0;
for length in 1..=max {
let bit = br.read_bit()?;
code = (code << 1) | bit;
let count = self.counts[length as usize];
if count > 0 {
let first = self.first_code[length as usize];
if code >= first && code < first + count {
let sym_idx = self.first_idx[length as usize] + (code - first);
return Ok(self.symbols[sym_idx as usize]);
}
}
}
Err(Error::InvalidHuffmanTree)
}
}
#[derive(Debug)]
pub(crate) struct BitSource<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> BitSource<'a> {
pub(crate) fn at(data: &'a [u8], pos: usize) -> Self {
Self { data, pos }
}
pub(crate) fn position(&self) -> usize {
self.pos
}
pub(crate) fn set_position(&mut self, p: usize) {
self.pos = p;
}
pub(crate) fn remaining(&self) -> usize {
self.data.len() * 8 - self.pos
}
pub(crate) fn read_bit(&mut self) -> Result<u32, Error> {
if self.pos >= self.data.len() * 8 {
return Err(Error::UnexpectedEnd);
}
let byte = self.data[self.pos >> 3];
let bit = (byte >> (self.pos & 7)) & 1;
self.pos += 1;
Ok(bit as u32)
}
pub(crate) fn read_bits(&mut self, n: u32) -> Result<u32, Error> {
debug_assert!(n <= 32);
if n == 0 {
return Ok(0);
}
if self.remaining() < n as usize {
return Err(Error::UnexpectedEnd);
}
let mut acc: u32 = 0;
let mut got: u32 = 0;
while got < n {
let byte_pos = self.pos >> 3;
let bit_off = (self.pos & 7) as u32;
let take = (8 - bit_off).min(n - got);
let mask: u32 = if take == 32 {
u32::MAX
} else {
(1u32 << take) - 1
};
let chunk = ((self.data[byte_pos] as u32) >> bit_off) & mask;
acc |= chunk << got;
got += take;
self.pos += take as usize;
}
Ok(acc)
}
pub(crate) fn align_to_byte(&mut self) {
let r = self.pos & 7;
if r != 0 {
self.pos += 8 - r;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_symbol_zero_bits() {
let d = HuffmanDecoder::single(42);
let data = [0u8; 1];
let mut src = BitSource::at(&data, 0);
assert_eq!(d.decode(&mut src).unwrap(), 42);
assert_eq!(src.position(), 0);
}
#[test]
fn two_symbols_one_bit_each() {
let d = HuffmanDecoder::from_lengths_sparse(&[(0, 1), (1, 1)]).unwrap();
let data = [0b1010_1010u8];
let mut src = BitSource::at(&data, 0);
assert_eq!(d.decode(&mut src).unwrap(), 0);
assert_eq!(d.decode(&mut src).unwrap(), 1);
}
#[test]
fn read_bits_lsb_first() {
let data = [0b1011_0100u8, 0b0000_0001];
let mut src = BitSource::at(&data, 0);
assert_eq!(src.read_bits(4).unwrap(), 4);
assert_eq!(src.read_bits(8).unwrap(), 0x1B);
}
}