#![allow(dead_code)]
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
pub(crate) const MAX_CODE_LEN: usize = 20;
pub(crate) struct DecodeTable {
n: usize,
base: [i32; MAX_CODE_LEN + 2],
limit: [i32; MAX_CODE_LEN + 2],
perm: Vec<u16>,
min_len: u8,
max_len: u8,
}
impl DecodeTable {
pub(crate) fn from_lengths(lengths: &[u8]) -> Result<Self, Error> {
let n = lengths.len();
if n == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut min_len: u8 = 0xFF;
let mut max_len: u8 = 0;
for &l in lengths {
if l == 0 || l as usize > MAX_CODE_LEN {
return Err(Error::InvalidHuffmanTree);
}
if l < min_len {
min_len = l;
}
if l > max_len {
max_len = l;
}
}
if max_len == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut count = [0i32; MAX_CODE_LEN + 2];
for &l in lengths {
count[l as usize] += 1;
}
let mut left: i64 = 0;
for (i, &c) in count[1..=(max_len as usize)].iter().enumerate() {
let len = i + 1;
left = (left << 1) + c as i64;
if left > (1i64 << len) {
return Err(Error::InvalidHuffmanTree);
}
}
if left != (1i64 << max_len) {
return Err(Error::InvalidHuffmanTree);
}
let mut base = [0i32; MAX_CODE_LEN + 2];
let mut limit = [0i32; MAX_CODE_LEN + 2];
let mut vec_pos: i32 = 0;
for len in 1..=MAX_CODE_LEN {
base[len] = vec_pos;
vec_pos += count[len];
}
let mut code: i32 = 0;
for len in 1..=MAX_CODE_LEN {
let cnt = count[len];
if cnt == 0 {
limit[len] = -1;
} else {
limit[len] = code + cnt - 1;
base[len] = code - base[len];
}
code = (code + cnt) << 1;
}
limit[MAX_CODE_LEN + 1] = -1;
let mut cursor = [0usize; MAX_CODE_LEN + 2];
let mut acc = 0usize;
for len in 1..=MAX_CODE_LEN {
cursor[len] = acc;
acc += count[len] as usize;
}
let mut perm = vec![0u16; n];
for (sym, &l) in lengths.iter().enumerate() {
let len = l as usize;
perm[cursor[len]] = sym as u16;
cursor[len] += 1;
}
Ok(Self {
n,
base,
limit,
perm,
min_len,
max_len,
})
}
pub(crate) fn decode_symbol(&self, br: &mut super::bits::BitReader<'_>) -> Result<u16, Error> {
let mut len = self.min_len as usize;
let mut code = br.read_bits(len as u32)? as i32;
while len <= MAX_CODE_LEN {
if code <= self.limit[len] {
let pos = (code - self.base[len]) as usize;
if pos >= self.n {
return Err(Error::InvalidHuffmanTree);
}
return Ok(self.perm[pos]);
}
len += 1;
code = (code << 1) | (br.read_bit()? as i32);
}
Err(Error::InvalidHuffmanTree)
}
pub(crate) fn max_len(&self) -> u8 {
self.max_len
}
}
pub(crate) fn build_canonical_lengths(freqs: &[u32], max_len: usize) -> Vec<u8> {
let n = freqs.len();
let mut weights: Vec<u32> = freqs.iter().map(|&f| if f == 0 { 1 } else { f }).collect();
loop {
let lengths = compute_lengths(&weights);
let mx = lengths.iter().copied().max().unwrap_or(0) as usize;
if mx <= max_len {
return lengths;
}
for w in weights.iter_mut() {
*w = (*w).div_ceil(2).max(1);
}
if n <= 1 {
return vec![1u8; n.max(1)];
}
}
}
fn compute_lengths(weights: &[u32]) -> Vec<u8> {
let n = weights.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![1];
}
let mut alive: Vec<(u64, usize)> = weights
.iter()
.enumerate()
.map(|(i, &w)| (w as u64, i))
.collect();
let mut parent: Vec<usize> = vec![usize::MAX; 2 * n];
let mut next_node = n;
while alive.len() > 1 {
alive.sort_by_key(|b| core::cmp::Reverse(b.0));
let (w1, n1) = alive.pop().unwrap();
let (w2, n2) = alive.pop().unwrap();
parent[n1] = next_node;
parent[n2] = next_node;
alive.push((w1 + w2, next_node));
next_node += 1;
}
let mut lengths = vec![0u8; n];
for leaf in 0..n {
let mut depth = 0u32;
let mut node = parent[leaf];
while node != usize::MAX {
depth += 1;
node = parent[node];
}
lengths[leaf] = depth.max(1) as u8;
}
lengths
}
pub(crate) fn build_canonical_codes(lengths: &[u8]) -> Vec<u32> {
let n = lengths.len();
let max_len = lengths.iter().copied().max().unwrap_or(0) as usize;
let mut bl_count = vec![0u32; max_len + 2];
for &l in lengths {
bl_count[l as usize] += 1;
}
bl_count[0] = 0;
let mut next_code = vec![0u32; max_len + 2];
let mut code = 0u32;
for bits in 1..=max_len {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut codes = vec![0u32; n];
for (sym, &l) in lengths.iter().enumerate() {
if l > 0 {
codes[sym] = next_code[l as usize];
next_code[l as usize] += 1;
}
}
codes
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn round_trip_decode_table_simple() {
let lengths = [2u8, 2, 2, 2];
let tbl = DecodeTable::from_lengths(&lengths).unwrap();
let codes = build_canonical_codes(&lengths);
assert_eq!(codes, vec![0b00, 0b01, 0b10, 0b11]);
assert_eq!(tbl.max_len(), 2);
}
#[test]
fn round_trip_encode_decode() {
let lengths = [3u8, 3, 3, 3, 2, 2];
let codes = build_canonical_codes(&lengths);
let mut bw = super::super::bits::BitWriter::new();
let stream = [0u16, 5, 3, 1, 4];
for &s in &stream {
bw.write_bits(lengths[s as usize] as u32, codes[s as usize]);
}
bw.align_to_byte();
let buf = bw.into_bytes();
let tbl = DecodeTable::from_lengths(&lengths).unwrap();
let mut br = super::super::bits::BitReader::new(&buf);
for &expect in &stream {
let got = tbl.decode_symbol(&mut br).unwrap();
assert_eq!(got, expect);
}
}
#[test]
fn build_lengths_does_not_explode() {
let freqs = [50u32, 30, 20, 10, 5, 3, 2, 1];
let lens = build_canonical_lengths(&freqs, MAX_CODE_LEN);
assert!(lens.iter().all(|&l| (1..=20).contains(&l)));
assert_eq!(lens.len(), freqs.len());
}
}