use std::io::{self, Error, Seek};
use bitstream_io::{BitRead, BitReader, Endianness};
#[derive(Default, Clone, Copy)]
pub struct TableEntry {
pub sym: u16,
pub len: u8,
}
const MAX_HUFFMAN_SYMBOLS: usize = 288;
const MAX_HUFFMAN_BITS: usize = 16;
const HUFFMAN_LOOKUP_TABLE_BITS: u8 = 8;
pub struct HuffmanDecoder {
pub table: [TableEntry; 1 << HUFFMAN_LOOKUP_TABLE_BITS],
pub sentinel_bits: [u32; MAX_HUFFMAN_BITS + 1],
pub offset_first_sym_idx: [u16; MAX_HUFFMAN_BITS + 1],
pub syms: [u16; MAX_HUFFMAN_SYMBOLS],
}
impl Default for HuffmanDecoder {
fn default() -> Self {
Self {
table: [TableEntry::default(); 1 << HUFFMAN_LOOKUP_TABLE_BITS],
sentinel_bits: [0; MAX_HUFFMAN_BITS + 1],
offset_first_sym_idx: [0; MAX_HUFFMAN_BITS + 1],
syms: [0; MAX_HUFFMAN_SYMBOLS],
}
}
}
pub fn reverse_lsb(x: u16, n: usize) -> u16 {
debug_assert!(n > 0);
debug_assert!(n <= 16);
x.reverse_bits() >> (16 - n)
}
impl HuffmanDecoder {
pub fn init(&mut self, lengths: &[u8], n: usize) -> std::io::Result<()> {
let mut count = [0; MAX_HUFFMAN_BITS + 1];
let mut code = [0; MAX_HUFFMAN_BITS + 1];
let mut sym_idx = [0u16; MAX_HUFFMAN_BITS + 1];
for i in 0..n {
debug_assert!(lengths[i] as usize <= MAX_HUFFMAN_BITS);
count[lengths[i] as usize] += 1;
}
count[0] = 0; code[0] = 0;
sym_idx[0] = 0;
for l in 1..=MAX_HUFFMAN_BITS {
code[l] = (code[l - 1] + count[l - 1]) << 1;
if count[l] != 0 && u32::from(code[l]) + u32::from(count[l]) - 1 > (1u32 << l) - 1 {
return Err(Error::new(
io::ErrorKind::InvalidData,
"the last codeword is longer than len bits",
));
}
let s = (u32::from(code[l]) + u32::from(count[l])) << (MAX_HUFFMAN_BITS - l);
self.sentinel_bits[l] = s;
debug_assert!(self.sentinel_bits[l] >= u32::from(code[l]), "No overflow!");
sym_idx[l] = sym_idx[l - 1] + count[l - 1];
self.offset_first_sym_idx[l] = sym_idx[l].wrapping_sub(code[l]);
}
self.table.fill(TableEntry::default());
lengths
.iter()
.enumerate()
.take(n)
.for_each(|(i, code_len)| {
let l = *code_len as usize;
if l == 0 {
return;
}
self.syms[sym_idx[l] as usize] = i as u16;
sym_idx[l] += 1;
if l <= HUFFMAN_LOOKUP_TABLE_BITS as usize {
self.table_insert(i, l, code[l]);
code[l] += 1;
}
});
Ok(())
}
pub fn table_insert(&mut self, sym: usize, len: usize, codeword: u16) {
debug_assert!(len <= HUFFMAN_LOOKUP_TABLE_BITS as usize);
let codeword = reverse_lsb(codeword, len); let pad_len = HUFFMAN_LOOKUP_TABLE_BITS as usize - len;
for padding in 0..(1 << pad_len) {
let index = (codeword | (padding << len)) as usize;
debug_assert!(sym <= u16::MAX as usize);
self.table[index].sym = sym as u16;
debug_assert!(len <= u8::MAX as usize);
self.table[index].len = len as u8;
}
}
pub fn huffman_decode<T: std::io::Read + Seek, E: Endianness>(
&mut self,
length: u64,
is: &mut BitReader<T, E>,
) -> std::io::Result<u16> {
let read_bits1 = u64::from(HUFFMAN_LOOKUP_TABLE_BITS).min(length - is.position_in_bits()?);
let lookup_bits = !is.read_var::<u8>(read_bits1 as u32)? as usize;
debug_assert!(lookup_bits < self.table.len());
if self.table[lookup_bits].len != 0 {
debug_assert!(self.table[lookup_bits].len <= HUFFMAN_LOOKUP_TABLE_BITS);
is.seek_bits(io::SeekFrom::Current(
-(read_bits1 as i64) + i64::from(self.table[lookup_bits].len),
))?;
return Ok(self.table[lookup_bits].sym);
}
let read_bits2 = u64::from(HUFFMAN_LOOKUP_TABLE_BITS).min(length - is.position_in_bits()?);
let mut bits = reverse_lsb(
(lookup_bits | ((!is.read_var::<u8>(read_bits2 as u32)? as usize) << read_bits1))
as u16,
MAX_HUFFMAN_BITS,
);
for l in (HUFFMAN_LOOKUP_TABLE_BITS as usize + 1)..=MAX_HUFFMAN_BITS {
if u32::from(bits) < self.sentinel_bits[l] {
bits >>= MAX_HUFFMAN_BITS - l;
let sym_idx = (self.offset_first_sym_idx[l] as usize + bits as usize) & 0xFFFF;
is.seek_bits(io::SeekFrom::Current(
-(read_bits1 as i64 + read_bits2 as i64) + l as i64,
))?;
return Ok(self.syms[sym_idx]);
}
}
Err(Error::new(
io::ErrorKind::InvalidData,
"huffman decode failed",
))
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use bitstream_io::{BitReader, LittleEndian};
use super::HuffmanDecoder;
#[test]
fn test_huffman_decode_basic() {
let lens = [
3, 3, 3, 3, 3, 3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, ];
let mut d = HuffmanDecoder::default();
d.init(&lens, lens.len()).unwrap();
assert_eq!(
d.huffman_decode(
8,
&mut BitReader::endian(&mut Cursor::new(&[!0x0]), LittleEndian)
)
.unwrap(),
0
);
assert_eq!(
d.huffman_decode(
8,
&mut BitReader::endian(&mut Cursor::new(&[!0b110]), LittleEndian)
)
.unwrap(),
0b011
);
assert_eq!(
d.huffman_decode(
8,
&mut BitReader::endian(&mut Cursor::new(&[!0b1111]), LittleEndian)
)
.unwrap(),
0b10001
);
assert_eq!(
d.huffman_decode(
8,
&mut BitReader::endian(&mut Cursor::new(&[!0b11111]), LittleEndian)
)
.unwrap(),
0b10000
);
assert!(
d.huffman_decode(
8,
&mut BitReader::endian(&mut Cursor::new(&[!0x7f]), LittleEndian)
)
.is_err()
);
}
}