use super::inflate::MAX_NUM_LIT;
use crate::math::bits;
use std::sync::OnceLock;
#[derive(Default, Clone, Copy)]
pub(super) struct HCode {
pub code: u16,
pub len: u16,
}
pub(super) struct HuffmanEncoder {
pub codes: Vec<HCode>,
freqcache: Option<Vec<LiteralNode>>,
bit_count: [u32; 17],
}
#[derive(Default, Clone, Copy)]
struct LiteralNode {
literal: u16,
freq: i32,
}
#[derive(Default, Clone, Copy)]
struct LevelInfo {
level: u32,
last_freq: i32,
next_char_freq: i32,
next_pair_freq: i32,
needed: u32,
}
impl HCode {
fn set(&mut self, code: u16, length: u16) {
self.len = length;
self.code = code;
}
}
fn max_node() -> LiteralNode {
LiteralNode {
literal: u16::MAX,
freq: i32::MAX,
}
}
pub(super) fn new_huffman_encoder(size: usize) -> HuffmanEncoder {
HuffmanEncoder {
codes: vec![HCode::default(); size],
freqcache: None,
bit_count: [0; 17],
}
}
fn generate_fixed_literal_encoding() -> HuffmanEncoder {
let mut h = new_huffman_encoder(MAX_NUM_LIT);
for ch in 0..MAX_NUM_LIT as u16 {
let bits;
let size;
if ch < 144 {
bits = ch + 48;
size = 8;
} else if (144..256).contains(&ch) {
bits = ch + 400 - 144;
size = 9;
} else if (256..280).contains(&ch) {
bits = ch - 256;
size = 7;
} else {
bits = ch + 192 - 280;
size = 8;
}
h.codes[ch as usize] = HCode {
code: reverse_bits(bits, size as u8),
len: size,
}
}
h
}
fn generate_fixed_offset_encoding() -> HuffmanEncoder {
let mut h = new_huffman_encoder(30);
for ch in 0..h.codes.len() {
h.codes[ch] = HCode {
code: reverse_bits(ch as u16, 5),
len: 5,
}
}
h
}
pub(super) fn get_fixed_literal_encoding() -> &'static HuffmanEncoder {
static ENCODER: OnceLock<HuffmanEncoder> = OnceLock::new();
ENCODER.get_or_init(generate_fixed_literal_encoding)
}
pub(super) fn get_fixed_offset_encoding() -> &'static HuffmanEncoder {
static ENCODER: OnceLock<HuffmanEncoder> = OnceLock::new();
ENCODER.get_or_init(generate_fixed_offset_encoding)
}
const MAX_BITS_LIMIT: usize = 16;
impl HuffmanEncoder {
pub fn bit_length(&self, freq: &[i32]) -> usize {
let mut total = 0;
for (i, f) in freq.iter().enumerate() {
if *f != 0 {
total += *f as usize * self.codes[i].len as usize;
}
}
total
}
fn bit_counts(&mut self, list_size: usize, max_bits: usize) -> usize {
if max_bits >= MAX_BITS_LIMIT {
panic!("flate: max_bits too large");
}
let n = list_size;
let list = &mut self.freqcache.as_mut().unwrap()[..n + 1];
list[n] = max_node();
let mut max_bits = max_bits;
if max_bits > n - 1 {
max_bits = n - 1;
}
let mut levels = vec![LevelInfo::default(); MAX_BITS_LIMIT];
let mut leaf_counts = vec![vec![0; MAX_BITS_LIMIT]; MAX_BITS_LIMIT];
for level in 1..=max_bits {
levels[level] = LevelInfo {
level: level as u32,
last_freq: list[1].freq,
next_char_freq: list[2].freq,
next_pair_freq: list[0].freq + list[1].freq,
needed: 0,
};
leaf_counts[level][level] = 2;
if level == 1 {
levels[level].next_pair_freq = i32::MAX;
}
}
levels[max_bits].needed = 2 * n as u32 - 4;
let mut level = max_bits;
loop {
if levels[level].next_pair_freq == i32::MAX && levels[level].next_char_freq == i32::MAX
{
levels[level].needed = 0;
levels[level + 1].next_pair_freq = i32::MAX;
level += 1;
continue;
}
let prev_freq = levels[level].last_freq;
if levels[level].next_char_freq < levels[level].next_pair_freq {
let n = leaf_counts[level][level] + 1;
levels[level].last_freq = levels[level].next_char_freq;
leaf_counts[level][level] = n;
levels[level].next_char_freq = list[n].freq;
} else {
levels[level].last_freq = levels[level].next_pair_freq;
for i in 0..level {
leaf_counts[level][i] = leaf_counts[level - 1][i];
}
let index = levels[level].level as usize - 1;
levels[index].needed = 2;
}
levels[level].needed -= 1;
if levels[level].needed == 0 {
if levels[level].level == max_bits as u32 {
break;
}
let index = levels[level].level as usize + 1;
levels[index].next_pair_freq = prev_freq + levels[level].last_freq;
level += 1;
} else {
while levels[level - 1].needed > 0 {
level -= 1;
}
}
}
if leaf_counts[max_bits][max_bits] != n {
panic!("leafCounts[max_bits][max_bits] != n");
}
let mut bits = 1;
let counts = &leaf_counts[max_bits];
for level in (1..=max_bits).rev() {
self.bit_count[bits] = (counts[level] - counts[level - 1]) as u32;
bits += 1;
}
max_bits
}
fn assign_encoding_and_size(&mut self, max_bits: usize, list_size: usize) {
let bit_count = &self.bit_count[..max_bits + 1];
let mut list = &mut self.freqcache.as_mut().unwrap()[..list_size];
let mut code = 0_u16;
for (n, bits) in bit_count.iter().enumerate() {
code <<= 1;
if n == 0 || *bits == 0 {
continue;
}
let size = list.len() - (*bits as usize);
let chunk = &mut list[size..];
chunk.sort_by_key(|node| node.literal);
for node in chunk {
self.codes[node.literal as usize] = HCode {
code: reverse_bits(code, n as u8),
len: n as u16,
};
code += 1;
}
list = &mut list[0..size];
}
}
pub fn generate(&mut self, freq: &[i32], max_bits: usize) {
if self.freqcache.is_none() {
self.freqcache = Some(vec![LiteralNode::default(); MAX_NUM_LIT + 1]);
}
let freqcache = self.freqcache.as_mut().unwrap();
let mut list_size = 0;
for (i, f) in freq.iter().enumerate() {
if *f != 0 {
freqcache[list_size] = LiteralNode {
literal: i as u16,
freq: *f,
};
list_size += 1;
} else {
self.codes[i].len = 0;
}
}
let list = &mut freqcache[..list_size];
if list_size <= 2 {
for (i, node) in list.iter().enumerate() {
self.codes[node.literal as usize].set(i as u16, 1);
}
return;
}
list.sort_by_key(|node| node.freq);
let max_bits = self.bit_counts(list_size, max_bits);
self.assign_encoding_and_size(max_bits, list_size)
}
}
pub(super) fn reverse_bits(number: u16, bit_length: u8) -> u16 {
bits::reverse16(number << (16 - bit_length))
}