use alloc::vec;
use alloc::vec::Vec;
use super::BitWriter;
pub(crate) const fn reverse_bits(mut v: u32, n: u32) -> u32 {
let mut out = 0u32;
let mut i = 0;
while i < n {
out = (out << 1) | (v & 1);
v >>= 1;
i += 1;
}
out
}
#[derive(Clone, Copy)]
enum PoolKind {
Coin(u16),
Pair(u32, u32),
}
struct PoolElement {
cost: u64,
kind: PoolKind,
}
pub(crate) fn length_limited_huffman(freqs: &[u32], max_length: u8) -> Vec<u8> {
assert!(
max_length > 0 && max_length <= 15,
"max_length must be 1..=15"
);
let mut out = vec![0u8; freqs.len()];
let mut coins: Vec<(u32, u16)> = freqs
.iter()
.enumerate()
.filter_map(|(i, &f)| if f > 0 { Some((f, i as u16)) } else { None })
.collect();
let n = coins.len();
if n == 0 {
return out;
}
if n == 1 {
out[coins[0].1 as usize] = 1;
return out;
}
assert!(n <= 1usize << max_length, "alphabet too big for max_length");
coins.sort_by_key(|&(f, _)| f);
let mut pool: Vec<PoolElement> = Vec::with_capacity(n * (max_length as usize) * 2 + 8);
let mut current: Vec<u32> = Vec::with_capacity(2 * n);
for &(f, sym) in &coins {
pool.push(PoolElement {
cost: f as u64,
kind: PoolKind::Coin(sym),
});
current.push((pool.len() - 1) as u32);
}
for _ in 1..max_length {
let mut packages: Vec<u32> = Vec::with_capacity(current.len() / 2);
let mut i = 0;
while i + 1 < current.len() {
let a = current[i];
let b = current[i + 1];
let cost = pool[a as usize].cost + pool[b as usize].cost;
pool.push(PoolElement {
cost,
kind: PoolKind::Pair(a, b),
});
packages.push((pool.len() - 1) as u32);
i += 2;
}
let coin_start = pool.len();
for &(f, sym) in &coins {
pool.push(PoolElement {
cost: f as u64,
kind: PoolKind::Coin(sym),
});
}
let fresh_coins: Vec<u32> = (coin_start..pool.len()).map(|i| i as u32).collect();
let mut merged: Vec<u32> = Vec::with_capacity(fresh_coins.len() + packages.len());
let (mut ci, mut pi) = (0usize, 0usize);
while ci < fresh_coins.len() && pi < packages.len() {
if pool[fresh_coins[ci] as usize].cost <= pool[packages[pi] as usize].cost {
merged.push(fresh_coins[ci]);
ci += 1;
} else {
merged.push(packages[pi]);
pi += 1;
}
}
merged.extend_from_slice(&fresh_coins[ci..]);
merged.extend_from_slice(&packages[pi..]);
current = merged;
}
let pick = 2 * n - 2;
let mut stack: Vec<u32> = Vec::with_capacity(32);
for &root in ¤t[..pick] {
stack.clear();
stack.push(root);
while let Some(idx) = stack.pop() {
match pool[idx as usize].kind {
PoolKind::Coin(sym) => out[sym as usize] += 1,
PoolKind::Pair(a, b) => {
stack.push(a);
stack.push(b);
}
}
}
}
out
}
pub(crate) fn canonical_codes_from_lengths(lengths: &[u8]) -> Vec<u16> {
let mut count = [0u32; 16];
for &len in lengths {
debug_assert!(len <= 15);
if len > 0 {
count[len as usize] += 1;
}
}
let mut next_code = [0u32; 16];
let mut code: u32 = 0;
for bits in 1..=15 {
code = (code + count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut out = vec![0u16; lengths.len()];
for (i, &len) in lengths.iter().enumerate() {
if len > 0 {
out[i] = next_code[len as usize] as u16;
next_code[len as usize] += 1;
}
}
out
}
pub(crate) fn alphabet_bits(alphabet_size: u32) -> u32 {
if alphabet_size <= 1 {
return 0;
}
let mut n = 1u32;
while (1u32 << n) < alphabet_size {
n += 1;
}
n
}
pub(crate) fn emit_simple_nsym1(
bw: &mut BitWriter,
out: &mut Vec<u8>,
symbol: u32,
alphabet_size: u32,
) {
bw.write(1, 2, out);
bw.write(0, 2, out);
let ab = alphabet_bits(alphabet_size);
bw.write(symbol, ab, out);
}
pub(crate) fn emit_simple_nsym2(
bw: &mut BitWriter,
out: &mut Vec<u8>,
symbols: [u32; 2],
alphabet_size: u32,
) -> [u32; 2] {
debug_assert!(symbols[0] != symbols[1], "NSYM=2 requires distinct symbols");
bw.write(1, 2, out);
bw.write(1, 2, out);
let (lo, hi) = if symbols[0] < symbols[1] {
(symbols[0], symbols[1])
} else {
(symbols[1], symbols[0])
};
let ab = alphabet_bits(alphabet_size);
bw.write(lo, ab, out);
bw.write(hi, ab, out);
let lo_code = 0u32;
let hi_code = 1u32;
if symbols[0] < symbols[1] {
[lo_code, hi_code]
} else {
[hi_code, lo_code]
}
}
const CL_CL_CODES: [(u32, u32); 6] = [
(2, 0b00), (4, 0b0111), (3, 0b011), (2, 0b10), (2, 0b01), (4, 0b1111), ];
const CODE_LENGTH_ORDER: [usize; 18] =
[1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15];
#[derive(Clone, Copy)]
struct RleSymbol {
sym: u8,
extra_value: u32,
extra_bits: u32,
}
fn rle_encode_lengths(lengths: &[u8]) -> Vec<RleSymbol> {
let mut out: Vec<RleSymbol> = Vec::new();
let mut i = 0usize;
while i < lengths.len() {
let cur = lengths[i];
let mut run = 1usize;
while i + run < lengths.len() && lengths[i + run] == cur {
run += 1;
}
if cur == 0 {
let mut left = run;
let mut just_17 = false;
while left >= 3 {
if just_17 {
out.push(RleSymbol {
sym: 0,
extra_value: 0,
extra_bits: 0,
});
left -= 1;
just_17 = false;
if left < 3 {
break;
}
}
let chunk = left.min(10);
out.push(RleSymbol {
sym: 17,
extra_value: (chunk - 3) as u32,
extra_bits: 3,
});
left -= chunk;
just_17 = true;
}
let _ = just_17;
for _ in 0..left {
out.push(RleSymbol {
sym: 0,
extra_value: 0,
extra_bits: 0,
});
}
} else {
out.push(RleSymbol {
sym: cur,
extra_value: 0,
extra_bits: 0,
});
let mut left = run - 1;
let mut just_16 = false;
while left >= 3 {
if just_16 {
out.push(RleSymbol {
sym: cur,
extra_value: 0,
extra_bits: 0,
});
left -= 1;
just_16 = false;
if left < 3 {
break;
}
}
let chunk = left.min(6);
out.push(RleSymbol {
sym: 16,
extra_value: (chunk - 3) as u32,
extra_bits: 2,
});
left -= chunk;
just_16 = true;
}
let _ = just_16;
for _ in 0..left {
out.push(RleSymbol {
sym: cur,
extra_value: 0,
extra_bits: 0,
});
}
}
i += run;
}
while let Some(last) = out.last() {
if last.sym == 0 || last.sym == 17 {
out.pop();
} else {
break;
}
}
out
}
pub(crate) fn emit_complex_prefix_code(
bw: &mut BitWriter,
out: &mut Vec<u8>,
lengths: &[u8],
) -> Vec<u16> {
bw.write(0, 2, out);
let rle = rle_encode_lengths(lengths);
let mut cl_freq = [0u32; 18];
for s in &rle {
cl_freq[s.sym as usize] += 1;
}
let cl_lengths_vec = length_limited_huffman(&cl_freq, 5);
let mut cl_lengths = [0u8; 18];
for (i, &l) in cl_lengths_vec.iter().enumerate() {
cl_lengths[i] = l;
}
debug_assert!({
let mut sum: u32 = 0;
for &l in cl_lengths.iter() {
if l > 0 {
sum += 32u32 >> (l as u32);
}
}
sum == 32
});
let mut space: i32 = 32;
let mut last_nonzero_idx: i32 = -1;
for (idx, &sym_pos) in CODE_LENGTH_ORDER.iter().enumerate() {
let v = cl_lengths[sym_pos];
if v != 0 {
space -= 32 >> (v as i32);
last_nonzero_idx = idx as i32;
if space <= 0 {
break;
}
}
}
debug_assert!(
last_nonzero_idx >= 0 && space == 0,
"cl-cl Kraft sum {} did not reach 32 (lengths={:?})",
32 - space,
cl_lengths
);
let emit_up_to = (last_nonzero_idx + 1) as usize;
for &sym_pos in CODE_LENGTH_ORDER.iter().take(emit_up_to) {
let v = cl_lengths[sym_pos];
let (n, code) = CL_CL_CODES[v as usize];
bw.write(code, n, out);
}
let cl_codes = canonical_codes_from_lengths(&cl_lengths);
for s in &rle {
let len_b = cl_lengths[s.sym as usize] as u32;
debug_assert!(
len_b > 0,
"RLE symbol {} has no cl-cl code (cl_lengths={:?})",
s.sym,
cl_lengths
);
let code = cl_codes[s.sym as usize];
let rev = reverse_bits(code as u32, len_b);
bw.write(rev, len_b, out);
if s.extra_bits > 0 {
bw.write(s.extra_value, s.extra_bits, out);
}
}
canonical_codes_from_lengths(lengths)
}
pub(crate) fn build_huffman_lengths(freqs: &[u32], alphabet_size: usize) -> Vec<u8> {
debug_assert!(freqs.len() <= alphabet_size);
let lens = length_limited_huffman(freqs, 15);
let mut out = alloc::vec![0u8; alphabet_size];
for (i, &l) in lens.iter().enumerate() {
out[i] = l;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alphabet_bits_known_values() {
assert_eq!(alphabet_bits(1), 0);
assert_eq!(alphabet_bits(2), 1);
assert_eq!(alphabet_bits(16), 4);
assert_eq!(alphabet_bits(64), 6);
assert_eq!(alphabet_bits(256), 8);
assert_eq!(alphabet_bits(704), 10);
}
#[test]
fn rle_drops_trailing_zeros() {
let lens = [3u8, 4, 0, 0, 0, 0, 0];
let rle = rle_encode_lengths(&lens);
let last = rle.last().unwrap();
assert!(last.sym != 0 && last.sym != 17);
}
#[test]
fn rle_long_zero_run() {
let lens = [3u8]
.iter()
.chain([0u8; 30].iter())
.copied()
.collect::<Vec<_>>();
let rle = rle_encode_lengths(&lens);
assert_eq!(rle.last().unwrap().sym, 3);
}
fn kraft_sum(lens: &[u8]) -> u64 {
let mut s = 0u64;
for &l in lens {
if l > 0 {
s += 32768u64 >> (l as u32);
}
}
s
}
#[test]
fn huffman_three_equal_symbols_balances() {
let freqs = vec![3u32, 3, 3];
let lens = length_limited_huffman(&freqs, 15);
assert_eq!(kraft_sum(&lens), 32768, "lens {:?}", lens);
}
#[test]
fn huffman_four_equal_symbols_balances() {
let freqs = vec![1u32, 1, 1, 1];
let lens = length_limited_huffman(&freqs, 15);
assert_eq!(kraft_sum(&lens), 32768, "lens {:?}", lens);
}
#[test]
fn huffman_skewed_balances() {
let mut freqs = vec![0u32; 256];
freqs[0] = 1000;
freqs[1] = 100;
freqs[200] = 50;
freqs[201] = 25;
let lens = length_limited_huffman(&freqs, 15);
assert_eq!(
kraft_sum(&lens),
32768,
"lens with kraft={}",
kraft_sum(&lens)
);
}
}