extern crate alloc;
use alloc::vec::Vec;
use crate::error::Error;
pub const NUM_SYMBOLS: usize = 512;
pub const MAX_CODE_LEN: u32 = 15;
pub const DECODE_TABLE_BITS: u32 = 15;
pub const DECODE_TABLE_SIZE: usize = 1 << DECODE_TABLE_BITS;
pub fn unpack_lengths(packed: &[u8; 256]) -> [u8; NUM_SYMBOLS] {
let mut lens = [0u8; NUM_SYMBOLS];
for (k, &b) in packed.iter().enumerate() {
lens[2 * k] = b & 0x0F;
lens[2 * k + 1] = b >> 4;
}
lens
}
pub fn pack_lengths(lens: &[u8; NUM_SYMBOLS]) -> [u8; 256] {
let mut packed = [0u8; 256];
for k in 0..256 {
packed[k] = (lens[2 * k] & 0x0F) | ((lens[2 * k + 1] & 0x0F) << 4);
}
packed
}
pub fn build_decode_table(
lengths: &[u8; NUM_SYMBOLS],
) -> Result<alloc::boxed::Box<DecodeTable>, Error> {
for &l in lengths.iter() {
if l as u32 > MAX_CODE_LEN {
return Err(Error::InvalidHuffmanTree);
}
}
let mut table = alloc::boxed::Box::new([(0u16, 0u8); DECODE_TABLE_SIZE]);
let mut entry: usize = 0;
for bit_length in 1..=MAX_CODE_LEN {
for symbol in 0..NUM_SYMBOLS as u16 {
if lengths[symbol as usize] as u32 == bit_length {
let count = 1usize << (MAX_CODE_LEN - bit_length);
if entry + count > DECODE_TABLE_SIZE {
return Err(Error::InvalidHuffmanTree);
}
let value = (symbol, bit_length as u8);
for slot in &mut table[entry..entry + count] {
*slot = value;
}
entry += count;
}
}
}
if entry != DECODE_TABLE_SIZE {
return Err(Error::InvalidHuffmanTree);
}
Ok(table)
}
pub type DecodeTable = [(u16, u8); DECODE_TABLE_SIZE];
pub fn lengths_to_codes(lengths: &[u8; NUM_SYMBOLS]) -> [u16; NUM_SYMBOLS] {
let mut counts = [0u32; (MAX_CODE_LEN + 1) as usize];
for &l in lengths.iter() {
if l > 0 {
counts[l as usize] += 1;
}
}
let mut next_code = [0u32; (MAX_CODE_LEN + 1) as usize];
let mut code: u32 = 0;
for l in 1..=MAX_CODE_LEN as usize {
code = (code + counts[l - 1]) << 1;
next_code[l] = code;
}
let mut codes = [0u16; NUM_SYMBOLS];
for (sym, &l) in lengths.iter().enumerate() {
if l > 0 {
codes[sym] = next_code[l as usize] as u16;
next_code[l as usize] += 1;
}
}
codes
}
pub fn length_limited_huffman(freqs: &[u32], max_length: u32) -> [u8; NUM_SYMBOLS] {
debug_assert!(freqs.len() <= NUM_SYMBOLS);
debug_assert!((1..=MAX_CODE_LEN).contains(&max_length));
let mut active: Vec<(u32, u32)> = freqs
.iter()
.enumerate()
.filter(|&(_, &f)| f > 0)
.map(|(s, &f)| (f, s as u32))
.collect();
let mut lengths = [0u8; NUM_SYMBOLS];
if active.is_empty() {
return lengths;
}
if active.len() == 1 {
lengths[active[0].1 as usize] = 1;
return lengths;
}
active.sort_by_key(|&(f, s)| (f, s));
let base: Vec<(u32, u32)> = active.iter().map(|&(f, s)| (f, s)).collect();
let mut rows: Vec<Vec<Item>> = Vec::with_capacity(max_length as usize);
let leaves: Vec<Item> = base
.iter()
.map(|&(f, s)| Item {
weight: f as u64,
leaf: Some(s as u16),
children: None,
})
.collect();
rows.push(leaves.clone());
for _ in 1..max_length {
let prev = rows.last().unwrap();
let mut packages: Vec<Item> = Vec::with_capacity(prev.len() / 2);
for chunk in prev.chunks(2) {
if chunk.len() == 2 {
packages.push(Item {
weight: chunk[0].weight + chunk[1].weight,
leaf: None,
children: Some((chunk[0].clone().into(), chunk[1].clone().into())),
});
}
}
let mut merged: Vec<Item> = Vec::with_capacity(packages.len() + leaves.len());
let (mut i, mut j) = (0usize, 0usize);
while i < leaves.len() || j < packages.len() {
let take_leaf = match (leaves.get(i), packages.get(j)) {
(Some(l), Some(p)) => l.weight <= p.weight,
(Some(_), None) => true,
(None, Some(_)) => false,
_ => unreachable!(),
};
if take_leaf {
merged.push(leaves[i].clone());
i += 1;
} else {
merged.push(packages[j].clone());
j += 1;
}
}
rows.push(merged);
}
let top = rows.last().unwrap();
let n = base.len();
let take = (2 * n).saturating_sub(2);
let selected = &top[..take.min(top.len())];
for item in selected {
item.count_leaves(&mut lengths);
}
lengths
}
#[derive(Clone)]
struct Item {
weight: u64,
leaf: Option<u16>,
children: Option<(alloc::boxed::Box<Item>, alloc::boxed::Box<Item>)>,
}
impl Item {
fn count_leaves(&self, lengths: &mut [u8; NUM_SYMBOLS]) {
if let Some(s) = self.leaf {
lengths[s as usize] += 1;
} else if let Some((l, r)) = &self.children {
l.count_leaves(lengths);
r.count_leaves(lengths);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_unpack_roundtrip() {
let mut lens = [0u8; NUM_SYMBOLS];
for (i, l) in lens.iter_mut().enumerate() {
*l = ((i * 7) % 16) as u8;
}
let packed = pack_lengths(&lens);
let back = unpack_lengths(&packed);
assert_eq!(back, lens);
}
#[test]
fn build_decode_table_accepts_simple() {
let mut lens = [0u8; NUM_SYMBOLS];
lens[0] = 1;
lens[1] = 1;
let t = build_decode_table(&lens).unwrap();
assert_eq!(t[0], (0, 1));
assert_eq!(t[(1 << 14) - 1], (0, 1));
assert_eq!(t[1 << 14], (1, 1));
assert_eq!(t[(1 << 15) - 1], (1, 1));
}
#[test]
fn build_decode_table_rejects_invalid() {
let mut lens = [0u8; NUM_SYMBOLS];
lens[0] = 1; assert!(matches!(
build_decode_table(&lens),
Err(Error::InvalidHuffmanTree)
));
}
#[test]
fn canonical_codes_match_spec() {
let mut lens = [0u8; NUM_SYMBOLS];
lens[0] = 5;
lens[1] = 6;
lens[2] = 7;
lens[3] = 8;
let packed = pack_lengths(&lens);
assert_eq!(packed[0], 0x65);
assert_eq!(packed[1], 0x87);
}
#[test]
fn length_limited_caps_lengths() {
let mut freqs = [0u32; 200];
for (i, f) in freqs.iter_mut().enumerate() {
*f = (i as u32) + 1;
}
let lens = length_limited_huffman(&freqs, 15);
for &l in lens.iter() {
assert!(l <= 15);
}
}
#[test]
fn single_symbol_gets_length_one() {
let mut freqs = [0u32; NUM_SYMBOLS];
freqs[42] = 100;
let lens = length_limited_huffman(&freqs, 15);
assert_eq!(lens[42], 1);
for (s, &l) in lens.iter().enumerate() {
if s != 42 {
assert_eq!(l, 0);
}
}
}
}