pub struct Decoder {
table: Vec<(u16, u8)>,
max_bits: u8,
}
impl Decoder {
pub fn from_weights(weights: &[u8; 256]) -> Option<Self> {
let max_bits = *weights.iter().max()?;
if max_bits == 0 || max_bits > 11 {
return None;
}
let mut count = [0u32; 12];
for &w in weights {
if w <= 11 {
count[w as usize] += 1;
}
}
count[0] = 0;
let total: u32 = count.iter().sum();
if total == 0 || total > 256 {
return None;
}
let mut first_code = [0u32; 12];
let mut code: u32 = 0;
for bits in 1..=max_bits as usize {
code = (code + count[bits - 1]) << 1;
first_code[bits] = code;
}
#[allow(clippy::cast_precision_loss)]
let kraft: f64 = (1..=max_bits)
.map(|b| f64::from(count[b as usize]) / f64::from(1u32 << b))
.sum();
if !(0.999..=1.001).contains(&kraft) {
return None;
}
let table_size = 1usize << max_bits;
let mut table = vec![(0u16, 0u8); table_size];
let mut next_code = first_code;
for (symbol, &weight) in weights.iter().enumerate() {
if weight == 0 {
continue;
}
let bits = weight as usize;
let base_code = next_code[bits];
let num_entries = 1usize << (max_bits as usize - bits);
for i in 0..num_entries {
let idx = ((base_code as usize) << (max_bits as usize - bits)) | i;
table[idx] = (u16::try_from(symbol).unwrap_or(0), weight);
}
next_code[bits] += 1;
}
Some(Self { table, max_bits })
}
pub fn decode(&self, data: &[u8], num_symbols: usize) -> Option<Vec<u8>> {
let mut result = Vec::with_capacity(num_symbols);
let total_bits = data.len() * 8;
let mut acc: u64 = 0;
let mut acc_bits: u32 = 0;
let mut byte_pos: usize = 0;
let max = u32::from(self.max_bits);
let table_mask = (1u64 << max) - 1;
while result.len() < num_symbols {
while acc_bits < max && byte_pos < data.len() {
acc |= u64::from(data[byte_pos]) << acc_bits;
acc_bits += 8;
byte_pos += 1;
}
if acc_bits < max {
return None; }
let code = usize::try_from(acc & table_mask).ok()?;
let (symbol, bits_used) = self.table[code];
if bits_used == 0 {
return None; }
let consumed = u32::from(bits_used);
let bit_pos_after =
(byte_pos * 8).saturating_sub(acc_bits as usize) + consumed as usize;
if bit_pos_after > total_bits {
return None;
}
if let Ok(byte) = u8::try_from(symbol) {
result.push(byte);
}
acc >>= consumed;
acc_bits -= consumed;
}
Some(result)
}
}
#[cfg(test)]
fn read_bits(data: &[u8], bit_pos: usize, num_bits: u8) -> Option<u16> {
if num_bits == 0 || num_bits > 16 {
return None;
}
let byte_idx = bit_pos >> 3;
let bit_idx = bit_pos & 7;
if byte_idx >= data.len() {
return None;
}
let b0 = u32::from(data[byte_idx]);
let b1 = u32::from(data.get(byte_idx + 1).copied().unwrap_or(0));
let b2 = u32::from(data.get(byte_idx + 2).copied().unwrap_or(0));
let value = b0 | (b1 << 8) | (b2 << 16);
let shift = bit_idx;
let mask = (1u32 << num_bits) - 1;
Some(u16::try_from((value >> shift) & mask).unwrap_or(0))
}
pub fn parse_tree(data: &[u8]) -> Option<([u8; 256], usize)> {
if data.is_empty() {
return None;
}
let mut pos = 0;
let header = data[pos];
pos += 1;
let num_weights = ((header & 0x7F) as usize) + 1;
let use_4bit = (header & 0x80) != 0;
if num_weights > 256 {
return None;
}
let mut weights = [0u8; 256];
if use_4bit {
let bytes_needed = num_weights.div_ceil(2);
if pos + bytes_needed > data.len() {
return None;
}
for i in 0..num_weights {
let byte = data[pos + (i >> 1)];
let weight = if i & 1 == 0 { byte & 0x0F } else { byte >> 4 };
weights[i] = weight.min(11);
}
pos += bytes_needed;
} else {
if pos + num_weights > data.len() {
return None;
}
for i in 0..num_weights {
weights[i] = data[pos + i].min(11);
}
pos += num_weights;
}
Some((weights, pos))
}
pub fn decode_literals(data: &[u8], num_literals: usize) -> Option<Vec<u8>> {
if num_literals > 128 * 1024 {
return None;
}
if data.len() > 128 * 1024 {
return None;
}
let (weights, tree_size) = parse_tree(data)?;
if tree_size >= data.len() {
return None;
}
let compressed_size = data.len() - tree_size;
if num_literals > compressed_size * 8 {
return None;
}
let decoder = Decoder::from_weights(&weights)?;
decoder.decode(&data[tree_size..], num_literals)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_read_bits() {
let data = vec![0xB3, 0x55];
assert_eq!(read_bits(&data, 0, 4), Some(0x3));
assert_eq!(read_bits(&data, 4, 4), Some(0xB));
assert_eq!(read_bits(&data, 4, 8), Some(0x5B));
assert_eq!(read_bits(&data, 6, 4), Some(6));
}
#[test]
fn test_huffman_basic() {
let mut weights = [0u8; 256];
weights[0] = 2;
weights[1] = 2;
weights[2] = 2;
weights[3] = 2;
let decoder = Decoder::from_weights(&weights).unwrap();
assert_eq!(decoder.max_bits, 2);
let data = vec![0b0001_1011];
let decoded_symbols = decoder.decode(&data, 4).unwrap();
assert_eq!(decoded_symbols, vec![3, 2, 1, 0]);
}
#[test]
fn test_parse_tree_4bit() {
let data = vec![0x81, 0x12];
let (weights, consumed) = parse_tree(&data).unwrap();
assert_eq!(consumed, 2);
assert_eq!(weights[0], 2);
assert_eq!(weights[1], 1);
}
#[test]
fn test_parse_tree_8bit() {
let data = vec![0x00, 0x05];
let (weights, consumed) = parse_tree(&data).unwrap();
assert_eq!(consumed, 2);
assert_eq!(weights[0], 5);
}
#[test]
fn test_decode_literals_flow() {
let mut tree_and_data = vec![0x81, 0x11]; tree_and_data.push(0x07);
let literals = decode_literals(&tree_and_data, 3).unwrap();
assert_eq!(literals, vec![1, 1, 1]);
}
#[test]
fn test_decode_literals_bounds_check() {
let mut tree_and_data = vec![0x81, 0x11];
tree_and_data.push(0x07);
let literals = decode_literals(&tree_and_data, 9);
assert!(literals.is_none());
}
}