use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::fse::FseState;
pub const HUF_MAX_BITS: u8 = 11;
pub struct HuffTable {
pub max_bits: u8,
pub lookup: Vec<(u8, u8)>,
}
impl HuffTable {
pub fn decode(&self, br: &mut RevBitReader<'_>) -> Result<u8, Error> {
if br.remaining() == 0 {
return Err(Error::Corrupt);
}
let max = self.max_bits as u32;
let avail = br.remaining() as u32;
let take = core::cmp::min(max, avail);
let raw = br.read(take)?;
let idx = (raw << (max - take)) as usize;
if idx >= self.lookup.len() {
return Err(Error::Corrupt);
}
let (sym, len) = self.lookup[idx];
if len == 0 || (len as u32) > take {
return Err(Error::Corrupt);
}
if take > len as u32 {
br.unread(take - len as u32);
}
Ok(sym)
}
}
fn table_from_lengths(lengths: &[u8]) -> Result<HuffTable, Error> {
let mut max_bits = 0u8;
for &l in lengths {
if l > HUF_MAX_BITS {
return Err(Error::Corrupt);
}
if l > max_bits {
max_bits = l;
}
}
if max_bits == 0 {
return Err(Error::Corrupt);
}
let mut counts = [0u32; (HUF_MAX_BITS as usize) + 1];
for &l in lengths {
if l > 0 {
counts[l as usize] += 1;
}
}
let mut kraft: u64 = 0;
for l in 1..=max_bits {
kraft += (counts[l as usize] as u64) << (max_bits - l);
}
if kraft != (1u64 << max_bits) {
return Err(Error::Corrupt);
}
let mut next_code = [0u32; (HUF_MAX_BITS as usize) + 2];
next_code[max_bits as usize] = 0;
for l in (1..max_bits).rev() {
next_code[l as usize] = (next_code[(l + 1) as usize] + counts[(l + 1) as usize]) >> 1;
}
let size = 1usize << max_bits;
let mut lookup = vec![(0u8, 0u8); size];
for current_len in (1..=max_bits).rev() {
for (sym, &len) in lengths.iter().enumerate() {
if len != current_len {
continue;
}
let code = next_code[len as usize];
next_code[len as usize] += 1;
let shift = max_bits - len;
let start = (code << shift) as usize;
let count = 1usize << shift;
for slot in &mut lookup[start..start + count] {
*slot = (sym as u8, len);
}
}
}
Ok(HuffTable { max_bits, lookup })
}
pub fn decode_huffman_tree(data: &[u8]) -> Result<(HuffTable, usize), Error> {
if data.is_empty() {
return Err(Error::Corrupt);
}
let hb = data[0];
let (weights, consumed) = if hb >= 128 {
let count = (hb as usize) - 127;
let bytes_needed = count.div_ceil(2);
if data.len() < 1 + bytes_needed {
return Err(Error::Corrupt);
}
let mut weights = Vec::with_capacity(count);
for i in 0..count {
let b = data[1 + i / 2];
let nib = if i % 2 == 0 { b >> 4 } else { b & 0x0F };
weights.push(nib);
}
(weights, 1 + bytes_needed)
} else {
let fse_payload_len = hb as usize;
if data.len() < 1 + fse_payload_len {
return Err(Error::Corrupt);
}
let fse_bytes = &data[1..1 + fse_payload_len];
let weights = decode_fse_weights(fse_bytes)?;
(weights, 1 + fse_payload_len)
};
let mut sum: u64 = 0;
for &w in &weights {
if w > 0 {
sum += 1u64 << (w - 1);
}
}
if sum == 0 {
return Err(Error::Corrupt);
}
let max_num_bits = if sum.is_power_of_two() {
sum.trailing_zeros() as u8
} else {
(64 - sum.leading_zeros()) as u8
};
let left_over = (1u64 << max_num_bits) - sum;
let last_weight = if left_over == 0 {
0
} else {
if !left_over.is_power_of_two() {
return Err(Error::Corrupt);
}
(left_over.trailing_zeros() as u8) + 1
};
let mut all_weights = weights.clone();
all_weights.push(last_weight);
let mut lengths = vec![0u8; 256];
for (sym, &w) in all_weights.iter().enumerate() {
if sym >= 256 {
return Err(Error::Corrupt);
}
if w > 0 {
if w > max_num_bits {
return Err(Error::Corrupt);
}
lengths[sym] = max_num_bits + 1 - w;
}
}
let table = table_from_lengths(&lengths)?;
Ok((table, consumed))
}
pub(crate) fn decode_huffman_tree_weights_for_test(data: &[u8]) -> Result<Vec<u8>, Error> {
if data.is_empty() {
return Err(Error::Corrupt);
}
let hb = data[0];
if hb >= 128 {
let count = (hb as usize) - 127;
let bytes_needed = count.div_ceil(2);
if data.len() < 1 + bytes_needed {
return Err(Error::Corrupt);
}
let mut weights = Vec::with_capacity(count);
for i in 0..count {
let b = data[1 + i / 2];
let nib = if i % 2 == 0 { b >> 4 } else { b & 0x0F };
weights.push(nib);
}
Ok(weights)
} else {
let fse_payload_len = hb as usize;
if data.len() < 1 + fse_payload_len {
return Err(Error::Corrupt);
}
decode_fse_weights(&data[1..1 + fse_payload_len])
}
}
fn decode_fse_weights(payload: &[u8]) -> Result<Vec<u8>, Error> {
let max_accuracy_log = 6; let max_symbol: u16 = HUF_MAX_BITS as u16;
let (table, header_bytes) =
crate::zstd::fse::decode_fse_table(payload, max_accuracy_log, max_symbol)?;
if header_bytes > payload.len() {
return Err(Error::Corrupt);
}
let bitstream = &payload[header_bytes..];
if bitstream.is_empty() {
return Err(Error::Corrupt);
}
let mut br = RevBitReader::new(bitstream)?;
let mut s1 = FseState::init(&table, &mut br)?;
let mut s2 = FseState::init(&table, &mut br)?;
let mut weights: Vec<u8> = Vec::new();
loop {
let w1 = s1.symbol(&table) as u8;
weights.push(w1);
let nb = table.entries[s1.state as usize].num_bits as usize;
if br.remaining() < nb {
let w2 = s2.symbol(&table) as u8;
weights.push(w2);
break;
}
s1.advance(&table, &mut br)?;
let w2 = s2.symbol(&table) as u8;
weights.push(w2);
let nb = table.entries[s2.state as usize].num_bits as usize;
if br.remaining() < nb {
let w1f = s1.symbol(&table) as u8;
weights.push(w1f);
break;
}
s2.advance(&table, &mut br)?;
}
Ok(weights)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn table_from_simple_lengths() {
let lengths = {
let mut v = vec![0u8; 256];
v[0] = 2;
v[1] = 2;
v[2] = 2;
v[3] = 2;
v
};
let t = table_from_lengths(&lengths).unwrap();
assert_eq!(t.max_bits, 2);
assert_eq!(t.lookup.len(), 4);
assert_eq!(t.lookup[0], (0, 2));
assert_eq!(t.lookup[1], (1, 2));
assert_eq!(t.lookup[2], (2, 2));
assert_eq!(t.lookup[3], (3, 2));
}
}