use oxiarc_core::BitReader;
use oxiarc_core::error::{OxiArcError, Result};
use std::io::Read;
pub const MAX_CODE_LENGTH: usize = 15;
pub const LITLEN_ALPHABET_SIZE: usize = 286;
pub const DISTANCE_ALPHABET_SIZE: usize = 30;
pub const CODELEN_ALPHABET_SIZE: usize = 19;
pub const END_OF_BLOCK: u16 = 256;
#[derive(Debug, Clone)]
pub struct HuffmanTree {
fast_table: Vec<(u16, u8)>,
fast_bits: u8,
max_code_length: u8,
symbols: Vec<u16>,
base_codes: [u32; MAX_CODE_LENGTH + 1],
symbol_offsets: [u16; MAX_CODE_LENGTH + 1],
}
impl HuffmanTree {
const FAST_BITS: u8 = 9;
pub fn from_code_lengths(code_lengths: &[u8]) -> Result<Self> {
if code_lengths.is_empty() {
return Err(OxiArcError::invalid_header("Empty code lengths"));
}
let mut bl_count = [0u32; MAX_CODE_LENGTH + 1];
let mut max_length = 0u8;
for &len in code_lengths {
if len > 0 {
if len as usize > MAX_CODE_LENGTH {
return Err(OxiArcError::invalid_header(format!(
"Code length {} exceeds maximum {}",
len, MAX_CODE_LENGTH
)));
}
bl_count[len as usize] += 1;
max_length = max_length.max(len);
}
}
if max_length == 0 {
return Ok(Self {
fast_table: vec![(0, 0); 1 << Self::FAST_BITS],
fast_bits: Self::FAST_BITS,
max_code_length: 0,
symbols: Vec::new(),
base_codes: [0; MAX_CODE_LENGTH + 1],
symbol_offsets: [0; MAX_CODE_LENGTH + 1],
});
}
let mut next_code = [0u32; MAX_CODE_LENGTH + 1];
let mut code = 0u32;
for bits in 1..=max_length as usize {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let total_codes: u32 = bl_count[1..=max_length as usize].iter().sum();
if total_codes > 0 {
let max_codes = 1u32 << max_length;
if code + bl_count[max_length as usize] > max_codes {
return Err(OxiArcError::invalid_header("Over-subscribed Huffman tree"));
}
}
let mut symbols = vec![0u16; total_codes as usize];
let mut symbol_offsets = [0u16; MAX_CODE_LENGTH + 1];
let mut base_codes = [0u32; MAX_CODE_LENGTH + 1];
let mut offset = 0u16;
for bits in 1..=max_length as usize {
symbol_offsets[bits] = offset;
base_codes[bits] = next_code[bits];
offset += bl_count[bits] as u16;
}
if max_length < MAX_CODE_LENGTH as u8 {
symbol_offsets[max_length as usize + 1] = offset;
}
let mut current_code = next_code;
for (symbol, &len) in code_lengths.iter().enumerate() {
if len > 0 {
let len = len as usize;
let idx =
symbol_offsets[len] as usize + (current_code[len] - base_codes[len]) as usize;
if idx < symbols.len() {
symbols[idx] = symbol as u16;
}
current_code[len] += 1;
}
}
let fast_bits = Self::FAST_BITS.min(max_length);
let fast_table_size = 1 << fast_bits;
let mut fast_table = vec![(0u16, 0u8); fast_table_size];
for (symbol, &len) in code_lengths.iter().enumerate() {
if len > 0 && len <= fast_bits {
let len = len as usize;
let code = Self::reverse_bits(next_code[len] as u16, len as u8);
next_code[len] += 1;
let fill_count = 1 << (fast_bits - len as u8);
for i in 0..fill_count {
let index = code as usize | (i << len);
if index < fast_table_size {
fast_table[index] = (symbol as u16, len as u8);
}
}
}
}
Ok(Self {
fast_table,
fast_bits,
max_code_length: max_length,
symbols,
base_codes,
symbol_offsets,
})
}
fn reverse_bits(mut code: u16, length: u8) -> u16 {
let mut reversed = 0u16;
for _ in 0..length {
reversed = (reversed << 1) | (code & 1);
code >>= 1;
}
reversed
}
#[inline]
pub fn decode<R: Read>(&self, reader: &mut BitReader<R>) -> Result<u16> {
if self.max_code_length == 0 {
return Err(OxiArcError::invalid_huffman(reader.bit_position()));
}
match reader.peek_bits(self.fast_bits) {
Ok(bits) => {
let (symbol, len) = unsafe {
*self.fast_table.get_unchecked(bits as usize)
};
if len > 0 {
reader.skip_bits(len)?;
return Ok(symbol);
}
self.decode_slow(reader)
}
Err(_) => {
self.decode_slow(reader)
}
}
}
fn decode_slow<R: Read>(&self, reader: &mut BitReader<R>) -> Result<u16> {
let mut code = 0u32;
for len in 1..=self.max_code_length as usize {
let bit = reader.read_bits(1)?;
code = (code << 1) | bit;
let count = if len < MAX_CODE_LENGTH {
self.symbol_offsets[len + 1] - self.symbol_offsets[len]
} else {
self.symbols.len() as u16 - self.symbol_offsets[len]
};
if count > 0 && code >= self.base_codes[len] {
let idx = code - self.base_codes[len];
if idx < count as u32 {
let symbol_idx = self.symbol_offsets[len] as usize + idx as usize;
if symbol_idx < self.symbols.len() {
return Ok(self.symbols[symbol_idx]);
}
}
}
}
Err(OxiArcError::invalid_huffman(reader.bit_position()))
}
}
#[derive(Debug)]
pub struct HuffmanBuilder {
frequencies: Vec<u32>,
max_length: u8,
}
impl HuffmanBuilder {
pub fn new(alphabet_size: usize, max_length: u8) -> Self {
Self {
frequencies: vec![0; alphabet_size],
max_length,
}
}
pub fn add(&mut self, symbol: u16) {
if (symbol as usize) < self.frequencies.len() {
self.frequencies[symbol as usize] += 1;
}
}
pub fn add_count(&mut self, symbol: u16, count: u32) {
if (symbol as usize) < self.frequencies.len() {
self.frequencies[symbol as usize] += count;
}
}
pub fn build_lengths(&self) -> Vec<u8> {
let n = self.frequencies.len();
let mut lengths = vec![0u8; n];
let mut symbols: Vec<(u32, usize)> = self
.frequencies
.iter()
.enumerate()
.filter(|&(_, f)| *f > 0)
.map(|(i, f)| (*f, i))
.collect();
if symbols.is_empty() {
return lengths;
}
if symbols.len() == 1 {
lengths[symbols[0].1] = 1;
return lengths;
}
symbols.sort_by_key(|&(f, i)| (f, i));
let code_lengths = self.package_merge(&symbols);
for (i, (_, symbol)) in symbols.iter().enumerate() {
lengths[*symbol] = code_lengths[i];
}
lengths
}
fn package_merge(&self, symbols: &[(u32, usize)]) -> Vec<u8> {
let n = symbols.len();
let max_len = self.max_length as usize;
let mut lengths = vec![0u8; n];
let total: f64 = symbols.iter().map(|(f, _)| *f as f64).sum();
for (i, (freq, _)) in symbols.iter().enumerate() {
if *freq > 0 {
let prob = *freq as f64 / total;
let ideal_len = (-prob.log2()).ceil() as u8;
lengths[i] = ideal_len.max(1).min(self.max_length);
}
}
self.adjust_lengths(&mut lengths, max_len);
lengths
}
fn adjust_lengths(&self, lengths: &mut [u8], max_len: usize) {
loop {
let kraft_sum: f64 = lengths
.iter()
.filter(|&&l| l > 0)
.map(|&l| 2.0f64.powi(-(l as i32)))
.sum();
if kraft_sum <= 1.0 {
return; }
let mut candidates: Vec<usize> = (0..lengths.len())
.filter(|&i| lengths[i] > 0 && lengths[i] < max_len as u8)
.collect();
if candidates.is_empty() {
return;
}
candidates.sort_by(|&a, &b| lengths[b].cmp(&lengths[a]));
let mut made_progress = false;
for &i in &candidates {
if lengths[i] < max_len as u8 {
lengths[i] += 1;
made_progress = true;
let new_kraft: f64 = lengths
.iter()
.filter(|&&l| l > 0)
.map(|&l| 2.0f64.powi(-(l as i32)))
.sum();
if new_kraft <= 1.0 {
return; }
}
}
if !made_progress {
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_huffman_tree_simple() {
let lengths = [1u8, 2, 2];
let tree = HuffmanTree::from_code_lengths(&lengths).unwrap();
let data = vec![0b00011010u8];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(tree.decode(&mut reader).unwrap(), 0); assert_eq!(tree.decode(&mut reader).unwrap(), 1); assert_eq!(tree.decode(&mut reader).unwrap(), 2); assert_eq!(tree.decode(&mut reader).unwrap(), 0); }
#[test]
fn test_huffman_builder() {
let mut builder = HuffmanBuilder::new(4, 15);
builder.add_count(0, 100); builder.add_count(1, 50);
builder.add_count(2, 25);
builder.add_count(3, 25);
let lengths = builder.build_lengths();
assert!(lengths[0] <= lengths[1]);
assert!(lengths[1] <= lengths[2]);
assert!(lengths[0] > 0);
assert!(lengths[1] > 0);
assert!(lengths[2] > 0);
assert!(lengths[3] > 0);
}
#[test]
fn test_empty_tree() {
let lengths: [u8; 4] = [0, 0, 0, 0];
let tree = HuffmanTree::from_code_lengths(&lengths).unwrap();
assert_eq!(tree.max_code_length, 0);
}
#[test]
fn test_single_symbol() {
let lengths = [1u8, 0, 0, 0];
let tree = HuffmanTree::from_code_lengths(&lengths).unwrap();
let data = vec![0b00000000u8];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(tree.decode(&mut reader).unwrap(), 0);
}
#[test]
fn test_reverse_bits() {
assert_eq!(HuffmanTree::reverse_bits(0b101, 3), 0b101);
assert_eq!(HuffmanTree::reverse_bits(0b1100, 4), 0b0011);
assert_eq!(HuffmanTree::reverse_bits(0b10101010, 8), 0b01010101);
}
}