#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::cmp::{self, Ordering};
const NONE: u32 = u32::MAX;
struct Thing<'a> {
nodes: &'a mut Vec<Node>,
leaves: &'a mut Vec<Leaf>,
lists: [List; 15],
}
struct Node {
weight: usize,
count: usize,
tail: u32, }
struct Leaf {
weight: usize,
count: usize,
}
impl PartialEq for Leaf {
fn eq(&self, other: &Self) -> bool {
self.weight == other.weight
}
}
impl Eq for Leaf {}
impl Ord for Leaf {
fn cmp(&self, other: &Self) -> Ordering {
self.weight.cmp(&other.weight)
}
}
impl PartialOrd for Leaf {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone, Copy)]
struct List {
lookahead0: u32, lookahead1: u32, }
pub(crate) struct HuffmanScratch {
leaves: Vec<Leaf>,
nodes: Vec<Node>,
}
impl HuffmanScratch {
pub(crate) fn new() -> Self {
Self {
leaves: Vec::new(),
nodes: Vec::new(),
}
}
}
pub(crate) fn length_limited_code_lengths(
frequencies: &[u32],
max_bits: usize,
scratch: &mut HuffmanScratch,
bit_lengths: &mut [u8],
) {
let num_freqs = frequencies.len();
assert!(num_freqs <= 288);
assert!(bit_lengths.len() >= num_freqs);
bit_lengths[..num_freqs].fill(0);
scratch.leaves.clear();
for (i, &freq) in frequencies.iter().enumerate() {
if freq != 0 {
scratch.leaves.push(Leaf {
weight: freq as usize,
count: i,
});
}
}
let num_symbols = scratch.leaves.len();
if num_symbols <= 2 {
for i in 0..num_symbols {
bit_lengths[scratch.leaves[i].count] = 1;
}
return;
}
scratch.leaves.sort();
let max_bits = cmp::min(num_symbols - 1, max_bits);
assert!(max_bits <= 15);
let capacity = max_bits * 2 * num_symbols;
scratch.nodes.clear();
scratch
.nodes
.reserve(capacity.saturating_sub(scratch.nodes.capacity()));
scratch.nodes.push(Node {
weight: scratch.leaves[0].weight,
count: 1,
tail: NONE,
});
scratch.nodes.push(Node {
weight: scratch.leaves[1].weight,
count: 2,
tail: NONE,
});
let lists = [List {
lookahead0: 0,
lookahead1: 1,
}; 15];
let mut thing = Thing {
nodes: &mut scratch.nodes,
leaves: &mut scratch.leaves,
lists,
};
let num_boundary_pm_runs = 2 * num_symbols - 4;
for _ in 0..num_boundary_pm_runs - 1 {
thing.boundary_pm(max_bits - 1);
}
thing.boundary_pm_final(max_bits - 1);
thing.extract_bit_lengths(max_bits, bit_lengths);
}
impl Thing<'_> {
fn boundary_pm(&mut self, index: usize) {
let num_symbols = self.leaves.len();
let last_count = self.nodes[self.lists[index].lookahead1 as usize].count;
if index == 0 && last_count >= num_symbols {
return;
}
self.lists[index].lookahead0 = self.lists[index].lookahead1;
if index == 0 {
let tail = self.nodes[self.lists[index].lookahead0 as usize].tail;
let idx = self.nodes.len() as u32;
self.nodes.push(Node {
weight: self.leaves[last_count].weight,
count: last_count + 1,
tail,
});
self.lists[index].lookahead1 = idx;
} else {
let weight_sum = {
let previous_list = &self.lists[index - 1];
self.nodes[previous_list.lookahead0 as usize].weight
+ self.nodes[previous_list.lookahead1 as usize].weight
};
if last_count < num_symbols && weight_sum > self.leaves[last_count].weight {
let tail = self.nodes[self.lists[index].lookahead0 as usize].tail;
let idx = self.nodes.len() as u32;
self.nodes.push(Node {
weight: self.leaves[last_count].weight,
count: last_count + 1,
tail,
});
self.lists[index].lookahead1 = idx;
} else {
let tail = self.lists[index - 1].lookahead1;
let idx = self.nodes.len() as u32;
self.nodes.push(Node {
weight: weight_sum,
count: last_count,
tail,
});
self.lists[index].lookahead1 = idx;
self.boundary_pm(index - 1);
self.boundary_pm(index - 1);
}
}
}
fn boundary_pm_final(&mut self, index: usize) {
let num_symbols = self.leaves.len();
let last_count = self.nodes[self.lists[index].lookahead1 as usize].count;
let weight_sum = {
let previous_list = &self.lists[index - 1];
self.nodes[previous_list.lookahead0 as usize].weight
+ self.nodes[previous_list.lookahead1 as usize].weight
};
if last_count < num_symbols && weight_sum > self.leaves[last_count].weight {
let tail = self.nodes[self.lists[index].lookahead1 as usize].tail;
let idx = self.nodes.len() as u32;
self.nodes.push(Node {
weight: 0,
count: last_count + 1,
tail,
});
self.lists[index].lookahead1 = idx;
} else {
let node_idx = self.lists[index].lookahead1 as usize;
self.nodes[node_idx].tail = self.lists[index - 1].lookahead1;
}
}
fn extract_bit_lengths(&self, max_bits: usize, bit_lengths: &mut [u8]) {
let mut counts = [0usize; 16];
let mut end = 16;
let mut ptr = 15;
let mut value = 1u8;
let mut node_idx = self.lists[max_bits - 1].lookahead1 as usize;
end -= 1;
counts[end] = self.nodes[node_idx].count;
let mut tail = self.nodes[node_idx].tail;
while tail != NONE {
end -= 1;
node_idx = tail as usize;
counts[end] = self.nodes[node_idx].count;
tail = self.nodes[node_idx].tail;
}
let mut val = counts[15];
while ptr >= end {
while val > counts[ptr - 1] {
bit_lengths[self.leaves[val - 1].count] = value;
val -= 1;
}
ptr -= 1;
value += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn llcl(frequencies: &[u32], max_bits: usize) -> Vec<u8> {
let mut scratch = HuffmanScratch::new();
let mut result = vec![0u8; frequencies.len()];
length_limited_code_lengths(frequencies, max_bits, &mut scratch, &mut result);
result
}
#[test]
fn test_from_paper_3() {
let input = [1u32, 1, 5, 7, 10, 14];
let output = llcl(&input, 3);
assert_eq!(output, [3, 3, 3, 3, 2, 2]);
}
#[test]
fn test_from_paper_4() {
let input = [1u32, 1, 5, 7, 10, 14];
let output = llcl(&input, 4);
assert_eq!(output, [4, 4, 3, 2, 2, 2]);
}
#[test]
fn max_bits_7() {
let input: [u32; 19] = [252, 0, 1, 6, 9, 10, 6, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let output = llcl(&input, 7);
assert_eq!(
output,
[1, 0, 6, 4, 3, 3, 3, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);
}
#[test]
fn max_bits_15() {
let input: [u32; 32] = [
0, 0, 0, 0, 0, 0, 18, 0, 6, 0, 12, 2, 14, 9, 27, 15, 23, 15, 17, 8, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0,
];
let output = llcl(&input, 15);
assert_eq!(
output,
[
0, 0, 0, 0, 0, 0, 3, 0, 5, 0, 4, 6, 4, 4, 3, 4, 3, 3, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0
]
);
}
#[test]
fn no_frequencies() {
let input = [0u32, 0, 0, 0, 0];
let output = llcl(&input, 7);
assert_eq!(output, [0, 0, 0, 0, 0]);
}
#[test]
fn only_one_frequency() {
let input = [0u32, 10, 0];
let output = llcl(&input, 7);
assert_eq!(output, [0, 1, 0]);
}
#[test]
fn only_two_frequencies() {
let input: [u32; 17] = [0, 0, 0, 0, 252, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let output = llcl(&input, 7);
assert_eq!(output, [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
}
#[test]
fn scratch_reuse_produces_same_results() {
let mut scratch = HuffmanScratch::new();
let inputs: &[&[u32]] = &[
&[1, 1, 5, 7, 10, 14],
&[252, 0, 1, 6, 9, 10, 6, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
&[0, 0, 0, 0, 252, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
];
let max_bits = [3, 7, 7];
for (input, &mb) in inputs.iter().zip(max_bits.iter()) {
let expected = llcl(input, mb);
let mut result = vec![0u8; input.len()];
length_limited_code_lengths(input, mb, &mut scratch, &mut result);
assert_eq!(result, expected);
}
}
}