use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use super::bits::BitReader;
const MAX_CODE_LEN: u8 = 15;
#[derive(Debug, Clone)]
pub struct Huffman {
counts: [u16; (MAX_CODE_LEN as usize) + 1],
symbols: Vec<u16>,
first_code: [u32; (MAX_CODE_LEN as usize) + 1],
first_idx: [u16; (MAX_CODE_LEN as usize) + 1],
max_length: u8,
}
impl Huffman {
pub fn from_lengths(lengths: &[u8]) -> Result<Self, Error> {
let mut counts = [0u16; (MAX_CODE_LEN as usize) + 1];
let mut max_length: u8 = 0;
for &len in lengths {
if len > MAX_CODE_LEN {
return Err(Error::InvalidHuffmanTree);
}
if len > 0 {
counts[len as usize] += 1;
if len > max_length {
max_length = len;
}
}
}
let symbols_cap = lengths.len();
if max_length == 0 {
return Ok(Self {
counts,
symbols: vec![0u16; symbols_cap],
first_code: [0u32; (MAX_CODE_LEN as usize) + 1],
first_idx: [0u16; (MAX_CODE_LEN as usize) + 1],
max_length: 0,
});
}
let mut kraft: u32 = 0;
for l in 1..=(MAX_CODE_LEN as u32) {
kraft += (counts[l as usize] as u32) << (MAX_CODE_LEN as u32 - l);
}
if kraft > (1u32 << MAX_CODE_LEN) {
return Err(Error::InvalidHuffmanTree);
}
let mut first_code = [0u32; (MAX_CODE_LEN as usize) + 1];
let mut first_idx = [0u16; (MAX_CODE_LEN as usize) + 1];
let mut code: u32 = 0;
let mut idx: u16 = 0;
for l in 1..=(MAX_CODE_LEN as usize) {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l] as u32;
idx += counts[l];
}
let mut symbols = vec![0u16; symbols_cap];
let mut next = first_idx;
for (sym, &len) in lengths.iter().enumerate() {
if len > 0 {
let slot = next[len as usize] as usize;
symbols[slot] = sym as u16;
next[len as usize] += 1;
}
}
Ok(Self {
counts,
symbols,
first_code,
first_idx,
max_length,
})
}
pub fn decode(&self, reader: &mut BitReader) -> Result<u16, Error> {
if self.max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let max = self.max_length as u32;
let lookahead = self.peek_padded(reader, max)?;
for length in 1..=max {
let code = lookahead >> (max - length);
let count = self.counts[length as usize] as u32;
if count > 0 {
let first = self.first_code[length as usize];
if code >= first && code < first + count {
let slot = self.first_idx[length as usize] as u32 + (code - first);
if (slot as usize) >= self.symbols.len() {
return Err(Error::InvalidHuffmanTree);
}
reader.drop_bits(length)?;
return Ok(self.symbols[slot as usize]);
}
}
}
Err(Error::InvalidHuffmanTree)
}
fn peek_padded(&self, reader: &mut BitReader, n: u32) -> Result<u32, Error> {
match reader.peek(n) {
Ok(v) => Ok(v),
Err(_) => {
let mut try_n = n;
while try_n > 0 {
try_n -= 1;
if let Ok(v) = reader.peek(try_n) {
return Ok(v << (n - try_n));
}
}
Err(Error::UnexpectedEnd)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
extern crate std;
#[test]
fn round_trip_simple_tree() {
let dec = Huffman::from_lengths(&[2, 1, 3, 3]).unwrap();
let mut r = BitReader::new();
r.feed_slice(&[0x5F, 0x00]);
assert_eq!(dec.decode(&mut r).unwrap(), 1);
assert_eq!(dec.decode(&mut r).unwrap(), 0);
assert_eq!(dec.decode(&mut r).unwrap(), 3);
assert_eq!(dec.decode(&mut r).unwrap(), 2);
}
#[test]
fn invalid_lengths_rejected() {
assert!(Huffman::from_lengths(&[16, 0, 0, 0]).is_err());
assert!(Huffman::from_lengths(&[1, 1, 2]).is_err());
}
#[test]
fn empty_tree_rejects_decode() {
let dec = Huffman::from_lengths(&[0, 0, 0, 0]).unwrap();
let mut r = BitReader::new();
r.feed_slice(&[0xFF]);
assert!(matches!(dec.decode(&mut r), Err(Error::InvalidHuffmanTree)));
}
}