use oxiarc_core::BitReader;
use oxiarc_core::error::{OxiArcError, Result};
use std::io::Read;
pub const MAX_CODE_LENGTH: usize = 15;
pub const LITLEN_ALPHABET_SIZE: usize = 286;
pub const DISTANCE_ALPHABET_SIZE: usize = 30;
pub const CODELEN_ALPHABET_SIZE: usize = 19;
pub const END_OF_BLOCK: u16 = 256;
#[derive(Debug, Clone)]
pub struct HuffmanTree {
fast_table: Vec<(u16, u8)>,
fast_bits: u8,
max_code_length: u8,
symbols: Vec<u16>,
base_codes: [u32; MAX_CODE_LENGTH + 1],
symbol_offsets: [u16; MAX_CODE_LENGTH + 1],
}
impl HuffmanTree {
const FAST_BITS: u8 = 9;
pub fn from_code_lengths(code_lengths: &[u8]) -> Result<Self> {
if code_lengths.is_empty() {
return Err(OxiArcError::invalid_header("Empty code lengths"));
}
let mut bl_count = [0u32; MAX_CODE_LENGTH + 1];
let mut max_length = 0u8;
for &len in code_lengths {
if len > 0 {
if len as usize > MAX_CODE_LENGTH {
return Err(OxiArcError::invalid_header(format!(
"Code length {} exceeds maximum {}",
len, MAX_CODE_LENGTH
)));
}
bl_count[len as usize] += 1;
max_length = max_length.max(len);
}
}
if max_length == 0 {
return Ok(Self {
fast_table: vec![(0, 0); 1 << Self::FAST_BITS],
fast_bits: Self::FAST_BITS,
max_code_length: 0,
symbols: Vec::new(),
base_codes: [0; MAX_CODE_LENGTH + 1],
symbol_offsets: [0; MAX_CODE_LENGTH + 1],
});
}
let mut next_code = [0u32; MAX_CODE_LENGTH + 1];
let mut code = 0u32;
for bits in 1..=max_length as usize {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let total_codes: u32 = bl_count[1..=max_length as usize].iter().sum();
if total_codes > 0 {
let max_codes = 1u32 << max_length;
if code + bl_count[max_length as usize] > max_codes {
return Err(OxiArcError::invalid_header("Over-subscribed Huffman tree"));
}
}
let mut symbols = vec![0u16; total_codes as usize];
let mut symbol_offsets = [0u16; MAX_CODE_LENGTH + 1];
let mut base_codes = [0u32; MAX_CODE_LENGTH + 1];
let mut offset = 0u16;
for bits in 1..=max_length as usize {
symbol_offsets[bits] = offset;
base_codes[bits] = next_code[bits];
offset += bl_count[bits] as u16;
}
if max_length < MAX_CODE_LENGTH as u8 {
symbol_offsets[max_length as usize + 1] = offset;
}
let mut current_code = next_code;
for (symbol, &len) in code_lengths.iter().enumerate() {
if len > 0 {
let len = len as usize;
let idx =
symbol_offsets[len] as usize + (current_code[len] - base_codes[len]) as usize;
if idx < symbols.len() {
symbols[idx] = symbol as u16;
}
current_code[len] += 1;
}
}
let fast_bits = Self::FAST_BITS.min(max_length);
let fast_table_size = 1 << fast_bits;
let mut fast_table = vec![(0u16, 0u8); fast_table_size];
for (symbol, &len) in code_lengths.iter().enumerate() {
if len > 0 && len <= fast_bits {
let len = len as usize;
let code = Self::reverse_bits(next_code[len] as u16, len as u8);
next_code[len] += 1;
let fill_count = 1 << (fast_bits - len as u8);
for i in 0..fill_count {
let index = code as usize | (i << len);
if index < fast_table_size {
fast_table[index] = (symbol as u16, len as u8);
}
}
}
}
Ok(Self {
fast_table,
fast_bits,
max_code_length: max_length,
symbols,
base_codes,
symbol_offsets,
})
}
fn reverse_bits(mut code: u16, length: u8) -> u16 {
let mut reversed = 0u16;
for _ in 0..length {
reversed = (reversed << 1) | (code & 1);
code >>= 1;
}
reversed
}
#[inline]
pub fn decode<R: Read>(&self, reader: &mut BitReader<R>) -> Result<u16> {
if self.max_code_length == 0 {
return Err(OxiArcError::invalid_huffman(reader.bit_position()));
}
match reader.peek_bits(self.fast_bits) {
Ok(bits) => {
let (symbol, len) = unsafe {
*self.fast_table.get_unchecked(bits as usize)
};
if len > 0 {
reader.skip_bits(len)?;
return Ok(symbol);
}
self.decode_slow(reader)
}
Err(_) => {
self.decode_slow(reader)
}
}
}
fn decode_slow<R: Read>(&self, reader: &mut BitReader<R>) -> Result<u16> {
let mut code = 0u32;
for len in 1..=self.max_code_length as usize {
let bit = reader.read_bits(1)?;
code = (code << 1) | bit;
let count = if len < MAX_CODE_LENGTH {
self.symbol_offsets[len + 1] - self.symbol_offsets[len]
} else {
self.symbols.len() as u16 - self.symbol_offsets[len]
};
if count > 0 && code >= self.base_codes[len] {
let idx = code - self.base_codes[len];
if idx < count as u32 {
let symbol_idx = self.symbol_offsets[len] as usize + idx as usize;
if symbol_idx < self.symbols.len() {
return Ok(self.symbols[symbol_idx]);
}
}
}
}
Err(OxiArcError::invalid_huffman(reader.bit_position()))
}
}
pub(crate) fn cost_table_from_lengths(lengths: &[u8]) -> Vec<u32> {
lengths
.iter()
.map(|&l| if l == 0 { u32::MAX } else { l as u32 })
.collect()
}
pub(crate) fn cost_of_match(
length: u16,
distance: u16,
litlen_costs: &[u32],
dist_costs: &[u32],
) -> u32 {
use crate::tables::{DISTANCE_EXTRA_BITS, LENGTH_EXTRA_BITS, distance_to_code, length_to_code};
let (len_code, len_extra_bits, _) = length_to_code(length);
let len_sym_cost = litlen_costs
.get(len_code as usize)
.copied()
.unwrap_or(u32::MAX);
if len_sym_cost == u32::MAX {
return u32::MAX;
}
let (dist_code, dist_extra_bits, _) = distance_to_code(distance);
let dist_sym_cost = dist_costs
.get(dist_code as usize)
.copied()
.unwrap_or(u32::MAX);
if dist_sym_cost == u32::MAX {
return u32::MAX;
}
let len_eb = LENGTH_EXTRA_BITS
.get((len_code as usize).saturating_sub(257))
.copied()
.unwrap_or(len_extra_bits) as u32;
let dist_eb = DISTANCE_EXTRA_BITS
.get(dist_code as usize)
.copied()
.unwrap_or(dist_extra_bits) as u32;
len_sym_cost
.saturating_add(len_eb)
.saturating_add(dist_sym_cost)
.saturating_add(dist_eb)
}
#[derive(Debug)]
pub struct HuffmanBuilder {
frequencies: Vec<u32>,
max_length: u8,
}
impl HuffmanBuilder {
pub fn new(alphabet_size: usize, max_length: u8) -> Self {
Self {
frequencies: vec![0; alphabet_size],
max_length,
}
}
pub fn add(&mut self, symbol: u16) {
if (symbol as usize) < self.frequencies.len() {
self.frequencies[symbol as usize] += 1;
}
}
pub fn add_count(&mut self, symbol: u16, count: u32) {
if (symbol as usize) < self.frequencies.len() {
self.frequencies[symbol as usize] += count;
}
}
pub fn build_lengths(&self) -> Vec<u8> {
let n = self.frequencies.len();
let mut lengths = vec![0u8; n];
let mut symbols: Vec<(u32, usize)> = self
.frequencies
.iter()
.enumerate()
.filter(|&(_, f)| *f > 0)
.map(|(i, f)| (*f, i))
.collect();
if symbols.is_empty() {
return lengths;
}
if symbols.len() == 1 {
let only = symbols[0].1;
lengths[only] = 1;
let phantom = if only == 0 { 1.min(n - 1) } else { 0 };
if phantom != only {
lengths[phantom] = 1;
}
return lengths;
}
symbols.sort_by_key(|&(f, i)| (f, i));
let code_lengths = Self::package_merge(&symbols, self.max_length as usize);
for (i, (_, symbol)) in symbols.iter().enumerate() {
lengths[*symbol] = code_lengths[i];
}
lengths
}
fn package_merge(symbols: &[(u32, usize)], max_len: usize) -> Vec<u8> {
let n = symbols.len();
let min_bits = {
let mut b = 1usize;
while (1usize << b) < n {
b += 1;
}
b
};
let limit = max_len.max(min_bits);
#[derive(Clone)]
struct Item {
weight: u64,
coverage: Vec<usize>,
}
let base: Vec<Item> = symbols
.iter()
.enumerate()
.map(|(idx, &(w, _))| Item {
weight: w as u64,
coverage: vec![idx],
})
.collect();
let mut prev: Vec<Item> = base.clone();
for _ in 1..limit {
let mut packages: Vec<Item> = Vec::with_capacity(prev.len() / 2);
let mut i = 0;
while i + 1 < prev.len() {
let a = &prev[i];
let b = &prev[i + 1];
let mut coverage = Vec::with_capacity(a.coverage.len() + b.coverage.len());
coverage.extend_from_slice(&a.coverage);
coverage.extend_from_slice(&b.coverage);
packages.push(Item {
weight: a.weight + b.weight,
coverage,
});
i += 2;
}
let mut merged: Vec<Item> = Vec::with_capacity(base.len() + packages.len());
let mut bi = 0;
let mut pi = 0;
while bi < base.len() || pi < packages.len() {
let take_base = match (base.get(bi), packages.get(pi)) {
(Some(b), Some(p)) => b.weight <= p.weight,
(Some(_), None) => true,
(None, Some(_)) => false,
(None, None) => break,
};
if take_base {
merged.push(base[bi].clone());
bi += 1;
} else {
merged.push(packages[pi].clone());
pi += 1;
}
}
prev = merged;
}
let select = 2 * n - 2;
let mut lengths = vec![0u8; n];
for item in prev.iter().take(select) {
for &sym_idx in &item.coverage {
lengths[sym_idx] = lengths[sym_idx].saturating_add(1);
}
}
for l in lengths.iter_mut() {
if *l == 0 {
*l = 1;
}
if *l as usize > limit {
*l = limit as u8;
}
}
lengths
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_huffman_tree_simple() {
let lengths = [1u8, 2, 2];
let tree = HuffmanTree::from_code_lengths(&lengths).expect("build huffman tree");
let data = vec![0b00011010u8];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(tree.decode(&mut reader).expect("decode symbol A"), 0); assert_eq!(tree.decode(&mut reader).expect("decode symbol B"), 1); assert_eq!(tree.decode(&mut reader).expect("decode symbol C"), 2); assert_eq!(tree.decode(&mut reader).expect("decode symbol A again"), 0); }
#[test]
fn test_huffman_builder() {
let mut builder = HuffmanBuilder::new(4, 15);
builder.add_count(0, 100); builder.add_count(1, 50);
builder.add_count(2, 25);
builder.add_count(3, 25);
let lengths = builder.build_lengths();
assert!(lengths[0] <= lengths[1]);
assert!(lengths[1] <= lengths[2]);
assert!(lengths[0] > 0);
assert!(lengths[1] > 0);
assert!(lengths[2] > 0);
assert!(lengths[3] > 0);
}
#[test]
fn test_empty_tree() {
let lengths: [u8; 4] = [0, 0, 0, 0];
let tree = HuffmanTree::from_code_lengths(&lengths).expect("build empty huffman tree");
assert_eq!(tree.max_code_length, 0);
}
#[test]
fn test_single_symbol() {
let lengths = [1u8, 0, 0, 0];
let tree = HuffmanTree::from_code_lengths(&lengths).expect("build single symbol tree");
let data = vec![0b00000000u8];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(tree.decode(&mut reader).expect("decode single symbol"), 0);
}
#[test]
fn test_reverse_bits() {
assert_eq!(HuffmanTree::reverse_bits(0b101, 3), 0b101);
assert_eq!(HuffmanTree::reverse_bits(0b1100, 4), 0b0011);
assert_eq!(HuffmanTree::reverse_bits(0b10101010, 8), 0b01010101);
}
}