use crate::bits::BitReader;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct CanonicalDecoder<const N: usize> {
counts: [u16; 16],
symbols: [u16; N],
first_code: [u32; 16],
first_idx: [u16; 16],
max_length: u8,
}
impl<const N: usize> CanonicalDecoder<N> {
pub fn from_lengths(code_lengths: &[u8]) -> Result<Self, Error> {
assert!(code_lengths.len() <= N);
let mut counts = [0u16; 16];
let mut max_length: u8 = 0;
for &len in code_lengths {
if len > 15 {
return Err(Error::InvalidHuffmanTree);
}
if len > 0 {
counts[len as usize] += 1;
if len > max_length {
max_length = len;
}
}
}
let mut kraft: u32 = 0;
for l in 1..=15u32 {
kraft += (counts[l as usize] as u32) << (15 - l);
}
if kraft > (1 << 15) {
return Err(Error::InvalidHuffmanTree);
}
let mut first_code = [0u32; 16];
let mut first_idx = [0u16; 16];
let mut code: u32 = 0;
let mut idx: u16 = 0;
for l in 1..=15 {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l] as u32;
idx += counts[l];
}
let mut symbols = [0u16; N];
let mut next = first_idx;
for (sym, &len) in code_lengths.iter().enumerate() {
if len > 0 {
symbols[next[len as usize] as usize] = sym as u16;
next[len as usize] += 1;
}
}
Ok(Self {
counts,
symbols,
first_code,
first_idx,
max_length,
})
}
pub fn decode(&self, reader: &mut BitReader) -> Result<Option<u16>, Error> {
if self.max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let available = reader.bits_available();
let max = self.max_length as u32;
let mut code: u32 = 0;
for length in 1..=max {
if length > available {
return Ok(None);
}
let bit = ((reader.peek(length) >> (length - 1)) & 1) as u32;
code = (code << 1) | bit;
let count = self.counts[length as usize] as u32;
if count > 0 {
let first = self.first_code[length as usize];
if code >= first && code < first + count {
let sym_idx = self.first_idx[length as usize] as u32 + (code - first);
reader.drop_bits(length);
return Ok(Some(self.symbols[sym_idx as usize]));
}
}
}
Err(Error::InvalidHuffmanTree)
}
}
#[cfg(feature = "alloc")]
use alloc::vec;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
#[derive(Clone, Copy)]
enum PoolKind {
Coin(u16),
Pair(u32, u32),
}
#[cfg(feature = "alloc")]
struct PoolElement {
cost: u64,
kind: PoolKind,
}
#[cfg(feature = "alloc")]
pub fn length_limited_huffman(freqs: &[u32], max_length: u8) -> Vec<u8> {
assert!(
max_length > 0 && max_length <= 15,
"max_length must be 1..=15"
);
let mut out = vec![0u8; freqs.len()];
let mut coins: Vec<(u32, u16)> = freqs
.iter()
.enumerate()
.filter_map(|(i, &f)| if f > 0 { Some((f, i as u16)) } else { None })
.collect();
let n = coins.len();
if n == 0 {
return out;
}
if n == 1 {
out[coins[0].1 as usize] = 1;
return out;
}
assert!(n <= 1usize << max_length, "alphabet too big for max_length");
coins.sort_by_key(|&(f, _)| f);
let mut pool: Vec<PoolElement> = Vec::with_capacity(n * (max_length as usize) * 2 + 8);
let mut current: Vec<u32> = Vec::with_capacity(2 * n);
for &(f, sym) in &coins {
pool.push(PoolElement {
cost: f as u64,
kind: PoolKind::Coin(sym),
});
current.push((pool.len() - 1) as u32);
}
for _ in 1..max_length {
let mut packages: Vec<u32> = Vec::with_capacity(current.len() / 2);
let mut i = 0;
while i + 1 < current.len() {
let a = current[i];
let b = current[i + 1];
let cost = pool[a as usize].cost + pool[b as usize].cost;
pool.push(PoolElement {
cost,
kind: PoolKind::Pair(a, b),
});
packages.push((pool.len() - 1) as u32);
i += 2;
}
let coin_start = pool.len();
for &(f, sym) in &coins {
pool.push(PoolElement {
cost: f as u64,
kind: PoolKind::Coin(sym),
});
}
let fresh_coins: Vec<u32> = (coin_start..pool.len()).map(|i| i as u32).collect();
let mut merged: Vec<u32> = Vec::with_capacity(fresh_coins.len() + packages.len());
let (mut ci, mut pi) = (0usize, 0usize);
while ci < fresh_coins.len() && pi < packages.len() {
if pool[fresh_coins[ci] as usize].cost <= pool[packages[pi] as usize].cost {
merged.push(fresh_coins[ci]);
ci += 1;
} else {
merged.push(packages[pi]);
pi += 1;
}
}
merged.extend_from_slice(&fresh_coins[ci..]);
merged.extend_from_slice(&packages[pi..]);
current = merged;
}
let pick = 2 * n - 2;
let mut stack: Vec<u32> = Vec::with_capacity(32);
for &root in ¤t[..pick] {
stack.clear();
stack.push(root);
while let Some(idx) = stack.pop() {
match pool[idx as usize].kind {
PoolKind::Coin(sym) => out[sym as usize] += 1,
PoolKind::Pair(a, b) => {
stack.push(a);
stack.push(b);
}
}
}
}
out
}
#[cfg(feature = "alloc")]
pub fn canonical_codes_from_lengths(lengths: &[u8]) -> Vec<u16> {
let mut count = [0u32; 16];
for &len in lengths {
debug_assert!(len <= 15);
if len > 0 {
count[len as usize] += 1;
}
}
let mut next_code = [0u32; 16];
let mut code: u32 = 0;
for bits in 1..=15 {
code = (code + count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut out = vec![0u16; lengths.len()];
for (i, &len) in lengths.iter().enumerate() {
if len > 0 {
out[i] = next_code[len as usize] as u16;
next_code[len as usize] += 1;
}
}
out
}
#[cfg(all(test, feature = "alloc"))]
mod tests {
use super::*;
#[test]
fn canonical_decoder_rfc1951_example() {
let lens = [3u8, 3, 3, 3, 3, 2, 4, 4];
let dec = CanonicalDecoder::<8>::from_lengths(&lens).unwrap();
let mut r = BitReader::new();
r.feed(0b0000_0000);
let sym = dec.decode(&mut r).unwrap().unwrap();
assert_eq!(sym, 5);
let mut r = BitReader::new();
r.feed(0b0000_0010); let sym = dec.decode(&mut r).unwrap().unwrap();
assert_eq!(sym, 0); }
#[test]
fn canonical_codes_roundtrip() {
let lens = [3u8, 3, 3, 3, 3, 2, 4, 4];
let codes = canonical_codes_from_lengths(&lens);
assert_eq!(codes[5], 0b00); assert_eq!(codes[0], 0b010); assert_eq!(codes[1], 0b011); assert_eq!(codes[6], 0b1110); assert_eq!(codes[7], 0b1111); }
#[test]
fn length_limited_basic() {
let lens = length_limited_huffman(&[1, 1, 1, 1], 15);
assert_eq!(lens, vec![2, 2, 2, 2]);
}
#[test]
fn length_limited_enforces_cap() {
let freqs = [1u32, 1, 1, 1, 1, 1, 1, 100];
let lens = length_limited_huffman(&freqs, 3);
assert!(lens.iter().all(|&l| l <= 3));
let min_len = *lens.iter().filter(|&&l| l > 0).min().unwrap();
assert!(lens[7] <= min_len); }
#[test]
fn single_symbol_gets_length_one() {
let lens = length_limited_huffman(&[0, 0, 5, 0], 15);
assert_eq!(lens[2], 1);
assert!(lens.iter().enumerate().all(|(i, &l)| (i == 2) == (l > 0)));
}
}