use crate::bits::BitWriter;
use crate::deflate::tables::{encode_distance, encode_length, CODE_LENGTH_ORDER};
use crate::deflate::tokens::LZ77Token;
use crate::error::Result;
const MAX_CODE_LENGTH: u8 = 15;
const MAX_CL_CODE_LENGTH: u8 = 7;
#[derive(Clone, Debug)]
pub struct FrequencyCounter {
pub literal_freq: [u32; 286],
pub distance_freq: [u32; 30],
}
impl FrequencyCounter {
pub fn new() -> Self {
Self { literal_freq: [0; 286], distance_freq: [0; 30] }
}
pub fn count_tokens(&mut self, tokens: &[LZ77Token]) {
for token in tokens {
match token {
LZ77Token::Literal(byte) => {
self.literal_freq[*byte as usize] += 1;
}
LZ77Token::Copy { length, distance } => {
if let Some((len_code, _, _)) = encode_length(*length) {
self.literal_freq[len_code as usize] += 1;
}
if let Some((dist_code, _, _)) = encode_distance(*distance) {
self.distance_freq[dist_code as usize] += 1;
}
}
LZ77Token::EndOfBlock => {
self.literal_freq[256] += 1;
}
}
}
if self.literal_freq[256] == 0 {
self.literal_freq[256] = 1;
}
}
pub fn num_literal_codes(&self) -> usize {
let mut last = 256; for i in (257..286).rev() {
if self.literal_freq[i] > 0 {
last = i;
break;
}
}
last + 1
}
pub fn num_distance_codes(&self) -> usize {
let mut last = 0;
for i in (0..30).rev() {
if self.distance_freq[i] > 0 {
last = i;
break;
}
}
(last + 1).max(1)
}
}
impl Default for FrequencyCounter {
fn default() -> Self {
Self::new()
}
}
pub fn compute_code_lengths(frequencies: &[u32], max_bits: u8) -> Vec<u8> {
let n = frequencies.len();
if n == 0 {
return vec![];
}
let symbols: Vec<(usize, u32)> =
frequencies.iter().enumerate().filter(|(_, &f)| f > 0).map(|(i, &f)| (i, f)).collect();
if symbols.is_empty() {
return vec![0; n];
}
if symbols.len() == 1 {
let mut lengths = vec![0u8; n];
lengths[symbols[0].0] = 1;
return lengths;
}
if symbols.len() == 2 {
let mut lengths = vec![0u8; n];
lengths[symbols[0].0] = 1;
lengths[symbols[1].0] = 1;
return lengths;
}
let mut lengths = build_huffman_lengths(&symbols, n);
limit_code_lengths(&mut lengths, &symbols, max_bits);
lengths
}
fn build_huffman_lengths(symbols: &[(usize, u32)], n: usize) -> Vec<u8> {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
#[derive(Clone)]
struct Node {
freq: u64,
symbols: Vec<usize>, depth: u8,
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.freq == other.freq
}
}
impl Eq for Node {}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.freq.cmp(&other.freq)
}
}
let mut heap: BinaryHeap<Reverse<Node>> = symbols
.iter()
.map(|&(sym, freq)| Reverse(Node { freq: freq as u64, symbols: vec![sym], depth: 0 }))
.collect();
while heap.len() > 1 {
let Reverse(left) = heap.pop().unwrap();
let Reverse(right) = heap.pop().unwrap();
let mut combined_symbols = left.symbols;
combined_symbols.extend(right.symbols);
heap.push(Reverse(Node {
freq: left.freq + right.freq,
symbols: combined_symbols,
depth: left.depth.max(right.depth) + 1,
}));
}
let mut lengths = vec![0u8; n];
if heap.pop().is_some() {
compute_depths_bfs(symbols, &mut lengths);
}
lengths
}
fn compute_depths_bfs(symbols: &[(usize, u32)], lengths: &mut [u8]) {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
#[derive(Clone)]
enum TreeNode {
Leaf(usize), Internal(Box<TreeNode>, Box<TreeNode>),
}
#[derive(Clone)]
struct HeapNode {
freq: u64,
node: TreeNode,
}
impl PartialEq for HeapNode {
fn eq(&self, other: &Self) -> bool {
self.freq == other.freq
}
}
impl Eq for HeapNode {}
impl PartialOrd for HeapNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapNode {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.freq.cmp(&other.freq)
}
}
let mut heap: BinaryHeap<Reverse<HeapNode>> = symbols
.iter()
.map(|&(sym, freq)| Reverse(HeapNode { freq: freq as u64, node: TreeNode::Leaf(sym) }))
.collect();
while heap.len() > 1 {
let Reverse(left) = heap.pop().unwrap();
let Reverse(right) = heap.pop().unwrap();
heap.push(Reverse(HeapNode {
freq: left.freq + right.freq,
node: TreeNode::Internal(Box::new(left.node), Box::new(right.node)),
}));
}
fn traverse(node: &TreeNode, depth: u8, lengths: &mut [u8]) {
match node {
TreeNode::Leaf(sym) => {
lengths[*sym] = depth.max(1); }
TreeNode::Internal(left, right) => {
traverse(left, depth + 1, lengths);
traverse(right, depth + 1, lengths);
}
}
}
if let Some(Reverse(root)) = heap.pop() {
traverse(&root.node, 0, lengths);
}
}
fn limit_code_lengths(lengths: &mut [u8], symbols: &[(usize, u32)], max_bits: u8) {
let max_len = lengths.iter().copied().max().unwrap_or(0);
if max_len <= max_bits {
return;
}
let mut bl_count = vec![0u32; max_len as usize + 1];
for &(sym, _) in symbols {
let len = lengths[sym];
if len > 0 {
bl_count[len as usize] += 1;
}
}
let mut overflow = 0u32;
for bits in ((max_bits as usize + 1)..=max_len as usize).rev() {
overflow += bl_count[bits];
bl_count[bits] = 0;
}
bl_count[max_bits as usize] += overflow;
while overflow > 0 {
for bits in (1..max_bits as usize).rev() {
if bl_count[bits] > 0 {
bl_count[bits] -= 1;
bl_count[bits + 1] += 2;
bl_count[max_bits as usize] -= 1;
overflow -= 1;
break;
}
}
if bl_count[1..(max_bits as usize)].iter().all(|&c| c == 0) {
break;
}
}
let mut sorted_syms: Vec<(usize, u32)> = symbols.to_vec();
sorted_syms.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
let mut sym_idx = 0;
for (bits, &count) in bl_count.iter().enumerate().skip(1).take(max_bits as usize) {
for _ in 0..count {
if sym_idx < sorted_syms.len() {
lengths[sorted_syms[sym_idx].0] = bits as u8;
sym_idx += 1;
}
}
}
}
pub struct HuffmanEncoder {
use_fixed: bool,
fixed_lit_codes: Vec<(u32, u8)>,
fixed_dist_codes: Vec<(u32, u8)>,
}
impl HuffmanEncoder {
pub fn new(use_fixed: bool) -> Self {
let fixed_lit_codes = build_fixed_literal_codes();
let fixed_dist_codes = build_fixed_distance_codes();
Self { use_fixed, fixed_lit_codes, fixed_dist_codes }
}
pub fn encode(&mut self, tokens: &[LZ77Token], is_final: bool) -> Result<Vec<u8>> {
let mut writer = BitWriter::with_capacity(tokens.len() * 2);
writer.write_bit(is_final); if self.use_fixed {
writer.write_bits(1, 2); self.encode_fixed(&mut writer, tokens)?;
} else {
writer.write_bits(2, 2); self.encode_dynamic(&mut writer, tokens)?;
}
Ok(writer.finish())
}
pub(crate) fn fixed_lit_codes(&self) -> &[(u32, u8)] {
&self.fixed_lit_codes
}
pub(crate) fn fixed_dist_codes(&self) -> &[(u32, u8)] {
&self.fixed_dist_codes
}
fn encode_fixed(&self, writer: &mut BitWriter, tokens: &[LZ77Token]) -> Result<()> {
for token in tokens {
match token {
LZ77Token::Literal(byte) => {
let (code, len) = self.fixed_lit_codes[*byte as usize];
writer.write_bits(code, len);
}
LZ77Token::Copy { length, distance } => {
if let Some((len_code, extra_val, extra_bits)) = encode_length(*length) {
let (code, code_len) = self.fixed_lit_codes[len_code as usize];
writer.write_bits(code, code_len);
if extra_bits > 0 {
writer.write_bits(extra_val as u32, extra_bits);
}
}
if let Some((dist_code, extra_val, extra_bits)) = encode_distance(*distance) {
let (code, code_len) = self.fixed_dist_codes[dist_code as usize];
writer.write_bits(code, code_len);
if extra_bits > 0 {
writer.write_bits(extra_val as u32, extra_bits);
}
}
}
LZ77Token::EndOfBlock => {
let (code, len) = self.fixed_lit_codes[256];
writer.write_bits(code, len);
}
}
}
let (code, len) = self.fixed_lit_codes[256];
writer.write_bits(code, len);
Ok(())
}
fn encode_dynamic(&self, writer: &mut BitWriter, tokens: &[LZ77Token]) -> Result<()> {
let mut freq = FrequencyCounter::new();
freq.count_tokens(tokens);
let num_lit = freq.num_literal_codes();
let num_dist = freq.num_distance_codes();
let mut lit_lengths = compute_code_lengths(&freq.literal_freq[..num_lit], MAX_CODE_LENGTH);
let mut dist_lengths =
compute_code_lengths(&freq.distance_freq[..num_dist], MAX_CODE_LENGTH);
if lit_lengths.len() > 256 && lit_lengths[256] == 0 {
lit_lengths[256] = 1;
}
if dist_lengths.iter().all(|&l| l == 0) {
if dist_lengths.is_empty() {
dist_lengths = vec![1];
} else {
dist_lengths[0] = 1;
}
}
let lit_codes = build_codes_from_lengths(&lit_lengths);
let dist_codes = build_codes_from_lengths(&dist_lengths);
self.write_dynamic_header(writer, &lit_lengths, &dist_lengths)?;
self.encode_with_codes(writer, tokens, &lit_codes, &dist_codes)?;
let (code, len) = lit_codes[256];
writer.write_bits(code, len);
Ok(())
}
fn write_dynamic_header(
&self,
writer: &mut BitWriter,
lit_lengths: &[u8],
dist_lengths: &[u8],
) -> Result<()> {
let hlit = lit_lengths.len() - 257; let hdist = dist_lengths.len() - 1;
let combined_lengths: Vec<u8> =
lit_lengths.iter().chain(dist_lengths.iter()).copied().collect();
let rle_encoded = rle_encode_lengths(&combined_lengths);
let mut cl_freq = [0u32; 19];
for &(sym, _) in &rle_encoded {
cl_freq[sym as usize] += 1;
}
let cl_lengths = compute_code_lengths(&cl_freq, MAX_CL_CODE_LENGTH);
let cl_codes = build_codes_from_lengths(&cl_lengths);
let mut hclen = 4usize; for i in (0..19).rev() {
if cl_lengths[CODE_LENGTH_ORDER[i]] > 0 {
hclen = i + 1;
break;
}
}
hclen = hclen.max(4);
writer.write_bits(hlit as u32, 5);
writer.write_bits(hdist as u32, 5);
writer.write_bits((hclen - 4) as u32, 4);
for &sym in CODE_LENGTH_ORDER.iter().take(hclen) {
writer.write_bits(cl_lengths[sym] as u32, 3);
}
for &(sym, extra) in &rle_encoded {
let (code, len) = cl_codes[sym as usize];
writer.write_bits(code, len);
match sym {
16 => writer.write_bits(extra as u32, 2), 17 => writer.write_bits(extra as u32, 3), 18 => writer.write_bits(extra as u32, 7), _ => {}
}
}
Ok(())
}
fn encode_with_codes(
&self,
writer: &mut BitWriter,
tokens: &[LZ77Token],
lit_codes: &[(u32, u8)],
dist_codes: &[(u32, u8)],
) -> Result<()> {
for token in tokens {
match token {
LZ77Token::Literal(byte) => {
let (code, len) = lit_codes[*byte as usize];
writer.write_bits(code, len);
}
LZ77Token::Copy { length, distance } => {
if let Some((len_code, extra_val, extra_bits)) = encode_length(*length) {
let (code, code_len) = lit_codes[len_code as usize];
writer.write_bits(code, code_len);
if extra_bits > 0 {
writer.write_bits(extra_val as u32, extra_bits);
}
}
if let Some((dist_code, extra_val, extra_bits)) = encode_distance(*distance) {
let (code, code_len) = dist_codes[dist_code as usize];
writer.write_bits(code, code_len);
if extra_bits > 0 {
writer.write_bits(extra_val as u32, extra_bits);
}
}
}
LZ77Token::EndOfBlock => {
let (code, len) = lit_codes[256];
writer.write_bits(code, len);
}
}
}
Ok(())
}
}
fn rle_encode_lengths(lengths: &[u8]) -> Vec<(u8, u8)> {
let mut result = Vec::new();
let mut i = 0;
while i < lengths.len() {
let len = lengths[i];
let mut run = 1;
while i + run < lengths.len() && lengths[i + run] == len {
run += 1;
}
if len == 0 {
while run > 0 {
if run >= 11 {
let count = run.min(138);
result.push((18, (count - 11) as u8));
run -= count;
} else if run >= 3 {
let count = run.min(10);
result.push((17, (count - 3) as u8));
run -= count;
} else {
result.push((0, 0));
run -= 1;
}
}
} else {
result.push((len, 0));
run -= 1;
while run > 0 {
if run >= 3 {
let count = run.min(6);
result.push((16, (count - 3) as u8));
run -= count;
} else {
result.push((len, 0));
run -= 1;
}
}
}
i += lengths[i..].iter().take_while(|&&l| l == len).count();
}
result
}
fn build_fixed_literal_codes() -> Vec<(u32, u8)> {
let lengths = super::tables::fixed_literal_lengths();
build_codes_from_lengths(&lengths)
}
fn build_fixed_distance_codes() -> Vec<(u32, u8)> {
let lengths = super::tables::fixed_distance_lengths();
build_codes_from_lengths(&lengths)
}
fn build_codes_from_lengths(lengths: &[u8]) -> Vec<(u32, u8)> {
let max_bits = *lengths.iter().max().unwrap_or(&0);
let mut bl_count = vec![0u32; max_bits as usize + 1];
for &len in lengths {
if len > 0 {
bl_count[len as usize] += 1;
}
}
let mut next_code = vec![0u32; max_bits as usize + 1];
let mut code = 0u32;
for bits in 1..=max_bits as usize {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut codes = vec![(0u32, 0u8); lengths.len()];
for (sym, &len) in lengths.iter().enumerate() {
if len > 0 {
codes[sym] = (crate::bits::writer::reverse_bits(next_code[len as usize], len), len);
next_code[len as usize] += 1;
}
}
codes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_fixed_literal_codes() {
let codes = build_fixed_literal_codes();
assert_eq!(codes.len(), 288);
assert_eq!(codes[0].1, 8); assert_eq!(codes[143].1, 8);
assert_eq!(codes[144].1, 9);
assert_eq!(codes[255].1, 9);
assert_eq!(codes[256].1, 7); assert_eq!(codes[279].1, 7);
assert_eq!(codes[280].1, 8);
assert_eq!(codes[287].1, 8);
}
#[test]
fn test_encode_literals() {
let mut encoder = HuffmanEncoder::new(true);
let tokens = vec![LZ77Token::Literal(b'H'), LZ77Token::Literal(b'i')];
let data = encoder.encode(&tokens, true).unwrap();
assert!(!data.is_empty());
}
#[test]
fn test_encode_dynamic() {
let mut encoder = HuffmanEncoder::new(false); let tokens = vec![
LZ77Token::Literal(b'H'),
LZ77Token::Literal(b'e'),
LZ77Token::Literal(b'l'),
LZ77Token::Literal(b'l'),
LZ77Token::Literal(b'o'),
];
let data = encoder.encode(&tokens, true).unwrap();
assert!(!data.is_empty());
assert_eq!(data[0] & 0x07, 0x05); }
#[test]
fn test_frequency_counter() {
let mut freq = FrequencyCounter::new();
let tokens = vec![
LZ77Token::Literal(b'a'),
LZ77Token::Literal(b'a'),
LZ77Token::Literal(b'b'),
LZ77Token::Copy { length: 3, distance: 1 },
];
freq.count_tokens(&tokens);
assert_eq!(freq.literal_freq[b'a' as usize], 2);
assert_eq!(freq.literal_freq[b'b' as usize], 1);
assert_eq!(freq.literal_freq[256], 1); assert_eq!(freq.literal_freq[257], 1);
assert_eq!(freq.distance_freq[0], 1);
}
#[test]
fn test_compute_code_lengths() {
let freqs = [1u32, 1, 1, 1];
let lengths = compute_code_lengths(&freqs, 15);
assert!(lengths.iter().all(|&l| l > 0));
assert!(lengths.iter().all(|&l| l <= 3));
let kraft: f64 = lengths.iter().map(|&l| 2f64.powi(-(l as i32))).sum();
assert!(kraft <= 1.0 + 0.001); }
#[test]
fn test_compute_code_lengths_skewed() {
let freqs = [100u32, 1, 1, 1];
let lengths = compute_code_lengths(&freqs, 15);
assert!(lengths[0] <= lengths[1]);
assert!(lengths[0] <= lengths[2]);
assert!(lengths[0] <= lengths[3]);
}
#[test]
fn test_rle_encode_zeros() {
let lengths = vec![0u8; 20];
let encoded = rle_encode_lengths(&lengths);
assert_eq!(encoded.len(), 1);
assert_eq!(encoded[0].0, 18); assert_eq!(encoded[0].1, 9); }
#[test]
fn test_rle_encode_repeat() {
let lengths = vec![5u8; 10];
let encoded = rle_encode_lengths(&lengths);
assert!(encoded.len() >= 2);
assert_eq!(encoded[0].0, 5); }
#[test]
fn test_encode_roundtrip_fixed() {
use std::io::Read;
let input = b"Hello, World! This is a round-trip test for fixed Huffman encoding.";
let tokens: Vec<LZ77Token> = input.iter().map(|&b| LZ77Token::Literal(b)).collect();
let mut encoder = HuffmanEncoder::new(true);
let deflate_data = encoder.encode(&tokens, true).unwrap();
let mut inflated = Vec::new();
flate2::read::DeflateDecoder::new(&deflate_data[..])
.read_to_end(&mut inflated)
.expect("flate2 should inflate fixed Huffman output");
assert_eq!(inflated, input);
}
#[test]
fn test_encode_roundtrip_dynamic() {
use std::io::Read;
let input: Vec<u8> = (0u8..=127).cycle().take(512).collect();
let tokens: Vec<LZ77Token> = input.iter().map(|&b| LZ77Token::Literal(b)).collect();
let mut encoder = HuffmanEncoder::new(false);
let deflate_data = encoder.encode(&tokens, true).unwrap();
let mut inflated = Vec::new();
flate2::read::DeflateDecoder::new(&deflate_data[..])
.read_to_end(&mut inflated)
.expect("flate2 should inflate dynamic Huffman output");
assert_eq!(inflated, input);
}
}