#![allow(dead_code)]
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
pub(crate) const MAX_CODE_LEN: usize = 20;
pub(crate) struct DecodeTable {
n: usize,
base: [i32; MAX_CODE_LEN + 2],
limit: [i32; MAX_CODE_LEN + 2],
perm: Vec<u16>,
min_len: u8,
max_len: u8,
}
impl DecodeTable {
pub(crate) fn from_lengths(lengths: &[u8]) -> Result<Self, Error> {
let n = lengths.len();
if n == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut min_len: u8 = 0xFF;
let mut max_len: u8 = 0;
for &l in lengths {
if l == 0 || l as usize > MAX_CODE_LEN {
return Err(Error::InvalidHuffmanTree);
}
if l < min_len {
min_len = l;
}
if l > max_len {
max_len = l;
}
}
if max_len == 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut count = [0i32; MAX_CODE_LEN + 2];
for &l in lengths {
count[l as usize] += 1;
}
let mut left: i64 = 0;
for (i, &c) in count[1..=(max_len as usize)].iter().enumerate() {
let len = i + 1;
left = (left << 1) + c as i64;
if left > (1i64 << len) {
return Err(Error::InvalidHuffmanTree);
}
}
if left != (1i64 << max_len) {
return Err(Error::InvalidHuffmanTree);
}
let mut base = [0i32; MAX_CODE_LEN + 2];
let mut limit = [0i32; MAX_CODE_LEN + 2];
let mut vec_pos: i32 = 0;
for len in 1..=MAX_CODE_LEN {
base[len] = vec_pos;
vec_pos += count[len];
}
let mut code: i32 = 0;
for len in 1..=MAX_CODE_LEN {
let cnt = count[len];
if cnt == 0 {
limit[len] = -1;
} else {
limit[len] = code + cnt - 1;
base[len] = code - base[len];
}
code = (code + cnt) << 1;
}
limit[MAX_CODE_LEN + 1] = -1;
let mut cursor = [0usize; MAX_CODE_LEN + 2];
let mut acc = 0usize;
for len in 1..=MAX_CODE_LEN {
cursor[len] = acc;
acc += count[len] as usize;
}
let mut perm = vec![0u16; n];
for (sym, &l) in lengths.iter().enumerate() {
let len = l as usize;
perm[cursor[len]] = sym as u16;
cursor[len] += 1;
}
Ok(Self {
n,
base,
limit,
perm,
min_len,
max_len,
})
}
pub(crate) fn decode_symbol(&self, br: &mut super::bits::BitReader<'_>) -> Result<u16, Error> {
let mut len = self.min_len as usize;
let mut code = br.read_bits(len as u32)? as i32;
while len <= MAX_CODE_LEN {
if code <= self.limit[len] {
let pos = (code - self.base[len]) as usize;
if pos >= self.n {
return Err(Error::InvalidHuffmanTree);
}
return Ok(self.perm[pos]);
}
len += 1;
code = (code << 1) | (br.read_bit()? as i32);
}
Err(Error::InvalidHuffmanTree)
}
pub(crate) fn max_len(&self) -> u8 {
self.max_len
}
}
pub(crate) fn build_canonical_lengths(freqs: &[u32], max_len: usize) -> Vec<u8> {
let design_max = max_len.min(17);
hb_make_code_lengths(freqs, design_max)
}
fn hb_make_code_lengths(freqs: &[u32], max_len: usize) -> Vec<u8> {
let alpha_size = freqs.len();
if alpha_size == 0 {
return Vec::new();
}
if alpha_size == 1 {
return vec![1];
}
let cap_nodes = alpha_size * 2 + 2;
let mut weight = vec![0i64; cap_nodes];
let mut parent = vec![0i32; cap_nodes];
let mut heap = vec![0i32; alpha_size + 2];
let mut cur_freq: Vec<i64> = freqs
.iter()
.map(|&f| if f == 0 { 1i64 } else { f as i64 })
.collect();
const DEPTH_MASK: i64 = 0x0000_00ff;
fn weight_of(w: i64) -> i64 {
w & !DEPTH_MASK
}
fn depth_of(w: i64) -> i64 {
w & DEPTH_MASK
}
fn add_weights(a: i64, b: i64) -> i64 {
(weight_of(a) + weight_of(b)) | (1 + core::cmp::max(depth_of(a), depth_of(b)))
}
loop {
for i in 0..alpha_size {
weight[i + 1] = cur_freq[i] << 8;
}
let mut n_nodes = alpha_size as i32;
let mut n_heap = 0i32;
heap[0] = 0;
weight[0] = 0;
parent[0] = -2;
for i in 1..=alpha_size as i32 {
parent[i as usize] = -1;
n_heap += 1;
heap[n_heap as usize] = i;
let mut zz = n_heap;
let tmp = heap[zz as usize];
while weight[tmp as usize] < weight[heap[(zz >> 1) as usize] as usize] {
heap[zz as usize] = heap[(zz >> 1) as usize];
zz >>= 1;
}
heap[zz as usize] = tmp;
}
while n_heap > 1 {
let n1 = heap[1];
heap[1] = heap[n_heap as usize];
n_heap -= 1;
downheap(&mut heap, &weight, n_heap, 1);
let n2 = heap[1];
heap[1] = heap[n_heap as usize];
n_heap -= 1;
downheap(&mut heap, &weight, n_heap, 1);
n_nodes += 1;
parent[n1 as usize] = n_nodes;
parent[n2 as usize] = n_nodes;
weight[n_nodes as usize] = add_weights(weight[n1 as usize], weight[n2 as usize]);
parent[n_nodes as usize] = -1;
n_heap += 1;
heap[n_heap as usize] = n_nodes;
let mut zz = n_heap;
let tmp = heap[zz as usize];
while weight[tmp as usize] < weight[heap[(zz >> 1) as usize] as usize] {
heap[zz as usize] = heap[(zz >> 1) as usize];
zz >>= 1;
}
heap[zz as usize] = tmp;
}
let mut lengths = vec![0u8; alpha_size];
let mut too_long = false;
for i in 1..=alpha_size {
let mut j = 0i32;
let mut k = i as i32;
while parent[k as usize] >= 0 {
k = parent[k as usize];
j += 1;
}
lengths[i - 1] = j as u8;
if j as usize > max_len {
too_long = true;
}
}
if !too_long {
return lengths;
}
for f in cur_freq.iter_mut() {
let j = *f;
*f = 1 + (j / 2);
}
}
}
fn downheap(heap: &mut [i32], weight: &[i64], n_heap: i32, z: i32) {
let mut zz = z;
let tmp = heap[zz as usize];
loop {
let mut yy = zz << 1;
if yy > n_heap {
break;
}
if yy < n_heap
&& weight[heap[(yy + 1) as usize] as usize] < weight[heap[yy as usize] as usize]
{
yy += 1;
}
if weight[tmp as usize] < weight[heap[yy as usize] as usize] {
break;
}
heap[zz as usize] = heap[yy as usize];
zz = yy;
}
heap[zz as usize] = tmp;
}
pub(crate) fn build_canonical_codes(lengths: &[u8]) -> Vec<u32> {
let n = lengths.len();
let max_len = lengths.iter().copied().max().unwrap_or(0) as usize;
let mut bl_count = vec![0u32; max_len + 2];
for &l in lengths {
bl_count[l as usize] += 1;
}
bl_count[0] = 0;
let mut next_code = vec![0u32; max_len + 2];
let mut code = 0u32;
for bits in 1..=max_len {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut codes = vec![0u32; n];
for (sym, &l) in lengths.iter().enumerate() {
if l > 0 {
codes[sym] = next_code[l as usize];
next_code[l as usize] += 1;
}
}
codes
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn round_trip_decode_table_simple() {
let lengths = [2u8, 2, 2, 2];
let tbl = DecodeTable::from_lengths(&lengths).unwrap();
let codes = build_canonical_codes(&lengths);
assert_eq!(codes, vec![0b00, 0b01, 0b10, 0b11]);
assert_eq!(tbl.max_len(), 2);
}
#[test]
fn round_trip_encode_decode() {
let lengths = [3u8, 3, 3, 3, 2, 2];
let codes = build_canonical_codes(&lengths);
let mut bw = super::super::bits::BitWriter::new();
let stream = [0u16, 5, 3, 1, 4];
for &s in &stream {
bw.write_bits(lengths[s as usize] as u32, codes[s as usize]);
}
bw.align_to_byte();
let buf = bw.into_bytes();
let tbl = DecodeTable::from_lengths(&lengths).unwrap();
let mut br = super::super::bits::BitReader::new(&buf);
for &expect in &stream {
let got = tbl.decode_symbol(&mut br).unwrap();
assert_eq!(got, expect);
}
}
#[test]
fn build_lengths_does_not_explode() {
let freqs = [50u32, 30, 20, 10, 5, 3, 2, 1];
let lens = build_canonical_lengths(&freqs, MAX_CODE_LEN);
assert!(lens.iter().all(|&l| (1..=20).contains(&l)));
assert_eq!(lens.len(), freqs.len());
}
#[test]
fn build_lengths_caps_at_17_and_is_kraft_valid() {
let cases: alloc::vec::Vec<alloc::vec::Vec<u32>> = alloc::vec![
alloc::vec![1, 1],
alloc::vec![0, 0, 0, 5],
alloc::vec![1000000, 1, 1, 1, 1, 1, 1, 1],
(0..50u32).map(|i| 1 << (i % 24)).collect(),
alloc::vec![1u32; 258],
];
for freqs in &cases {
let lens = build_canonical_lengths(freqs, MAX_CODE_LEN);
assert_eq!(lens.len(), freqs.len());
assert!(
lens.iter().all(|&l| (1..=17).contains(&l)),
"length out of 1..=17: {lens:?}"
);
DecodeTable::from_lengths(&lens).expect("builder produced a non-Kraft-valid table");
}
}
}