use alloc::vec::Vec;
use crate::zstd::encoder_bitwriter::RevBitWriter;
use crate::zstd::huffman::HUF_MAX_BITS;
pub type HuffLengths = [u8; 256];
pub struct HuffEncoder {
pub codes: [u16; 256],
pub lengths: HuffLengths,
}
impl HuffEncoder {
pub fn encode_symbol(&self, writer: &mut RevBitWriter, sym: u8) {
let len = self.lengths[sym as usize];
debug_assert!(len > 0, "encoding absent symbol {sym}");
let code = self.codes[sym as usize];
writer.write_bits(code as u64, len as u32);
}
}
pub fn build_huff_lengths(freq: &[u32; 256]) -> Option<HuffLengths> {
let present: Vec<usize> = (0..256).filter(|&s| freq[s] > 0).collect();
if present.len() < 2 {
return None;
}
#[derive(Clone, Copy)]
struct Node {
parent: u32,
}
let mut nodes: Vec<Node> = Vec::with_capacity(2 * present.len());
let mut heap: Vec<(u64, u32)> = Vec::with_capacity(present.len());
let mut leaf_id_of_sym = [u32::MAX; 256];
for &s in &present {
let id = nodes.len() as u32;
nodes.push(Node { parent: u32::MAX });
heap.push((freq[s] as u64, id));
leaf_id_of_sym[s] = id;
}
heap.sort_by(|a, b| b.0.cmp(&a.0).then(b.1.cmp(&a.1)));
while heap.len() >= 2 {
let (w1, id1) = heap.pop().unwrap();
let (w2, id2) = heap.pop().unwrap();
let new_id = nodes.len() as u32;
let new_weight = w1.saturating_add(w2);
nodes.push(Node { parent: u32::MAX });
nodes[id1 as usize].parent = new_id;
nodes[id2 as usize].parent = new_id;
let entry = (new_weight, new_id);
let pos = heap
.iter()
.rposition(|x| x.0 > new_weight || (x.0 == new_weight && x.1 > new_id))
.map(|i| i + 1)
.unwrap_or(0);
heap.insert(pos, entry);
}
let mut lengths: HuffLengths = [0u8; 256];
for s in 0..256 {
let leaf = leaf_id_of_sym[s];
if leaf == u32::MAX {
continue;
}
let mut depth = 0u32;
let mut cur = nodes[leaf as usize].parent;
while cur != u32::MAX {
depth += 1;
cur = nodes[cur as usize].parent;
}
if depth == 0 {
depth = 1;
}
lengths[s] = if depth > 255 { 255 } else { depth as u8 };
}
cap_code_lengths(&mut lengths, HUF_MAX_BITS);
Some(lengths)
}
fn cap_code_lengths(lengths: &mut HuffLengths, max_len: u8) {
for l in lengths.iter_mut() {
if *l > max_len {
*l = max_len;
}
}
let mut total: u64 = 0;
for &l in lengths.iter() {
if l > 0 {
total += 1u64 << (max_len - l);
}
}
let budget: u64 = 1u64 << max_len;
while total > budget {
let mut best_sym = usize::MAX;
let mut best_len = u8::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l > 0 && l < max_len && l < best_len {
best_sym = s;
best_len = l;
}
}
if best_sym == usize::MAX {
break;
}
let old_contrib = 1u64 << (max_len - best_len);
let new_contrib = 1u64 << (max_len - best_len - 1);
let delta = old_contrib - new_contrib;
lengths[best_sym] = best_len + 1;
total -= delta;
}
while total < budget {
let mut best_sym = usize::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l == max_len {
best_sym = s;
break;
}
}
if best_sym == usize::MAX {
for (s, &l) in lengths.iter().enumerate() {
if l > 1 {
best_sym = s;
break;
}
}
if best_sym == usize::MAX {
break;
}
}
let cur = lengths[best_sym];
let cur_contrib = 1u64 << (max_len - cur);
if total + cur_contrib > budget {
let mut alt = usize::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l > cur {
let lc = 1u64 << (max_len - l);
if total + lc <= budget {
alt = s;
break;
}
}
}
if alt == usize::MAX {
break;
}
let alt_len = lengths[alt];
let alt_contrib = 1u64 << (max_len - alt_len);
lengths[alt] = alt_len - 1;
total += alt_contrib;
} else {
lengths[best_sym] = cur - 1;
total += cur_contrib;
}
}
}
pub fn build_huff_encoder(lengths: &HuffLengths) -> HuffEncoder {
let mut max_len = 0u8;
for &l in lengths {
if l > max_len {
max_len = l;
}
}
let mut counts = [0u32; (HUF_MAX_BITS as usize) + 1];
for &l in lengths {
if l > 0 {
counts[l as usize] += 1;
}
}
let mut next_code = [0u32; (HUF_MAX_BITS as usize) + 2];
next_code[max_len as usize] = 0;
for l in (1..max_len).rev() {
next_code[l as usize] = (next_code[(l + 1) as usize] + counts[(l + 1) as usize]) >> 1;
}
let mut codes = [0u16; 256];
for current_len in (1..=max_len).rev() {
for (sym, &len) in lengths.iter().enumerate() {
if len != current_len {
continue;
}
let code = next_code[len as usize];
next_code[len as usize] += 1;
codes[sym] = code as u16;
}
}
HuffEncoder {
codes,
lengths: *lengths,
}
}
pub fn lengths_to_weights(lengths: &HuffLengths) -> (Vec<u8>, u8) {
let mut max_len = 0u8;
for &l in lengths {
if l > max_len {
max_len = l;
}
}
let max_num_bits = max_len;
let mut last_present: usize = 0;
for (s, &l) in lengths.iter().enumerate() {
if l > 0 {
last_present = s;
}
}
let mut weights = Vec::with_capacity(last_present);
for &l in &lengths[0..last_present] {
if l == 0 {
weights.push(0);
} else {
weights.push(max_num_bits + 1 - l);
}
}
(weights, max_num_bits)
}
pub fn encode_huff_tree_direct(weights: &[u8]) -> Vec<u8> {
debug_assert!(
weights.len() <= 128,
"direct encoding limited to 128 weights (got {})",
weights.len()
);
let n = weights.len();
let mut out = Vec::with_capacity(1 + n.div_ceil(2));
out.push(127 + n as u8);
let mut i = 0;
while i < n {
let hi = weights[i] & 0x0F;
let lo = if i + 1 < n { weights[i + 1] & 0x0F } else { 0 };
out.push((hi << 4) | lo);
i += 2;
}
out
}
pub fn encode_huff_stream(enc: &HuffEncoder, data: &[u8]) -> Vec<u8> {
let mut writer = RevBitWriter::new();
for &b in data.iter().rev() {
enc.encode_symbol(&mut writer, b);
}
writer.finish()
}
pub fn encode_huff_4streams(
enc: &HuffEncoder,
data: &[u8],
) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
let regen = data.len();
let per = regen.div_ceil(4);
let last = regen - 3 * per;
let s1 = encode_huff_stream(enc, &data[0..per]);
let s2 = encode_huff_stream(enc, &data[per..2 * per]);
let s3 = encode_huff_stream(enc, &data[2 * per..3 * per]);
let s4 = encode_huff_stream(enc, &data[3 * per..3 * per + last]);
(s1, s2, s3, s4)
}
pub fn predicted_bits(lengths: &HuffLengths, freq: &[u32; 256]) -> u64 {
let mut total = 0u64;
for s in 0..256 {
if lengths[s] > 0 {
total += (lengths[s] as u64) * (freq[s] as u64);
}
}
total
}
pub fn histogram(data: &[u8]) -> [u32; 256] {
let mut freq = [0u32; 256];
for &b in data {
freq[b as usize] += 1;
}
freq
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::huffman::decode_huffman_tree;
fn round_trip_huff(freq: &[u32; 256]) -> (HuffEncoder, Vec<u8>) {
let lengths = build_huff_lengths(freq).unwrap();
let mut max_len = 0u8;
for &l in &lengths {
if l > max_len {
max_len = l;
}
}
let mut kraft: u64 = 0;
for &l in &lengths {
if l > 0 {
kraft += 1u64 << (max_len - l);
}
}
assert_eq!(kraft, 1u64 << max_len, "Kraft not satisfied");
let enc = build_huff_encoder(&lengths);
let (weights, _max) = lengths_to_weights(&lengths);
let tree_bytes = encode_huff_tree_direct(&weights);
(enc, tree_bytes)
}
#[test]
fn simple_huff_round_trip() {
let mut freq = [0u32; 256];
freq[b'a' as usize] = 10;
freq[b'b' as usize] = 5;
freq[b'c' as usize] = 3;
freq[b'd' as usize] = 2;
let (enc, tree_bytes) = round_trip_huff(&freq);
let (dec_table, _consumed) = decode_huffman_tree(&tree_bytes).unwrap();
let symbols: &[u8] = b"abcdabcdab";
let stream = encode_huff_stream(&enc, symbols);
let mut br = RevBitReader::new(&stream).unwrap();
let mut decoded: Vec<u8> = Vec::new();
for _ in 0..symbols.len() {
decoded.push(dec_table.decode(&mut br).unwrap());
}
assert_eq!(decoded, symbols);
}
#[test]
fn larger_alphabet_round_trip() {
let text = b"the quick brown fox jumps over the lazy dog. the lazy dog sleeps.";
let mut freq = [0u32; 256];
for &b in text {
freq[b as usize] += 1;
}
let (enc, tree_bytes) = round_trip_huff(&freq);
let (dec_table, _consumed) = decode_huffman_tree(&tree_bytes).unwrap();
let stream = encode_huff_stream(&enc, text);
let mut br = RevBitReader::new(&stream).unwrap();
let mut decoded: Vec<u8> = Vec::new();
for _ in 0..text.len() {
decoded.push(dec_table.decode(&mut br).unwrap());
}
assert_eq!(decoded, text);
}
#[test]
fn four_stream_round_trip() {
let mut input: Vec<u8> = Vec::new();
for _ in 0..32 {
input.extend_from_slice(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. ");
}
let freq = histogram(&input);
let lengths = build_huff_lengths(&freq).unwrap();
let enc = build_huff_encoder(&lengths);
let (weights, _) = lengths_to_weights(&lengths);
let tree_bytes = encode_huff_tree_direct(&weights);
let (dec_table, _) = decode_huffman_tree(&tree_bytes).unwrap();
let (s1, s2, s3, s4) = encode_huff_4streams(&enc, &input);
let regen = input.len();
let per = regen.div_ceil(4);
let last = regen - 3 * per;
let mut out: Vec<u8> = Vec::new();
for (stream_bytes, n) in [(s1, per), (s2, per), (s3, per), (s4, last)].into_iter() {
let mut br = RevBitReader::new(&stream_bytes).unwrap();
for _ in 0..n {
out.push(dec_table.decode(&mut br).unwrap());
}
}
assert_eq!(out, input);
}
#[test]
fn cap_code_lengths_idempotent_under_limit() {
let mut lengths = [0u8; 256];
lengths[0] = 1;
lengths[1] = 1;
cap_code_lengths(&mut lengths, 11);
assert_eq!(lengths[0], 1);
assert_eq!(lengths[1], 1);
}
#[test]
fn cap_code_lengths_caps_long_codes() {
let mut lengths = [0u8; 256];
lengths[0] = 1;
lengths[1] = 15;
lengths[2] = 15;
cap_code_lengths(&mut lengths, 11);
assert!(lengths[1] <= 11 && lengths[2] <= 11);
}
}