extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
pub const MAX_BITS: u32 = 16;
pub struct HuffTable {
table: Vec<TableEntry>,
table_bits: u32,
long: Vec<[u32; 2]>,
num_symbols: usize,
single: Option<u16>,
}
#[derive(Clone, Copy)]
struct TableEntry {
sym: u32,
len: u8,
}
const LEAF: u32 = 0x8000_0000;
impl HuffTable {
pub fn build(lengths: &[u8], table_bits: u32) -> Result<Self, Error> {
let num_symbols = lengths.len();
let mut count = [0u32; (MAX_BITS + 1) as usize];
let mut used = 0u32;
let mut last_sym = 0u16;
for (s, &l) in lengths.iter().enumerate() {
if l as u32 > MAX_BITS {
return Err(Error::InvalidHuffmanTree);
}
if l > 0 {
count[l as usize] += 1;
used += 1;
last_sym = s as u16;
}
}
if used == 0 {
return Ok(Self {
table: Vec::new(),
table_bits,
long: Vec::new(),
num_symbols,
single: None,
});
}
if used == 1 {
return Ok(Self {
table: Vec::new(),
table_bits,
long: Vec::new(),
num_symbols,
single: Some(last_sym),
});
}
let mut total: u64 = 0;
for l in 1..=MAX_BITS {
total += (count[l as usize] as u64) << (MAX_BITS - l);
}
if total != (1u64 << MAX_BITS) {
return Err(Error::InvalidHuffmanTree);
}
let mut next_code = [0u32; (MAX_BITS + 1) as usize];
let mut code = 0u32;
for l in 1..=MAX_BITS as usize {
code = (code + count[l - 1]) << 1;
next_code[l] = code;
}
let table_size = 1usize << table_bits;
let mut table = vec![TableEntry { sym: 0, len: 0 }; table_size];
let mut long: Vec<[u32; 2]> = Vec::new();
for (sym, &l) in lengths.iter().enumerate() {
let l = l as u32;
if l == 0 {
continue;
}
let c = next_code[l as usize];
next_code[l as usize] += 1;
if l <= table_bits {
let shift = table_bits - l;
let start = (c << shift) as usize;
let span = 1usize << shift;
if start + span > table_size {
return Err(Error::InvalidHuffmanTree);
}
for slot in &mut table[start..start + span] {
slot.sym = sym as u32;
slot.len = l as u8;
}
} else {
let top = (c >> (l - table_bits)) as usize;
if top >= table_size {
return Err(Error::InvalidHuffmanTree);
}
if table[top].len as u32 <= table_bits {
long.push([0, 0]);
table[top].sym = (long.len() - 1) as u32;
table[top].len = (table_bits + 1) as u8; }
let mut node = table[top].sym as usize;
let extra = l - table_bits;
for i in (0..extra).rev() {
let bit = ((c >> i) & 1) as usize;
if i == 0 {
if node >= long.len() {
return Err(Error::InvalidHuffmanTree);
}
if long[node][bit] != 0 {
return Err(Error::InvalidHuffmanTree);
}
long[node][bit] = LEAF | sym as u32;
} else {
if node >= long.len() {
return Err(Error::InvalidHuffmanTree);
}
let child = long[node][bit];
if child == 0 {
long.push([0, 0]);
let idx = (long.len() - 1) as u32;
long[node][bit] = idx;
node = idx as usize;
} else if child & LEAF != 0 {
return Err(Error::InvalidHuffmanTree);
} else {
node = child as usize;
}
}
}
}
}
Ok(Self {
table,
table_bits,
long,
num_symbols,
single: None,
})
}
pub fn build_single(num_symbols: usize, sym: u16, table_bits: u32) -> Result<Self, Error> {
if sym as usize >= num_symbols {
return Err(Error::InvalidHuffmanTree);
}
Ok(Self {
table: Vec::new(),
table_bits,
long: Vec::new(),
num_symbols,
single: Some(sym),
})
}
pub fn decode(&self, br: &mut super::bits::BitReader<'_>) -> Result<u16, Error> {
if let Some(s) = self.single {
return Ok(s);
}
if self.table.is_empty() {
return Err(Error::Corrupt);
}
let idx = br.peek_bits(self.table_bits) as usize;
let entry = self.table[idx];
if entry.len == 0 {
return Err(Error::Corrupt);
}
if (entry.len as u32) <= self.table_bits {
br.consume(entry.len as u32);
let sym = entry.sym as usize;
if sym >= self.num_symbols {
return Err(Error::Corrupt);
}
return Ok(entry.sym as u16);
}
br.consume(self.table_bits);
let mut node = entry.sym as usize;
loop {
let bit = br.get_bits(1) as usize;
if node >= self.long.len() {
return Err(Error::Corrupt);
}
let next = self.long[node][bit];
if next == 0 {
return Err(Error::Corrupt);
}
if next & LEAF != 0 {
let sym = (next & !LEAF) as usize;
if sym >= self.num_symbols {
return Err(Error::Corrupt);
}
return Ok(sym as u16);
}
node = next as usize;
}
}
}
pub fn assign_lengths(freqs: &[u32], max_bits: u32) -> Vec<u8> {
let n = freqs.len();
let mut lengths = vec![0u8; n];
let mut active: Vec<(u64, usize)> = freqs
.iter()
.enumerate()
.filter(|&(_, &f)| f > 0)
.map(|(s, &f)| (f as u64, s))
.collect();
if active.is_empty() {
return lengths;
}
if active.len() == 1 {
lengths[active[0].1] = 1;
return lengths;
}
active.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
#[derive(Clone)]
enum Node {
Leaf(usize),
Pkg(alloc::boxed::Box<Node>, alloc::boxed::Box<Node>),
}
let leaves: Vec<(u64, Node)> = active.iter().map(|&(f, s)| (f, Node::Leaf(s))).collect();
let mut row: Vec<(u64, Node)> = leaves.clone();
for _ in 1..max_bits {
let mut packages: Vec<(u64, Node)> = Vec::with_capacity(row.len() / 2);
let mut i = 0;
while i + 1 < row.len() {
let w = row[i].0 + row[i + 1].0;
let left = alloc::boxed::Box::new(row[i].1.clone());
let right = alloc::boxed::Box::new(row[i + 1].1.clone());
packages.push((w, Node::Pkg(left, right)));
i += 2;
}
let mut merged: Vec<(u64, Node)> = Vec::with_capacity(packages.len() + leaves.len());
let (mut a, mut b) = (0usize, 0usize);
while a < leaves.len() || b < packages.len() {
let take_leaf = match (leaves.get(a), packages.get(b)) {
(Some(l), Some(p)) => l.0 <= p.0,
(Some(_), None) => true,
(None, Some(_)) => false,
(None, None) => break,
};
if take_leaf {
merged.push(leaves[a].clone());
a += 1;
} else {
merged.push(packages[b].clone());
b += 1;
}
}
row = merged;
}
let m = active.len();
let take = 2 * m - 2;
fn count_leaves(node: &Node, lengths: &mut [u8]) {
match node {
Node::Leaf(s) => lengths[*s] = lengths[*s].saturating_add(1),
Node::Pkg(l, r) => {
count_leaves(l, lengths);
count_leaves(r, lengths);
}
}
}
for item in row.iter().take(take) {
count_leaves(&item.1, &mut lengths);
}
lengths
}
pub fn lengths_to_codes(lengths: &[u8]) -> Vec<u32> {
let mut count = [0u32; (MAX_BITS + 1) as usize];
for &l in lengths {
if l > 0 {
count[l as usize] += 1;
}
}
let mut next_code = [0u32; (MAX_BITS + 1) as usize];
let mut code = 0u32;
for l in 1..=MAX_BITS as usize {
code = (code + count[l - 1]) << 1;
next_code[l] = code;
}
let mut codes = vec![0u32; lengths.len()];
for (s, &l) in lengths.iter().enumerate() {
if l > 0 {
codes[s] = next_code[l as usize];
next_code[l as usize] += 1;
}
}
codes
}
#[cfg(test)]
mod tests {
use super::*;
extern crate alloc;
use alloc::vec::Vec;
fn check(lens: &[u8], codes: &[u32], sym: usize, table: &HuffTable) {
let l = lens[sym] as u32;
let code = codes[sym];
let total = (l.div_ceil(8) * 8).max(24);
let val = (code as u64) << (total - l);
let mut bits = Vec::new();
for b in (0..total / 8).rev() {
bits.push(((val >> (b * 8)) & 0xFF) as u8);
}
let mut br = crate::lha::bits::BitReader::new(&bits);
let got = table.decode(&mut br).expect("decode ok");
assert_eq!(got as usize, sym, "sym {sym} len {l}");
}
#[test]
fn short_codes_roundtrip() {
let mut freqs = vec![0u32; 510];
for (i, f) in freqs.iter_mut().take(300).enumerate() {
*f = (i as u32 % 7) + 1;
}
let lens = assign_lengths(&freqs, MAX_BITS);
let codes = lengths_to_codes(&lens);
let table = HuffTable::build(&lens, 12).expect("build ok");
for sym in 0..510 {
if lens[sym] != 0 {
check(&lens, &codes, sym, &table);
}
}
}
#[test]
fn long_codes_roundtrip() {
let mut freqs = vec![0u32; 510];
let (mut a, mut b) = (1u32, 1u32);
for f in freqs.iter_mut().take(40) {
*f = a;
let c = a.wrapping_add(b);
a = b;
b = c.min(1_000_000);
}
for f in freqs.iter_mut().take(300).skip(40) {
*f = 1;
}
let lens = assign_lengths(&freqs, MAX_BITS);
assert!(
lens.iter().copied().max().unwrap() > 12,
"test should force long codes"
);
let codes = lengths_to_codes(&lens);
let table = HuffTable::build(&lens, 12).expect("build ok");
for sym in 0..510 {
if lens[sym] != 0 {
check(&lens, &codes, sym, &table);
}
}
}
#[test]
fn rejects_incomplete_code() {
let mut lens = vec![0u8; 8];
lens[0] = 2;
lens[1] = 2;
assert!(matches!(
HuffTable::build(&lens, 4),
Err(Error::InvalidHuffmanTree)
));
}
}