use crate::error::Error;
use super::bitreader::BitReader;
pub const MAX_BITS: u32 = 15;
#[derive(Debug, Clone)]
pub struct Rar2Huffman<const N: usize> {
counts: [u16; (MAX_BITS as usize) + 1],
symbols: [u16; N],
first_code: [u32; (MAX_BITS as usize) + 1],
first_idx: [u16; (MAX_BITS as usize) + 1],
max_length: u8,
n_symbols: u16,
}
impl<const N: usize> Rar2Huffman<N> {
pub fn from_lengths(code_lengths: &[u8]) -> Result<Self, Error> {
let n = code_lengths.len();
assert!(n <= N);
let mut counts = [0u16; (MAX_BITS as usize) + 1];
let mut max_length: u8 = 0;
for &len in code_lengths {
if len as u32 > MAX_BITS {
return Err(Error::InvalidHuffmanTree);
}
if len > 0 {
counts[len as usize] += 1;
if len > max_length {
max_length = len;
}
}
}
if max_length == 0 {
return Ok(Self {
counts,
symbols: [0u16; N],
first_code: [0u32; (MAX_BITS as usize) + 1],
first_idx: [0u16; (MAX_BITS as usize) + 1],
max_length: 0,
n_symbols: n as u16,
});
}
let mut kraft: u64 = 0;
for l in 1..=MAX_BITS {
kraft += (counts[l as usize] as u64) << (MAX_BITS - l);
}
if kraft > (1u64 << MAX_BITS) {
return Err(Error::InvalidHuffmanTree);
}
let mut first_code = [0u32; (MAX_BITS as usize) + 1];
let mut first_idx = [0u16; (MAX_BITS as usize) + 1];
let mut code: u32 = 0;
let mut idx: u16 = 0;
for l in 1..=(MAX_BITS as usize) {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l] as u32;
idx += counts[l];
}
let mut symbols = [0u16; N];
let mut next = first_idx;
for (sym, &len) in code_lengths.iter().enumerate() {
if len > 0 {
symbols[next[len as usize] as usize] = sym as u16;
next[len as usize] += 1;
}
}
Ok(Self {
counts,
symbols,
first_code,
first_idx,
max_length,
n_symbols: n as u16,
})
}
#[allow(dead_code)]
pub const fn is_empty(&self) -> bool {
self.max_length == 0
}
pub fn decode(&self, reader: &mut BitReader, input: &[u8]) -> Result<u16, Error> {
if self.max_length == 0 {
return Err(Error::InvalidHuffmanTree);
}
let max = self.max_length as u32;
let (peeked, have) = reader.peek_up_to(max, input);
for length in 1..=max {
let count = self.counts[length as usize] as u32;
if count == 0 {
continue;
}
let code = peeked >> (max - length);
let first = self.first_code[length as usize];
if code >= first && code < first + count {
if length > have {
return Err(Error::UnexpectedEnd);
}
let sym_idx = self.first_idx[length as usize] as u32 + (code - first);
if sym_idx >= self.n_symbols as u32 {
return Err(Error::InvalidHuffmanTree);
}
reader.drop_bits(length);
return Ok(self.symbols[sym_idx as usize]);
}
}
if have < max {
return Err(Error::UnexpectedEnd);
}
Err(Error::InvalidHuffmanTree)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn small_tree_roundtrip() {
let h = Rar2Huffman::<4>::from_lengths(&[2, 1, 3, 3]).unwrap();
let input = [0x5B, 0xC0];
let mut r = BitReader::new();
assert_eq!(h.decode(&mut r, &input).unwrap(), 1);
assert_eq!(h.decode(&mut r, &input).unwrap(), 0);
assert_eq!(h.decode(&mut r, &input).unwrap(), 2);
assert_eq!(h.decode(&mut r, &input).unwrap(), 3);
}
#[test]
fn empty_tree_rejects() {
let h = Rar2Huffman::<4>::from_lengths(&[0, 0, 0, 0]).unwrap();
assert!(h.is_empty());
let mut r = BitReader::new();
let input = [0xFF];
assert!(matches!(
h.decode(&mut r, &input),
Err(Error::InvalidHuffmanTree)
));
}
#[test]
fn over_long_code_rejected() {
let mut lens = [0u8; 2];
lens[0] = 16; assert!(Rar2Huffman::<2>::from_lengths(&lens).is_err());
}
#[test]
fn over_full_kraft_rejected() {
assert!(Rar2Huffman::<3>::from_lengths(&[1, 1, 2]).is_err());
}
#[test]
fn unexpected_end_when_bits_short() {
let h = Rar2Huffman::<5>::from_lengths(&[3, 3, 2, 2, 2]).unwrap();
let input: [u8; 0] = [];
let mut r = BitReader::new();
assert!(matches!(
h.decode(&mut r, &input),
Err(Error::UnexpectedEnd)
));
}
#[test]
fn single_symbol_tree() {
let h = Rar2Huffman::<1>::from_lengths(&[1]).unwrap();
let input = [0x00];
let mut r = BitReader::new();
assert_eq!(h.decode(&mut r, &input).unwrap(), 0);
}
}