use alloc::vec::Vec;
use crate::zstd::encoder_bitwriter::RevBitWriter;
use crate::zstd::huffman::HUF_MAX_BITS;
pub type HuffLengths = [u8; 256];
pub struct HuffEncoder {
pub codes: [u16; 256],
pub lengths: HuffLengths,
}
impl HuffEncoder {
pub fn encode_symbol(&self, writer: &mut RevBitWriter, sym: u8) {
let len = self.lengths[sym as usize];
debug_assert!(len > 0, "encoding absent symbol {sym}");
let code = self.codes[sym as usize];
writer.write_bits(code as u64, len as u32);
}
}
pub fn build_huff_lengths(freq: &[u32; 256]) -> Option<HuffLengths> {
let present: Vec<usize> = (0..256).filter(|&s| freq[s] > 0).collect();
if present.len() < 2 {
return None;
}
#[derive(Clone, Copy)]
struct Node {
parent: u32,
}
let mut nodes: Vec<Node> = Vec::with_capacity(2 * present.len());
let mut heap: Vec<(u64, u32)> = Vec::with_capacity(present.len());
let mut leaf_id_of_sym = [u32::MAX; 256];
for &s in &present {
let id = nodes.len() as u32;
nodes.push(Node { parent: u32::MAX });
heap.push((freq[s] as u64, id));
leaf_id_of_sym[s] = id;
}
heap.sort_by(|a, b| b.0.cmp(&a.0).then(b.1.cmp(&a.1)));
while heap.len() >= 2 {
let (w1, id1) = heap.pop().unwrap();
let (w2, id2) = heap.pop().unwrap();
let new_id = nodes.len() as u32;
let new_weight = w1.saturating_add(w2);
nodes.push(Node { parent: u32::MAX });
nodes[id1 as usize].parent = new_id;
nodes[id2 as usize].parent = new_id;
let entry = (new_weight, new_id);
let pos = heap
.iter()
.rposition(|x| x.0 > new_weight || (x.0 == new_weight && x.1 > new_id))
.map(|i| i + 1)
.unwrap_or(0);
heap.insert(pos, entry);
}
let mut lengths: HuffLengths = [0u8; 256];
for s in 0..256 {
let leaf = leaf_id_of_sym[s];
if leaf == u32::MAX {
continue;
}
let mut depth = 0u32;
let mut cur = nodes[leaf as usize].parent;
while cur != u32::MAX {
depth += 1;
cur = nodes[cur as usize].parent;
}
if depth == 0 {
depth = 1;
}
lengths[s] = if depth > 255 { 255 } else { depth as u8 };
}
cap_code_lengths(&mut lengths, HUF_MAX_BITS);
Some(lengths)
}
fn cap_code_lengths(lengths: &mut HuffLengths, max_len: u8) {
for l in lengths.iter_mut() {
if *l > max_len {
*l = max_len;
}
}
let mut total: u64 = 0;
for &l in lengths.iter() {
if l > 0 {
total += 1u64 << (max_len - l);
}
}
let budget: u64 = 1u64 << max_len;
while total > budget {
let mut best_sym = usize::MAX;
let mut best_len = u8::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l > 0 && l < max_len && l < best_len {
best_sym = s;
best_len = l;
}
}
if best_sym == usize::MAX {
break;
}
let old_contrib = 1u64 << (max_len - best_len);
let new_contrib = 1u64 << (max_len - best_len - 1);
let delta = old_contrib - new_contrib;
lengths[best_sym] = best_len + 1;
total -= delta;
}
while total < budget {
let mut best_sym = usize::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l == max_len {
best_sym = s;
break;
}
}
if best_sym == usize::MAX {
for (s, &l) in lengths.iter().enumerate() {
if l > 1 {
best_sym = s;
break;
}
}
if best_sym == usize::MAX {
break;
}
}
let cur = lengths[best_sym];
let cur_contrib = 1u64 << (max_len - cur);
if total + cur_contrib > budget {
let mut alt = usize::MAX;
for (s, &l) in lengths.iter().enumerate() {
if l > cur {
let lc = 1u64 << (max_len - l);
if total + lc <= budget {
alt = s;
break;
}
}
}
if alt == usize::MAX {
break;
}
let alt_len = lengths[alt];
let alt_contrib = 1u64 << (max_len - alt_len);
lengths[alt] = alt_len - 1;
total += alt_contrib;
} else {
lengths[best_sym] = cur - 1;
total += cur_contrib;
}
}
}
pub fn build_huff_encoder(lengths: &HuffLengths) -> HuffEncoder {
let mut max_len = 0u8;
for &l in lengths {
if l > max_len {
max_len = l;
}
}
let mut counts = [0u32; (HUF_MAX_BITS as usize) + 1];
for &l in lengths {
if l > 0 {
counts[l as usize] += 1;
}
}
let mut next_code = [0u32; (HUF_MAX_BITS as usize) + 2];
next_code[max_len as usize] = 0;
for l in (1..max_len).rev() {
next_code[l as usize] = (next_code[(l + 1) as usize] + counts[(l + 1) as usize]) >> 1;
}
let mut codes = [0u16; 256];
for current_len in (1..=max_len).rev() {
for (sym, &len) in lengths.iter().enumerate() {
if len != current_len {
continue;
}
let code = next_code[len as usize];
next_code[len as usize] += 1;
codes[sym] = code as u16;
}
}
HuffEncoder {
codes,
lengths: *lengths,
}
}
pub fn lengths_to_weights(lengths: &HuffLengths) -> (Vec<u8>, u8) {
let mut max_len = 0u8;
for &l in lengths {
if l > max_len {
max_len = l;
}
}
let max_num_bits = max_len;
let mut last_present: usize = 0;
for (s, &l) in lengths.iter().enumerate() {
if l > 0 {
last_present = s;
}
}
let mut weights = Vec::with_capacity(last_present);
for &l in &lengths[0..last_present] {
if l == 0 {
weights.push(0);
} else {
weights.push(max_num_bits + 1 - l);
}
}
(weights, max_num_bits)
}
pub fn encode_huff_tree_direct(weights: &[u8]) -> Vec<u8> {
debug_assert!(
weights.len() <= 128,
"direct encoding limited to 128 weights (got {})",
weights.len()
);
let n = weights.len();
let mut out = Vec::with_capacity(1 + n.div_ceil(2));
out.push(127 + n as u8);
let mut i = 0;
while i < n {
let hi = weights[i] & 0x0F;
let lo = if i + 1 < n { weights[i + 1] & 0x0F } else { 0 };
out.push((hi << 4) | lo);
i += 2;
}
out
}
pub fn encode_huff_tree_fse(weights: &[u8]) -> Option<Vec<u8>> {
use crate::zstd::encoder_fse::{FseEncoder, build_normalised_counts, encode_fse_table_header};
let n = weights.len();
if n < 2 {
return None;
}
const WALPHA: usize = 12; let mut hist = [0u32; WALPHA];
let mut max_w = 0usize;
for &w in weights {
let w = w as usize;
if w >= WALPHA {
return None;
}
hist[w] += 1;
if w > max_w {
max_w = w;
}
}
let max_symbol = max_w;
let mut accuracy_log: u8 = 6;
let distinct = hist.iter().filter(|&&c| c > 0).count();
while accuracy_log > 5 && (1u32 << accuracy_log) > (n as u32).max(distinct as u32) * 4 {
accuracy_log -= 1;
}
if accuracy_log < 5 {
accuracy_log = 5;
}
let counts = build_normalised_counts(&hist[..=max_symbol], n as u32, accuracy_log)?;
let header = encode_fse_table_header(&counts, accuracy_log);
let enc = FseEncoder::from_normalized(&counts, accuracy_log);
let last_even = (n - 1).is_multiple_of(2);
let s1_high = if last_even { n - 1 } else { n - 2 };
let s2_high = if last_even { n - 2 } else { n - 1 };
let mut writer = RevBitWriter::new();
let mut s1 = enc.init_state(weights[s1_high] as usize);
let mut s2 = enc.init_state(weights[s2_high] as usize);
let mut i1: isize = s1_high as isize - 2;
let mut i2: isize = s2_high as isize - 2;
loop {
if i1 < 0 && i2 < 0 {
break;
}
if i1 >= i2 {
s1 = enc.encode_symbol(s1, weights[i1 as usize] as usize, &mut writer);
i1 -= 2;
} else {
s2 = enc.encode_symbol(s2, weights[i2 as usize] as usize, &mut writer);
i2 -= 2;
}
}
enc.write_final_state(s2, &mut writer);
enc.write_final_state(s1, &mut writer);
let bitstream = writer.finish();
let mut payload = Vec::with_capacity(1 + header.len() + bitstream.len());
let fse_len = header.len() + bitstream.len();
if fse_len >= 128 {
return None;
}
payload.push(fse_len as u8);
payload.extend_from_slice(&header);
payload.extend_from_slice(&bitstream);
Some(payload)
}
pub fn encode_huff_stream(enc: &HuffEncoder, data: &[u8]) -> Vec<u8> {
let mut writer = RevBitWriter::new();
for &b in data.iter().rev() {
enc.encode_symbol(&mut writer, b);
}
writer.finish()
}
pub fn encode_huff_4streams(
enc: &HuffEncoder,
data: &[u8],
) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
let regen = data.len();
let per = regen.div_ceil(4);
let last = regen - 3 * per;
let s1 = encode_huff_stream(enc, &data[0..per]);
let s2 = encode_huff_stream(enc, &data[per..2 * per]);
let s3 = encode_huff_stream(enc, &data[2 * per..3 * per]);
let s4 = encode_huff_stream(enc, &data[3 * per..3 * per + last]);
(s1, s2, s3, s4)
}
pub fn predicted_bits(lengths: &HuffLengths, freq: &[u32; 256]) -> u64 {
let mut total = 0u64;
for s in 0..256 {
if lengths[s] > 0 {
total += (lengths[s] as u64) * (freq[s] as u64);
}
}
total
}
pub fn histogram(data: &[u8]) -> [u32; 256] {
let mut freq = [0u32; 256];
for &b in data {
freq[b as usize] += 1;
}
freq
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::huffman::decode_huffman_tree;
fn round_trip_huff(freq: &[u32; 256]) -> (HuffEncoder, Vec<u8>) {
let lengths = build_huff_lengths(freq).unwrap();
let mut max_len = 0u8;
for &l in &lengths {
if l > max_len {
max_len = l;
}
}
let mut kraft: u64 = 0;
for &l in &lengths {
if l > 0 {
kraft += 1u64 << (max_len - l);
}
}
assert_eq!(kraft, 1u64 << max_len, "Kraft not satisfied");
let enc = build_huff_encoder(&lengths);
let (weights, _max) = lengths_to_weights(&lengths);
let tree_bytes = encode_huff_tree_direct(&weights);
(enc, tree_bytes)
}
#[test]
fn simple_huff_round_trip() {
let mut freq = [0u32; 256];
freq[b'a' as usize] = 10;
freq[b'b' as usize] = 5;
freq[b'c' as usize] = 3;
freq[b'd' as usize] = 2;
let (enc, tree_bytes) = round_trip_huff(&freq);
let (dec_table, _consumed) = decode_huffman_tree(&tree_bytes).unwrap();
let symbols: &[u8] = b"abcdabcdab";
let stream = encode_huff_stream(&enc, symbols);
let mut br = RevBitReader::new(&stream).unwrap();
let mut decoded: Vec<u8> = Vec::new();
for _ in 0..symbols.len() {
decoded.push(dec_table.decode(&mut br).unwrap());
}
assert_eq!(decoded, symbols);
}
#[test]
fn larger_alphabet_round_trip() {
let text = b"the quick brown fox jumps over the lazy dog. the lazy dog sleeps.";
let mut freq = [0u32; 256];
for &b in text {
freq[b as usize] += 1;
}
let (enc, tree_bytes) = round_trip_huff(&freq);
let (dec_table, _consumed) = decode_huffman_tree(&tree_bytes).unwrap();
let stream = encode_huff_stream(&enc, text);
let mut br = RevBitReader::new(&stream).unwrap();
let mut decoded: Vec<u8> = Vec::new();
for _ in 0..text.len() {
decoded.push(dec_table.decode(&mut br).unwrap());
}
assert_eq!(decoded, text);
}
#[test]
fn four_stream_round_trip() {
let mut input: Vec<u8> = Vec::new();
for _ in 0..32 {
input.extend_from_slice(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. ");
}
let freq = histogram(&input);
let lengths = build_huff_lengths(&freq).unwrap();
let enc = build_huff_encoder(&lengths);
let (weights, _) = lengths_to_weights(&lengths);
let tree_bytes = encode_huff_tree_direct(&weights);
let (dec_table, _) = decode_huffman_tree(&tree_bytes).unwrap();
let (s1, s2, s3, s4) = encode_huff_4streams(&enc, &input);
let regen = input.len();
let per = regen.div_ceil(4);
let last = regen - 3 * per;
let mut out: Vec<u8> = Vec::new();
for (stream_bytes, n) in [(s1, per), (s2, per), (s3, per), (s4, last)].into_iter() {
let mut br = RevBitReader::new(&stream_bytes).unwrap();
for _ in 0..n {
out.push(dec_table.decode(&mut br).unwrap());
}
}
assert_eq!(out, input);
}
#[test]
fn fse_weights_round_trip() {
use crate::zstd::huffman::decode_huffman_tree_weights_for_test;
let mut freq = [0u32; 256];
for b in 0u32..200 {
freq[b as usize] = 200 - b + 1;
}
let lengths = build_huff_lengths(&freq).unwrap();
let (weights, _max) = lengths_to_weights(&lengths);
assert!(weights.len() > 128, "test needs > 128 weights");
let payload = encode_huff_tree_fse(&weights).expect("fse weight encode");
let decoded = decode_huffman_tree_weights_for_test(&payload).unwrap();
assert_eq!(decoded, weights, "FSE weight round-trip mismatch");
}
#[test]
fn fse_weights_round_trip_small_alphabet() {
use crate::zstd::huffman::decode_huffman_tree_weights_for_test;
let text =
b"the quick brown fox jumps over the lazy dog. pack my box with five dozen liquor jugs.";
let mut freq = [0u32; 256];
for &b in text {
freq[b as usize] += 1;
}
let lengths = build_huff_lengths(&freq).unwrap();
let (weights, _max) = lengths_to_weights(&lengths);
if let Some(payload) = encode_huff_tree_fse(&weights) {
let decoded = decode_huffman_tree_weights_for_test(&payload).unwrap();
assert_eq!(decoded, weights);
}
}
#[test]
fn cap_code_lengths_idempotent_under_limit() {
let mut lengths = [0u8; 256];
lengths[0] = 1;
lengths[1] = 1;
cap_code_lengths(&mut lengths, 11);
assert_eq!(lengths[0], 1);
assert_eq!(lengths[1], 1);
}
#[test]
fn cap_code_lengths_caps_long_codes() {
let mut lengths = [0u8; 256];
lengths[0] = 1;
lengths[1] = 15;
lengths[2] = 15;
cap_code_lengths(&mut lengths, 11);
assert!(lengths[1] <= 11 && lengths[2] <= 11);
}
}