use crate::fse::{FseBitReader, FseDecoder, read_fse_table_description};
use oxiarc_core::error::{OxiArcError, Result};
pub const MAX_CODE_LENGTH: u8 = 11;
pub const MAX_SYMBOLS: usize = 256;
#[derive(Debug, Clone, Copy, Default)]
pub struct HuffmanEntry {
pub symbol: u8,
pub num_bits: u8,
}
#[derive(Debug, Clone)]
pub struct HuffmanTable {
entries: Vec<HuffmanEntry>,
max_bits: u8,
}
impl HuffmanTable {
pub fn from_weights(weights: &[u8]) -> Result<Self> {
if weights.is_empty() {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "empty Huffman weights".to_string(),
});
}
let mut total_weight = 0u32;
let mut max_weight = 0u8;
for &w in weights {
if w > 0 {
total_weight += 1u32 << (w - 1);
if w > max_weight {
max_weight = w;
}
}
}
if total_weight == 0 {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "all Huffman weights are zero".to_string(),
});
}
let max_bits = 32 - total_weight.leading_zeros();
let max_bits = max_bits.min(MAX_CODE_LENGTH as u32) as u8;
let table_size = 1usize << max_bits;
let mut entries = vec![HuffmanEntry::default(); table_size];
let mut code = 0u32;
let mut code_lengths = vec![0u8; weights.len()];
for (symbol, &weight) in weights.iter().enumerate() {
if weight > 0 {
code_lengths[symbol] = max_bits + 1 - weight;
}
}
let mut symbols: Vec<(usize, u8)> = weights
.iter()
.enumerate()
.filter(|&(_, w)| *w > 0)
.map(|(s, _)| (s, code_lengths[s]))
.collect();
symbols.sort_by_key(|&(_, len)| len);
let mut prev_len = 0u8;
for (symbol, len) in symbols {
if len > prev_len {
code <<= len - prev_len;
prev_len = len;
}
let num_entries = 1 << (max_bits - len);
let base_code = (code as usize) << (max_bits - len);
for i in 0..num_entries {
entries[base_code + i] = HuffmanEntry {
symbol: symbol as u8,
num_bits: len,
};
}
code += 1;
}
Ok(Self { entries, max_bits })
}
#[inline]
pub fn decode(&self, bits: u32) -> &HuffmanEntry {
let idx = bits as usize & ((1 << self.max_bits) - 1);
&self.entries[idx]
}
pub fn max_bits(&self) -> u8 {
self.max_bits
}
}
pub fn read_huffman_table(data: &[u8]) -> Result<(HuffmanTable, usize)> {
if data.is_empty() {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "empty Huffman table data".to_string(),
});
}
let header = data[0];
if header < 128 {
read_huffman_table_fse(data)
} else {
read_huffman_table_direct(data)
}
}
fn read_huffman_table_direct(data: &[u8]) -> Result<(HuffmanTable, usize)> {
let header = data[0];
let num_symbols = (header - 127) as usize;
if num_symbols == 0 || num_symbols > MAX_SYMBOLS {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: format!("invalid number of Huffman symbols: {}", num_symbols),
});
}
let bytes_needed = num_symbols.div_ceil(2);
if data.len() < 1 + bytes_needed {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "truncated Huffman table".to_string(),
});
}
let mut weights = vec![0u8; num_symbols];
for (i, weight) in weights.iter_mut().enumerate() {
let byte_idx = 1 + i / 2;
let is_high = i % 2 == 0;
*weight = if is_high {
data[byte_idx] >> 4
} else {
data[byte_idx] & 0x0F
};
}
let table = HuffmanTable::from_weights(&weights)?;
Ok((table, 1 + bytes_needed))
}
fn read_huffman_table_fse(data: &[u8]) -> Result<(HuffmanTable, usize)> {
let compressed_size = data[0] as usize;
if compressed_size == 0 {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "zero-length FSE Huffman table".to_string(),
});
}
if data.len() < 1 + compressed_size {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "truncated FSE Huffman table".to_string(),
});
}
let fse_data = &data[1..1 + compressed_size];
let (fse_table, fse_bytes) = read_fse_table_description(fse_data, 12)?;
let bitstream_data = &fse_data[fse_bytes..];
let mut reader = FseBitReader::new(bitstream_data)?;
let mut decoder = FseDecoder::new(&fse_table, &mut reader);
let mut weights = Vec::with_capacity(MAX_SYMBOLS);
while weights.len() < MAX_SYMBOLS && !reader.is_empty() {
let weight = decoder.decode(&mut reader);
weights.push(weight);
}
if weights.is_empty() {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "no Huffman weights decoded".to_string(),
});
}
let table = HuffmanTable::from_weights(&weights)?;
Ok((table, 1 + compressed_size))
}
pub struct HuffmanBitReader<'a> {
data: &'a [u8],
bit_pos: usize,
total_bits: usize,
}
impl<'a> HuffmanBitReader<'a> {
pub fn new(data: &'a [u8]) -> Result<Self> {
if data.is_empty() {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "empty Huffman bitstream".to_string(),
});
}
let last_byte = data[data.len() - 1];
if last_byte == 0 {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "Huffman stream ends with zero".to_string(),
});
}
let padding = 7 - (31 - last_byte.leading_zeros()) as usize;
let total_bits = data.len() * 8 - padding - 1;
Ok(Self {
data,
bit_pos: 0,
total_bits,
})
}
pub fn peek_bits(&self, n: u8) -> u32 {
if n == 0 || self.bit_pos >= self.total_bits {
return 0;
}
let read_pos = self.total_bits - self.bit_pos - 1;
let byte_pos = read_pos / 8;
let bit_offset = read_pos % 8;
let mut value = 0u32;
for i in 0..3 {
if byte_pos >= i && byte_pos - i < self.data.len() {
value |= (self.data[byte_pos - i] as u32) << (i * 8);
}
}
(value >> (24 - bit_offset - n as usize)) & ((1 << n) - 1)
}
pub fn consume(&mut self, n: u8) {
self.bit_pos += n as usize;
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.bit_pos >= self.total_bits
}
#[allow(dead_code)]
pub fn remaining(&self) -> usize {
self.total_bits.saturating_sub(self.bit_pos)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_huffman_table_from_weights() {
let weights = [1u8, 1];
let table = HuffmanTable::from_weights(&weights).unwrap();
assert!(table.max_bits() >= 1);
}
#[test]
fn test_huffman_table_varying_weights() {
let weights = [4u8, 3, 2, 1, 1, 0, 0, 0];
let table = HuffmanTable::from_weights(&weights).unwrap();
assert!(table.max_bits() > 0);
}
#[test]
fn test_direct_huffman_table() {
let mut data = vec![127 + 4]; data.push(0x21); data.push(0x11);
let (table, consumed) = read_huffman_table(&data).unwrap();
assert_eq!(consumed, 3);
assert!(table.max_bits() > 0);
}
#[test]
fn test_empty_weights_fails() {
let weights: [u8; 0] = [];
assert!(HuffmanTable::from_weights(&weights).is_err());
}
}