use crate::error::{Result, WebpError as Error};
use super::bit_reader::BitReader;
const CODE_LENGTH_ORDER: [usize; 19] = [
17, 18, 0, 1, 2, 3, 4, 5, 16, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
];
pub type HuffmanCode = u16;
const LUT_BITS: u8 = 8;
const LUT_SIZE: usize = 1 << LUT_BITS;
type LutEntry = u32;
#[inline]
fn lut_pack(symbol: HuffmanCode, length: u8) -> LutEntry {
((length as u32) << 16) | (symbol as u32)
}
#[inline]
fn lut_length(e: LutEntry) -> u8 {
((e >> 16) & 0xff) as u8
}
#[inline]
fn lut_symbol(e: LutEntry) -> HuffmanCode {
(e & 0xffff) as HuffmanCode
}
#[derive(Debug)]
pub struct HuffmanTree {
only_symbol: Option<HuffmanCode>,
nodes: Vec<Node>,
lut: Vec<LutEntry>,
}
#[derive(Clone, Copy, Debug)]
enum Node {
Leaf(HuffmanCode),
Internal { zero: u32, one: u32 },
}
impl HuffmanTree {
pub fn read(br: &mut BitReader<'_>, alphabet: usize) -> Result<Self> {
let simple = br.read_bit()?;
if simple == 1 {
Self::read_simple(br, alphabet)
} else {
Self::read_normal(br, alphabet)
}
}
fn read_simple(br: &mut BitReader<'_>, alphabet: usize) -> Result<Self> {
let num_symbols = br.read_bit()? + 1; let is_first_8bits = br.read_bit()?;
let sym0 = br.read_bits(if is_first_8bits != 0 { 8 } else { 1 })? as HuffmanCode;
if (sym0 as usize) >= alphabet {
return Err(Error::invalid("VP8L: simple huffman symbol out of range"));
}
if num_symbols == 1 {
return Ok(Self {
only_symbol: Some(sym0),
nodes: vec![Node::Leaf(sym0)],
lut: Vec::new(),
});
}
let sym1 = br.read_bits(8)? as HuffmanCode;
if (sym1 as usize) >= alphabet {
return Err(Error::invalid("VP8L: simple huffman symbol out of range"));
}
let mut lut = vec![0 as LutEntry; LUT_SIZE];
let entry0 = lut_pack(sym0, 1);
let entry1 = lut_pack(sym1, 1);
for (i, slot) in lut.iter_mut().enumerate() {
*slot = if i & 1 == 0 { entry0 } else { entry1 };
}
Ok(Self {
only_symbol: None,
nodes: vec![
Node::Internal { zero: 1, one: 2 },
Node::Leaf(sym0),
Node::Leaf(sym1),
],
lut,
})
}
fn read_normal(br: &mut BitReader<'_>, alphabet: usize) -> Result<Self> {
let num_code_lengths = (br.read_bits(4)? + 4) as usize;
if num_code_lengths > CODE_LENGTH_ORDER.len() {
return Err(Error::invalid("VP8L: too many code-length lengths"));
}
let mut code_length_code_lengths = [0u8; 19];
for i in 0..num_code_lengths {
code_length_code_lengths[CODE_LENGTH_ORDER[i]] = br.read_bits(3)? as u8;
}
let meta_tree = build_from_lengths(&code_length_code_lengths)?;
let (max_symbol, use_length) = if br.read_bit()? == 1 {
let length_nbits = 2 + 2 * br.read_bits(3)? as usize;
let max = 2 + br.read_bits(length_nbits as u8)? as usize;
if max > alphabet {
return Err(Error::invalid("VP8L: max_symbol > alphabet"));
}
(max, true)
} else {
(alphabet, false)
};
let mut code_lengths = vec![0u8; alphabet];
let mut sym = 0usize;
let mut prev_len = 8u8;
let mut count = 0usize;
while sym < alphabet {
if use_length && count >= max_symbol {
break;
}
let code = meta_tree.decode(br)?;
count += 1;
match code {
0..=15 => {
code_lengths[sym] = code as u8;
if code != 0 {
prev_len = code as u8;
}
sym += 1;
}
16 => {
let repeat = 3 + br.read_bits(2)? as usize;
if sym + repeat > alphabet {
return Err(Error::invalid("VP8L: huffman repeat past alphabet"));
}
for _ in 0..repeat {
code_lengths[sym] = prev_len;
sym += 1;
}
}
17 => {
let repeat = 3 + br.read_bits(3)? as usize;
if sym + repeat > alphabet {
return Err(Error::invalid("VP8L: huffman zero-run past alphabet"));
}
for _ in 0..repeat {
code_lengths[sym] = 0;
sym += 1;
}
}
18 => {
let repeat = 11 + br.read_bits(7)? as usize;
if sym + repeat > alphabet {
return Err(Error::invalid("VP8L: huffman long-zero-run past alphabet"));
}
for _ in 0..repeat {
code_lengths[sym] = 0;
sym += 1;
}
}
_ => return Err(Error::invalid("VP8L: bad code length code")),
}
}
build_from_lengths(&code_lengths)
}
#[doc(hidden)]
pub fn from_code_lengths_for_bench(lengths: &[u8]) -> Self {
build_from_lengths(lengths).expect("test/bench-only helper, lengths must be valid")
}
#[inline]
pub fn decode(&self, br: &mut BitReader<'_>) -> Result<HuffmanCode> {
if let Some(s) = self.only_symbol {
return Ok(s);
}
br.refill(LUT_BITS);
let key = br.peek_bits(LUT_BITS) as usize;
let entry = self.lut[key];
let length = lut_length(entry);
if length != 0 {
br.consume(length);
return Ok(lut_symbol(entry));
}
let mut node = 0u32;
loop {
match self.nodes[node as usize] {
Node::Leaf(s) => return Ok(s),
Node::Internal { zero, one } => {
let b = br.read_bit()?;
node = if b == 0 { zero } else { one };
}
}
}
}
}
fn build_from_lengths(lengths: &[u8]) -> Result<HuffmanTree> {
let mut max_len = 0u8;
let mut total_nonzero = 0usize;
let mut lone_symbol: Option<u16> = None;
for (i, &l) in lengths.iter().enumerate() {
if l != 0 {
total_nonzero += 1;
if l > max_len {
max_len = l;
}
lone_symbol = Some(i as u16);
}
}
if total_nonzero == 0 {
return Ok(HuffmanTree {
only_symbol: Some(0),
nodes: vec![Node::Leaf(0)],
lut: Vec::new(),
});
}
if total_nonzero == 1 {
let s = lone_symbol.unwrap_or(0);
return Ok(HuffmanTree {
only_symbol: Some(s),
nodes: vec![Node::Leaf(s)],
lut: Vec::new(),
});
}
let mut bl_count = vec![0u32; (max_len + 1) as usize];
for &l in lengths {
if l > 0 {
bl_count[l as usize] += 1;
}
}
let mut next_code = vec![0u32; (max_len + 1) as usize];
let mut code = 0u32;
for bits in 1..=max_len as usize {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut nodes: Vec<Node> = vec![Node::Internal { zero: 0, one: 0 }];
let mut lut: Vec<LutEntry> = vec![0 as LutEntry; LUT_SIZE];
for (sym, &len) in lengths.iter().enumerate() {
if len == 0 {
continue;
}
let code_val = next_code[len as usize];
next_code[len as usize] += 1;
if len <= LUT_BITS {
let mut prefix = 0u32;
for b in 0..len {
if ((code_val >> b) & 1) != 0 {
prefix |= 1u32 << (len - 1 - b);
}
}
let stride = 1usize << len;
let entry = lut_pack(sym as u16, len);
let mut k = prefix as usize;
while k < LUT_SIZE {
lut[k] = entry;
k += stride;
}
}
let mut node = 0u32;
for b in (0..len).rev() {
let bit = (code_val >> b) & 1;
if b == 0 {
let leaf_idx = nodes.len() as u32;
nodes.push(Node::Leaf(sym as u16));
match &mut nodes[node as usize] {
Node::Internal { zero, one } => {
if bit == 0 {
*zero = leaf_idx;
} else {
*one = leaf_idx;
}
}
Node::Leaf(_) => {
return Err(Error::invalid(
"VP8L: canonical Huffman length table self-collides",
))
}
}
} else {
let child = match nodes[node as usize] {
Node::Internal { zero, one } => {
if bit == 0 {
zero
} else {
one
}
}
Node::Leaf(_) => {
return Err(Error::invalid(
"VP8L: canonical Huffman length table self-collides",
))
}
};
let next = if child == 0 {
let new_idx = nodes.len() as u32;
nodes.push(Node::Internal { zero: 0, one: 0 });
match &mut nodes[node as usize] {
Node::Internal { zero, one } => {
if bit == 0 {
*zero = new_idx;
} else {
*one = new_idx;
}
}
_ => unreachable!(),
}
new_idx
} else {
child
};
node = next;
}
}
}
Ok(HuffmanTree {
only_symbol: None,
nodes,
lut,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn canonical_two_symbols() {
let tree = build_from_lengths(&[1, 1]).unwrap();
let buf = [0b0000_0010u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 1);
}
#[test]
fn build_all_zeros_returns_leaf0() {
let tree = build_from_lengths(&[0u8; 40]).unwrap();
let buf = [0u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(br.byte_pos(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
}
#[test]
fn build_one_nonzero_returns_lone_symbol() {
let mut lens = vec![0u8; 40];
lens[17] = 1;
let tree = build_from_lengths(&lens).unwrap();
let buf = [0u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 17);
assert_eq!(br.byte_pos(), 0);
}
#[test]
fn build_one_nonzero_with_long_length_still_lone() {
let mut lens = vec![0u8; 256];
lens[200] = 15;
let tree = build_from_lengths(&lens).unwrap();
let buf = [0u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 200);
assert_eq!(br.byte_pos(), 0);
}
#[test]
fn build_kraft_equality_three_symbols() {
let tree = build_from_lengths(&[1, 2, 2]).unwrap();
let buf = [0x1Au8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 1);
assert_eq!(tree.decode(&mut br).unwrap(), 2);
}
#[test]
fn build_kraft_under_equality_is_accepted() {
let tree = build_from_lengths(&[2, 2]).unwrap();
let buf = [0x08u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 1);
}
#[test]
fn build_kraft_over_equality_errors() {
let err = build_from_lengths(&[1, 1, 2]).unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("self-collides"),
"expected self-collide error, got {msg}"
);
}
#[test]
fn build_kraft_over_equality_three_length_one_silently_truncates() {
let tree = build_from_lengths(&[1, 1, 1]).expect("currently accepted");
let buf = [0u8];
let mut br = BitReader::new(&buf);
let s = tree.decode(&mut br).unwrap();
assert!(s == 0 || s == 2, "got sym {s}");
}
#[test]
fn build_length_15_max_per_spec() {
let mut lens = vec![0u8; 16];
for (i, l) in (1u8..=15).enumerate() {
lens[i] = l;
}
lens[15] = 15;
let tree = build_from_lengths(&lens).expect("length-15 tree should build");
let buf = [0u8; 4];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
}
#[test]
fn build_length_1_two_symbol_fast_tree() {
let tree = build_from_lengths(&[1, 1]).unwrap();
let buf = [0b1010_1010u8]; let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 1);
assert_eq!(tree.decode(&mut br).unwrap(), 0);
assert_eq!(tree.decode(&mut br).unwrap(), 1);
}
#[test]
fn simple_one_symbol() {
let buf = [0b0000_1001u8];
let mut br = BitReader::new(&buf);
let tree = HuffmanTree::read(&mut br, 256).unwrap();
assert_eq!(tree.decode(&mut br).unwrap(), 1);
}
#[test]
fn simple_one_symbol_8bit_field() {
let buf = [0x2du8, 0x05];
let mut br = BitReader::new(&buf);
let tree = HuffmanTree::read(&mut br, 256).unwrap();
assert_eq!(tree.decode(&mut br).unwrap(), 0xa5);
}
#[test]
fn simple_two_symbols() {
let mut bw = TestBitWriter::new();
bw.write(1, 1); bw.write(1, 1); bw.write(1, 1); bw.write(10, 8); bw.write(20, 8); bw.write(0, 1); bw.write(1, 1); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let tree = HuffmanTree::read(&mut br, 256).unwrap();
assert_eq!(tree.decode(&mut br).unwrap(), 10);
assert_eq!(tree.decode(&mut br).unwrap(), 20);
}
#[test]
fn simple_two_symbols_degenerate_duplicate() {
let mut bw = TestBitWriter::new();
bw.write(1, 1); bw.write(1, 1); bw.write(1, 1); bw.write(42, 8); bw.write(42, 8); bw.write(0, 1);
bw.write(1, 1);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let tree = HuffmanTree::read(&mut br, 256).unwrap();
assert_eq!(tree.decode(&mut br).unwrap(), 42);
assert_eq!(tree.decode(&mut br).unwrap(), 42);
}
#[test]
fn simple_one_symbol_distance_alphabet_in_range() {
let mut bw = TestBitWriter::new();
bw.write(1, 1); bw.write(0, 1); bw.write(0, 1); bw.write(1, 1); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let tree = HuffmanTree::read(&mut br, 40).unwrap();
assert_eq!(tree.decode(&mut br).unwrap(), 1);
}
#[test]
fn simple_one_symbol_distance_alphabet_8bit_out_of_range_errors() {
let mut bw = TestBitWriter::new();
bw.write(1, 1); bw.write(0, 1); bw.write(1, 1); bw.write(100, 8); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let err = HuffmanTree::read(&mut br, 40).unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("out of range"),
"expected out-of-range error, got {msg}"
);
}
#[test]
fn simple_two_symbols_sym1_out_of_range_errors() {
let mut bw = TestBitWriter::new();
bw.write(1, 1); bw.write(1, 1); bw.write(1, 1); bw.write(0, 8); bw.write(200, 8); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let err = HuffmanTree::read(&mut br, 40).unwrap_err();
let msg = format!("{err:?}");
assert!(msg.contains("out of range"), "got {msg}");
}
#[test]
fn normal_max_symbol_greater_than_alphabet_errors() {
let mut bw = TestBitWriter::new();
bw.write(0, 1); bw.write(0, 4); bw.write(0, 3); bw.write(0, 3); bw.write(1, 3); bw.write(0, 3); bw.write(1, 1); bw.write(0, 3); bw.write(3, 2); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let err = HuffmanTree::read(&mut br, 4).unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("max_symbol") || msg.to_lowercase().contains("invalid"),
"expected max_symbol error, got {msg}"
);
}
#[test]
fn build_lengths_in_distance_alphabet_size() {
let mut lens = vec![0u8; 40];
lens[39] = 1;
let tree = build_from_lengths(&lens).unwrap();
let buf = [0u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 39);
}
#[test]
fn build_lengths_full_two_codes_at_distance_alphabet_size() {
let mut lens = vec![0u8; 40];
lens[5] = 1;
lens[37] = 1;
let tree = build_from_lengths(&lens).unwrap();
let buf = [0b1010_1010u8];
let mut br = BitReader::new(&buf);
assert_eq!(tree.decode(&mut br).unwrap(), 5);
assert_eq!(tree.decode(&mut br).unwrap(), 37);
assert_eq!(tree.decode(&mut br).unwrap(), 5);
assert_eq!(tree.decode(&mut br).unwrap(), 37);
}
struct TestBitWriter {
out: Vec<u8>,
cur: u32,
nbits: u32,
}
impl TestBitWriter {
fn new() -> Self {
Self {
out: Vec::new(),
cur: 0,
nbits: 0,
}
}
fn write(&mut self, value: u32, n: u32) {
debug_assert!(n <= 24);
let mask = ((1u64 << n) - 1) as u32;
self.cur |= (value & mask) << self.nbits;
self.nbits += n;
while self.nbits >= 8 {
self.out.push((self.cur & 0xff) as u8);
self.cur >>= 8;
self.nbits -= 8;
}
}
fn finish(mut self) -> Vec<u8> {
if self.nbits > 0 {
self.out.push((self.cur & 0xff) as u8);
}
self.out
}
}
}