const FAST_BITS: u8 = 10;
const FAST_TABLE_SIZE: usize = 1 << FAST_BITS;
const MAX_TREE_NODES: usize = 512;
#[inline(always)]
fn pack_entry(symbol: u16, len: u8) -> u16 {
(symbol << 5) | len as u16
}
#[inline(always)]
fn unpack_symbol(entry: u16) -> u16 {
entry >> 5
}
#[inline(always)]
fn unpack_len(entry: u16) -> u8 {
(entry & 0x1F) as u8
}
pub struct HuffmanTree {
fast_table: [u16; FAST_TABLE_SIZE],
nodes: [[i32; 2]; MAX_TREE_NODES],
n_nodes: usize,
}
impl HuffmanTree {
pub const fn empty() -> Self {
Self {
fast_table: [0; FAST_TABLE_SIZE],
nodes: [[0; 2]; MAX_TREE_NODES],
n_nodes: 0,
}
}
pub fn from_lengths(lengths: &[u8]) -> Result<Self, &'static str> {
let mut symbols = [(0u16, 0u8); 258];
let mut n_sym = 0usize;
for (i, &len) in lengths.iter().enumerate() {
if len > 0 {
symbols[n_sym] = (i as u16, len);
n_sym += 1;
}
}
if n_sym == 0 {
return Err("no symbols");
}
symbols[..n_sym].sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let mut tree = Self::empty();
tree.n_nodes = 1;
let mut code: u32 = 0;
let mut prev_len: u8 = symbols[0].1;
let mut code_entries = [(0u32, 0u8, 0u16); 258];
let mut n_entries = 0usize;
for idx in 0..n_sym {
let (sym, len) = symbols[idx];
code <<= len - prev_len;
prev_len = len;
code_entries[n_entries] = (code, len, sym);
n_entries += 1;
let mut node_idx: usize = 0;
for bit_pos in (0..len).rev() {
let bit = ((code >> bit_pos) & 1) as usize;
let child = tree.nodes[node_idx][bit];
if child > 0 {
node_idx = child as usize;
} else if bit_pos > 0 {
let new_idx = tree.n_nodes;
if new_idx >= MAX_TREE_NODES {
return Err("huffman tree too large");
}
tree.nodes[node_idx][bit] = new_idx as i32;
tree.n_nodes = new_idx + 1;
node_idx = new_idx;
} else {
tree.nodes[node_idx][bit] = -(sym as i32 + 1);
}
}
code += 1;
}
for idx in 0..n_entries {
let (c, len, sym) = code_entries[idx];
if len <= FAST_BITS {
let pad = FAST_BITS - len;
let base = (c as usize) << pad;
let entry = pack_entry(sym, len);
for suffix in 0..(1usize << pad) {
tree.fast_table[base | suffix] = entry;
}
}
}
Ok(tree)
}
#[inline(always)]
pub fn decode(&self, reader: &mut super::bitreader::BitReader<'_>) -> Option<u16> {
if let Some(bits) = reader.peek(FAST_BITS) {
let entry = unsafe { *self.fast_table.get_unchecked(bits as usize) };
let len = unpack_len(entry);
if len > 0 {
reader.consume(len);
return Some(unpack_symbol(entry));
}
}
self.decode_slow(reader)
}
#[cold]
fn decode_slow(&self, reader: &mut super::bitreader::BitReader<'_>) -> Option<u16> {
let mut node_idx: usize = 0;
loop {
let bit = reader.read_bit()? as usize;
let child = self.nodes[node_idx][bit];
if child < 0 {
return Some((-child - 1) as u16);
}
node_idx = child as usize;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bitreader::BitReader;
#[test]
fn simple_tree() {
let tree = HuffmanTree::from_lengths(&[1, 2, 2]).unwrap();
let data = [0b0_10_11_0_00];
let mut r = BitReader::new(&data);
assert_eq!(tree.decode(&mut r), Some(0));
assert_eq!(tree.decode(&mut r), Some(1));
assert_eq!(tree.decode(&mut r), Some(2));
assert_eq!(tree.decode(&mut r), Some(0));
}
#[test]
fn fast_table_coverage() {
let tree = HuffmanTree::from_lengths(&[2, 2, 3, 3]).unwrap();
let data = [0b00_01_100_1, 0b01_000000];
let mut r = BitReader::new(&data);
assert_eq!(tree.decode(&mut r), Some(0));
assert_eq!(tree.decode(&mut r), Some(1));
assert_eq!(tree.decode(&mut r), Some(2));
assert_eq!(tree.decode(&mut r), Some(3));
}
}