use crate::bitstream::BitStream;
use crate::error::{Result, SmkError};
const HUFF8_BRANCH: u16 = 0x8000;
const HUFF8_LEAF_MASK: u16 = 0x7FFF;
const HUFF8_MAX_SIZE: usize = 511;
pub(crate) struct Huff8 {
tree: Vec<u16>,
}
impl Huff8 {
pub fn build(bs: &mut BitStream) -> Result<Self> {
let mut h = Huff8 { tree: Vec::new() };
if bs.read_bit()? {
h.build_rec(bs)?;
} else {
h.tree.push(0);
}
if bs.read_bit()? {
return Err(SmkError::TreeBuildFailed("expected trailing 0 bit"));
}
Ok(h)
}
fn build_rec(&mut self, bs: &mut BitStream) -> Result<()> {
if self.tree.len() >= HUFF8_MAX_SIZE {
return Err(SmkError::TreeBuildFailed("huff8 tree size exceeded"));
}
if bs.read_bit()? {
let slot = self.tree.len();
self.tree.push(0);
self.build_rec(bs)?;
self.tree[slot] = HUFF8_BRANCH | self.tree.len() as u16;
self.build_rec(bs)?;
} else {
let value = bs.read_byte()?;
self.tree.push(u16::from(value));
}
Ok(())
}
pub fn lookup(&self, bs: &mut BitStream) -> Result<u8> {
let mut index = 0usize;
while self.tree[index] & HUFF8_BRANCH != 0 {
if bs.read_bit()? {
index = (self.tree[index] & HUFF8_LEAF_MASK) as usize;
} else {
index += 1;
}
}
Ok(self.tree[index] as u8)
}
}
const HUFF16_BRANCH: u32 = 0x8000_0000;
const HUFF16_CACHE: u32 = 0x4000_0000;
const HUFF16_LEAF_MASK: u32 = 0x3FFF_FFFF;
pub(crate) struct Huff16 {
tree: Vec<u32>,
cache: [u16; 3],
}
impl Default for Huff16 {
fn default() -> Self {
Huff16 {
tree: vec![0],
cache: [0; 3],
}
}
}
impl Huff16 {
pub fn build(bs: &mut BitStream, alloc_size: u32) -> Result<Self> {
let h;
if bs.read_bit()? {
let low8 = Huff8::build(bs)?;
let hi8 = Huff8::build(bs)?;
let mut cache = [0u16; 3];
for entry in &mut cache {
let lo = bs.read_byte()?;
let hi = bs.read_byte()?;
*entry = u16::from(lo) | (u16::from(hi) << 8);
}
if alloc_size < 12 || alloc_size % 4 != 0 {
return Err(SmkError::TreeBuildFailed("illegal alloc_size for huff16"));
}
let limit = ((alloc_size - 12) / 4) as usize;
h = Huff16 {
tree: Vec::with_capacity(limit),
cache,
};
let mut h = h;
h.build_rec(bs, &low8, &hi8, limit)?;
if h.tree.len() != limit {
return Err(SmkError::TreeBuildFailed(
"huff16 tree size does not match expected",
));
}
if bs.read_bit()? {
return Err(SmkError::TreeBuildFailed("expected trailing 0 bit"));
}
Ok(h)
} else {
h = Huff16 {
tree: vec![0],
cache: [0; 3],
};
if bs.read_bit()? {
return Err(SmkError::TreeBuildFailed("expected trailing 0 bit"));
}
Ok(h)
}
}
fn build_rec(
&mut self,
bs: &mut BitStream,
low8: &Huff8,
hi8: &Huff8,
limit: usize,
) -> Result<()> {
if self.tree.len() >= limit {
return Err(SmkError::TreeBuildFailed("huff16 tree size exceeded"));
}
if bs.read_bit()? {
let slot = self.tree.len();
self.tree.push(0);
self.build_rec(bs, low8, hi8, limit)?;
self.tree[slot] = HUFF16_BRANCH | self.tree.len() as u32;
self.build_rec(bs, low8, hi8, limit)?;
} else {
let lo = low8.lookup(bs)?;
let hi = hi8.lookup(bs)?;
let value = u16::from(lo) | (u16::from(hi) << 8);
let entry = if value == self.cache[0] {
HUFF16_CACHE
} else if value == self.cache[1] {
HUFF16_CACHE | 1
} else if value == self.cache[2] {
HUFF16_CACHE | 2
} else {
u32::from(value)
};
self.tree.push(entry);
}
Ok(())
}
pub fn reset_cache(&mut self) {
self.cache = [0; 3];
}
pub fn lookup(&mut self, bs: &mut BitStream) -> Result<u16> {
let mut index = 0usize;
while self.tree[index] & HUFF16_BRANCH != 0 {
if bs.read_bit()? {
index = (self.tree[index] & HUFF16_LEAF_MASK) as usize;
} else {
index += 1;
}
}
let raw = self.tree[index];
let value = if raw & HUFF16_CACHE != 0 {
let idx = (raw & HUFF16_LEAF_MASK) as usize;
if idx >= self.cache.len() {
return Err(SmkError::InvalidData("huff16 cache index out of range"));
}
self.cache[idx]
} else {
raw as u16
};
if self.cache[0] != value {
self.cache[2] = self.cache[1];
self.cache[1] = self.cache[0];
self.cache[0] = value;
}
Ok(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_huff8(data: &[u8]) -> Result<Huff8> {
let mut bs = BitStream::new(data);
Huff8::build(&mut bs)
}
#[test]
fn empty_tree() {
let data = [0x00];
let h = build_huff8(&data).unwrap();
assert_eq!(h.tree.len(), 1);
assert_eq!(h.tree[0], 0);
}
#[test]
fn single_leaf_tree() {
let data = [0x09, 0x01];
let h = build_huff8(&data).unwrap();
assert_eq!(h.tree.len(), 1);
let mut bs = BitStream::new(&[0x00]); assert_eq!(h.lookup(&mut bs).unwrap(), 0x42);
}
#[test]
fn two_leaf_tree() {
let data = [0x53, 0xB5, 0x0B];
let h = build_huff8(&data).unwrap();
assert_eq!(h.tree.len(), 3);
let mut bs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut bs).unwrap(), 0xAA);
let mut bs = BitStream::new(&[0x01]);
assert_eq!(h.lookup(&mut bs).unwrap(), 0xBB);
}
#[test]
fn trailing_one_is_error() {
let data = [0x02];
assert!(build_huff8(&data).is_err());
}
#[test]
fn huff16_empty_tree() {
let data = [0x00];
let mut bs = BitStream::new(&data);
let mut h = Huff16::build(&mut bs, 16).unwrap();
assert_eq!(h.tree.len(), 1);
assert_eq!(h.tree[0], 0);
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0);
}
#[test]
fn huff16_single_leaf() {
let mut bits: Vec<u8> = Vec::new();
bits.push(1);
bits.push(0);
bits.push(0);
bits.push(0);
bits.push(0);
for b in lsb_bits(0x01) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
for b in lsb_bits(0x02) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
for b in lsb_bits(0x03) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
bits.push(0);
bits.push(0);
let bytes = bits_to_bytes(&bits);
let mut bs = BitStream::new(&bytes);
let mut h = Huff16::build(&mut bs, 16).unwrap();
assert_eq!(h.tree.len(), 1);
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0x0000);
}
#[test]
fn huff16_cache_substitution() {
let mut bits: Vec<u8> = Vec::new();
bits.push(1); bits.push(0);
bits.push(0);
bits.push(0);
bits.push(0);
for b in lsb_bits(0x00) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
for b in lsb_bits(0x01) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
for b in lsb_bits(0x02) {
bits.push(b);
}
for b in lsb_bits(0x00) {
bits.push(b);
}
bits.push(0);
bits.push(0);
let bytes = bits_to_bytes(&bits);
let mut bs = BitStream::new(&bytes);
let mut h = Huff16::build(&mut bs, 16).unwrap();
assert_eq!(h.tree[0], HUFF16_CACHE);
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0x0000);
h.cache[0] = 0xBEEF;
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0xBEEF);
}
#[test]
fn huff16_mru_cache_update() {
let data = [0x00]; let mut bs = BitStream::new(&data);
let mut h = Huff16::build(&mut bs, 16).unwrap();
h.cache = [0x0A, 0x0B, 0x0C];
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0);
assert_eq!(h.cache, [0x00, 0x0A, 0x0B]);
let mut lbs = BitStream::new(&[0x00]);
assert_eq!(h.lookup(&mut lbs).unwrap(), 0);
assert_eq!(h.cache, [0x00, 0x0A, 0x0B]);
}
fn lsb_bits(byte: u8) -> [u8; 8] {
let mut out = [0u8; 8];
for i in 0..8 {
out[i] = (byte >> i) & 1;
}
out
}
fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
let mut bytes = Vec::new();
for chunk in bits.chunks(8) {
let mut byte = 0u8;
for (i, &b) in chunk.iter().enumerate() {
byte |= b << i;
}
bytes.push(byte);
}
bytes
}
}