use super::table::HuffmanTable;
use crate::fse::BitReader;
use haagenti_core::{Error, Result};
#[derive(Debug)]
pub struct HuffmanDecoder<'a> {
table: &'a HuffmanTable,
}
impl<'a> HuffmanDecoder<'a> {
pub fn new(table: &'a HuffmanTable) -> Self {
Self { table }
}
pub fn decode_symbol(&self, bits: &mut BitReader) -> Result<u8> {
let max_bits = self.table.max_bits() as usize;
let peek_value = bits.peek_bits_padded(max_bits)? as usize;
let entry = self.table.decode(peek_value);
bits.read_bits(entry.num_bits as usize)?;
Ok(entry.symbol)
}
pub fn table(&self) -> &HuffmanTable {
self.table
}
}
pub fn parse_huffman_weights(data: &[u8]) -> Result<(Vec<u8>, usize)> {
if data.is_empty() {
return Err(Error::corrupted("Empty Huffman header"));
}
let header_byte = data[0];
if header_byte < 128 {
parse_fse_compressed_weights(data)
} else {
parse_direct_weights(data)
}
}
fn parse_fse_compressed_weights(data: &[u8]) -> Result<(Vec<u8>, usize)> {
if data.is_empty() {
return Err(Error::corrupted("Empty FSE header for Huffman weights"));
}
let compressed_size = data[0] as usize;
if compressed_size == 0 {
return Err(Error::corrupted("Zero compressed size for Huffman weights"));
}
let total_header_size = 1 + compressed_size;
if data.len() < total_header_size {
return Err(Error::corrupted(format!(
"Huffman header too short: need {} bytes, have {}",
total_header_size,
data.len()
)));
}
let compressed = &data[1..total_header_size];
let weights = decompress_huffman_weights_fse(compressed)?;
Ok((weights, total_header_size))
}
fn parse_direct_weights(data: &[u8]) -> Result<(Vec<u8>, usize)> {
if data.is_empty() {
return Err(Error::corrupted("Empty direct weights header"));
}
let header_byte = data[0];
let num_symbols = (header_byte - 127) as usize;
if num_symbols == 0 || num_symbols > super::HUFFMAN_MAX_SYMBOLS {
return Err(Error::corrupted(format!(
"Invalid number of Huffman symbols: {}",
num_symbols
)));
}
let num_weight_bytes = num_symbols.div_ceil(2);
let total_header_size = 1 + num_weight_bytes;
if data.len() < total_header_size {
return Err(Error::corrupted(format!(
"Direct weights header too short: need {} bytes, have {}",
total_header_size,
data.len()
)));
}
let mut weights = Vec::with_capacity(num_symbols);
for i in 0..num_symbols {
let byte_idx = 1 + i / 2;
let weight = if i % 2 == 0 {
data[byte_idx] >> 4
} else {
data[byte_idx] & 0x0F
};
weights.push(weight);
}
Ok((weights, total_header_size))
}
fn decompress_huffman_weights_fse(data: &[u8]) -> Result<Vec<u8>> {
use crate::fse::{BitReader, FseDecoder, FseTable};
if data.is_empty() {
return Err(Error::corrupted("Empty FSE data for Huffman weights"));
}
const MAX_WEIGHT_SYMBOL: u8 = 12;
let (table, header_bytes) = FseTable::parse(data, MAX_WEIGHT_SYMBOL)?;
let accuracy_log = table.accuracy_log();
if !(5..=7).contains(&accuracy_log) {
return Err(Error::corrupted(format!(
"Huffman weight FSE accuracy log {} outside valid range 5-7",
accuracy_log
)));
}
let bitstream = &data[header_bytes..];
if bitstream.is_empty() {
return Err(Error::corrupted("No bitstream data after FSE header"));
}
let mut bits = BitReader::new_reversed(bitstream)?;
let mut decoder = FseDecoder::new(&table);
decoder.init_state(&mut bits)?;
let mut weights = Vec::with_capacity(256);
loop {
let bits_needed = decoder.peek_num_bits() as usize;
if bits.bits_remaining() < bits_needed {
let final_weight = decoder.peek_symbol();
if final_weight <= MAX_WEIGHT_SYMBOL {
weights.push(final_weight);
}
break;
}
let weight = decoder.decode_symbol(&mut bits)?;
if weight > MAX_WEIGHT_SYMBOL {
return Err(Error::corrupted(format!(
"Invalid Huffman weight {} (max {})",
weight, MAX_WEIGHT_SYMBOL
)));
}
weights.push(weight);
if weights.len() > super::HUFFMAN_MAX_SYMBOLS {
return Err(Error::corrupted("Too many Huffman symbols decoded"));
}
}
if weights.is_empty() {
return Err(Error::corrupted(
"No Huffman weights decoded from FSE stream",
));
}
Ok(weights)
}
pub fn build_table_from_weights(mut weights: Vec<u8>) -> Result<HuffmanTable> {
if weights.is_empty() {
return Err(Error::corrupted("Empty Huffman weights"));
}
let max_explicit_weight = *weights.iter().max().unwrap_or(&0);
if max_explicit_weight == 0 {
return Err(Error::corrupted("All explicit Huffman weights are zero"));
}
let weight_sum: u32 = weights.iter().filter(|&&w| w > 0).map(|&w| 1u32 << w).sum();
let target = weight_sum.next_power_of_two();
let remaining = target - weight_sum;
if remaining > 0 {
let implicit_weight = (32 - remaining.leading_zeros() - 1) as u8;
weights.push(implicit_weight);
}
HuffmanTable::from_weights(&weights)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_creation() {
let weights = [2u8, 1, 1];
let table = HuffmanTable::from_weights(&weights).unwrap();
let decoder = HuffmanDecoder::new(&table);
assert_eq!(decoder.table().num_symbols(), 3);
}
#[test]
fn test_decode_simple_symbols() {
let weights = [2u8, 1, 1];
let table = HuffmanTable::from_weights(&weights).unwrap();
let decoder = HuffmanDecoder::new(&table);
let data = [0b11_10_01_00u8]; let mut bits = BitReader::new(&data);
let sym0 = decoder.decode_symbol(&mut bits).unwrap();
assert_eq!(sym0, 0);
}
#[test]
fn test_direct_weights_parsing() {
let data = [0x83, 0x21, 0x10]; let (weights, consumed) = parse_direct_weights(&data).unwrap();
assert_eq!(consumed, 3); assert_eq!(weights, vec![2, 1, 1, 0]);
}
#[test]
fn test_direct_weights_odd_count() {
let data = [0x82, 0x32, 0x10];
let (weights, consumed) = parse_direct_weights(&data).unwrap();
assert_eq!(consumed, 3); assert_eq!(weights, vec![3, 2, 1]);
}
#[test]
fn test_direct_weights_single_symbol() {
let data = [0x80, 0x40];
let (weights, consumed) = parse_direct_weights(&data).unwrap();
assert_eq!(consumed, 2);
assert_eq!(weights, vec![4]);
}
#[test]
fn test_fse_header_detection() {
let data = [0x10, 0x00, 0x00]; let result = parse_huffman_weights(&data);
assert!(result.is_err());
}
#[test]
fn test_empty_header_error() {
let result = parse_huffman_weights(&[]);
assert!(result.is_err());
}
#[test]
fn test_direct_weights_too_short() {
let data = [0x83, 0x21]; let result = parse_direct_weights(&data);
assert!(result.is_err());
}
#[test]
fn test_build_table_with_implicit_weight() {
let weights = vec![2u8, 1];
let table = build_table_from_weights(weights).unwrap();
assert_eq!(table.num_symbols(), 3);
assert_eq!(table.max_bits(), 2);
}
#[test]
fn test_build_table_no_implicit_needed() {
let weights = vec![1u8, 1];
let table = build_table_from_weights(weights).unwrap();
assert_eq!(table.num_symbols(), 2);
}
#[test]
fn test_build_table_empty_error() {
let result = build_table_from_weights(vec![]);
assert!(result.is_err());
}
#[test]
fn test_build_table_all_zero_error() {
let result = build_table_from_weights(vec![0, 0, 0]);
assert!(result.is_err());
}
#[test]
fn test_decode_multiple_symbols() {
let weights = [2u8, 1, 1]; let table = HuffmanTable::from_weights(&weights).unwrap();
let decoder = HuffmanDecoder::new(&table);
let data = [0b00000000u8, 0b00000000]; let mut bits = BitReader::new(&data);
for _ in 0..8 {
let sym = decoder.decode_symbol(&mut bits).unwrap();
assert_eq!(sym, 0);
}
}
}