#![allow(dead_code)]
use crate::error::Result;
use crate::huffman::optimize::OptimizedTable;
#[derive(Clone, Debug)]
pub struct SymbolFrequencies {
counts: [u64; 256],
}
impl Default for SymbolFrequencies {
fn default() -> Self {
Self::new()
}
}
impl SymbolFrequencies {
#[must_use]
pub fn new() -> Self {
Self { counts: [0; 256] }
}
#[inline]
pub fn count(&mut self, symbol: u8) {
self.counts[symbol as usize] += 1;
}
#[inline]
pub fn add(&mut self, symbol: u8, count: u64) {
self.counts[symbol as usize] += count;
}
#[must_use]
pub fn get(&self, symbol: u8) -> u64 {
self.counts[symbol as usize]
}
#[must_use]
pub fn total(&self) -> u64 {
self.counts.iter().sum()
}
#[must_use]
pub fn num_symbols(&self) -> usize {
self.counts.iter().filter(|&&c| c > 0).count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.counts.iter().all(|&c| c == 0)
}
pub fn reset(&mut self) {
self.counts.fill(0);
}
pub fn merge(&mut self, other: &SymbolFrequencies) {
for i in 0..256 {
self.counts[i] = self.counts[i].saturating_add(other.counts[i]);
}
}
#[must_use]
pub fn as_slice(&self) -> &[u64; 256] {
&self.counts
}
pub fn from_slice(counts: &[u64]) -> Option<Self> {
if counts.len() != 256 {
return None;
}
let mut result = Self::new();
result.counts.copy_from_slice(counts);
Some(result)
}
}
#[derive(Clone, Debug)]
pub struct CodeLengths {
lengths: [u8; 256],
}
impl Default for CodeLengths {
fn default() -> Self {
Self::new()
}
}
impl CodeLengths {
#[must_use]
pub fn new() -> Self {
Self { lengths: [0; 256] }
}
#[must_use]
pub fn from_array(lengths: [u8; 256]) -> Self {
Self { lengths }
}
#[must_use]
pub fn get(&self, symbol: u8) -> u8 {
self.lengths[symbol as usize]
}
#[must_use]
pub fn as_slice(&self) -> &[u8; 256] {
&self.lengths
}
#[must_use]
pub fn max_length(&self) -> u8 {
*self.lengths.iter().max().unwrap_or(&0)
}
#[must_use]
pub fn kraft_sum(&self) -> u64 {
self.lengths
.iter()
.filter(|&&l| l > 0 && l <= 16)
.map(|&l| 1u64 << (16 - l as u64))
.sum()
}
#[must_use]
pub fn is_valid(&self) -> bool {
if self.lengths.iter().any(|&l| l > 16) {
return false;
}
self.kraft_sum() <= (1 << 16)
}
#[must_use]
pub fn to_bits_values(&self) -> ([u8; 16], Vec<u8>) {
let mut bits = [0u8; 16];
let mut symbols_by_length: [Vec<u8>; 17] = Default::default();
for (symbol, &length) in self.lengths.iter().enumerate() {
if length > 0 && length <= 16 {
bits[length as usize - 1] += 1;
symbols_by_length[length as usize].push(symbol as u8);
}
}
for syms in &mut symbols_by_length {
syms.sort_unstable();
}
let values: Vec<u8> = (1..=16)
.flat_map(|len| symbols_by_length[len].iter().copied())
.collect();
(bits, values)
}
#[must_use]
pub fn estimate_cost(&self, frequencies: &SymbolFrequencies) -> u64 {
(0..256)
.map(|i| frequencies.get(i as u8) * self.lengths[i] as u64)
.sum()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum HuffmanAlgorithm {
#[default]
MozjpegClassic,
JpegliTree,
}
impl HuffmanAlgorithm {
pub fn generate_code_lengths(&self, frequencies: &SymbolFrequencies) -> Result<CodeLengths> {
match self {
HuffmanAlgorithm::MozjpegClassic => generate_lengths_mozjpeg(frequencies),
HuffmanAlgorithm::JpegliTree => generate_lengths_jpegli(frequencies),
}
}
pub fn generate_table(&self, frequencies: &SymbolFrequencies) -> Result<OptimizedTable> {
let lengths = self.generate_code_lengths(frequencies)?;
let (bits, values) = lengths.to_bits_values();
OptimizedTable::from_bits_values(bits, values)
}
}
fn generate_lengths_mozjpeg(frequencies: &SymbolFrequencies) -> Result<CodeLengths> {
use crate::huffman::classic::generate_code_lengths;
let mut freq = [0i64; 257];
for (i, &count) in frequencies.as_slice().iter().enumerate() {
freq[i] = count as i64;
}
let lengths_array = generate_code_lengths(&mut freq)?;
Ok(CodeLengths::from_array(lengths_array))
}
fn generate_lengths_jpegli(frequencies: &SymbolFrequencies) -> Result<CodeLengths> {
use crate::huffman::build_code_lengths;
let mut freqs: Vec<u64> = frequencies.as_slice().to_vec();
freqs.push(1);
let depths = build_code_lengths(&freqs, 16);
let mut lengths = [0u8; 256];
lengths.copy_from_slice(&depths[..256]);
Ok(CodeLengths::from_array(lengths))
}
pub fn compare_algorithms(
frequencies: &SymbolFrequencies,
) -> Result<(CodeLengths, CodeLengths, u64, u64)> {
let mozjpeg = HuffmanAlgorithm::MozjpegClassic.generate_code_lengths(frequencies)?;
let jpegli = HuffmanAlgorithm::JpegliTree.generate_code_lengths(frequencies)?;
let mozjpeg_cost = mozjpeg.estimate_cost(frequencies);
let jpegli_cost = jpegli.estimate_cost(frequencies);
Ok((mozjpeg, jpegli, mozjpeg_cost, jpegli_cost))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symbol_frequencies_basic() {
let mut freq = SymbolFrequencies::new();
assert!(freq.is_empty());
freq.count(0);
freq.count(0);
freq.count(1);
assert_eq!(freq.get(0), 2);
assert_eq!(freq.get(1), 1);
assert_eq!(freq.get(2), 0);
assert_eq!(freq.total(), 3);
assert_eq!(freq.num_symbols(), 2);
}
#[test]
fn test_symbol_frequencies_merge() {
let mut freq1 = SymbolFrequencies::new();
freq1.count(0);
freq1.count(1);
let mut freq2 = SymbolFrequencies::new();
freq2.count(0);
freq2.count(2);
freq1.merge(&freq2);
assert_eq!(freq1.get(0), 2);
assert_eq!(freq1.get(1), 1);
assert_eq!(freq1.get(2), 1);
}
#[test]
fn test_code_lengths_to_bits_values() {
let mut lengths = CodeLengths::new();
lengths.lengths[0] = 2;
lengths.lengths[1] = 2;
lengths.lengths[2] = 3;
lengths.lengths[3] = 3;
let (bits, values) = lengths.to_bits_values();
assert_eq!(bits[1], 2); assert_eq!(bits[2], 2); assert_eq!(values, vec![0, 1, 2, 3]);
}
#[test]
fn test_code_lengths_kraft_sum() {
let mut lengths = CodeLengths::new();
lengths.lengths[0] = 1;
lengths.lengths[1] = 1;
assert_eq!(lengths.kraft_sum(), 1 << 16);
assert!(lengths.is_valid()); }
#[test]
fn test_code_lengths_estimate_cost() {
let mut lengths = CodeLengths::new();
lengths.lengths[0] = 1; lengths.lengths[1] = 3;
let mut freq = SymbolFrequencies::new();
freq.add(0, 100);
freq.add(1, 10);
assert_eq!(lengths.estimate_cost(&freq), 130);
}
#[test]
fn test_algorithm_mozjpeg_produces_valid_code() {
let mut freq = SymbolFrequencies::new();
freq.add(0, 1000); freq.add(1, 100);
freq.add(2, 10);
freq.add(3, 1);
let lengths = HuffmanAlgorithm::MozjpegClassic
.generate_code_lengths(&freq)
.unwrap();
assert!(lengths.is_valid());
assert!(lengths.max_length() <= 16);
assert!(lengths.get(0) <= lengths.get(3));
}
#[test]
fn test_algorithm_jpegli_produces_valid_code() {
let mut freq = SymbolFrequencies::new();
freq.add(0, 1000);
freq.add(1, 100);
freq.add(2, 10);
freq.add(3, 1);
let lengths = HuffmanAlgorithm::JpegliTree
.generate_code_lengths(&freq)
.unwrap();
assert!(lengths.is_valid());
assert!(lengths.max_length() <= 16);
assert!(lengths.get(0) <= lengths.get(3));
}
#[test]
fn test_algorithm_comparison_same_input() {
let mut freq = SymbolFrequencies::new();
freq.add(0, 10000); freq.add(1, 5000); freq.add(17, 3000); freq.add(33, 2000); for i in 2..16 {
freq.add(i, 1000 / (i as u64 + 1));
}
let (mozjpeg, jpegli, moz_cost, jpg_cost) = compare_algorithms(&freq).unwrap();
assert!(mozjpeg.is_valid());
assert!(jpegli.is_valid());
println!("mozjpeg cost: {} bits", moz_cost);
println!("jpegli cost: {} bits", jpg_cost);
println!(
"difference: {} bits ({:.2}%)",
(moz_cost as i64 - jpg_cost as i64).abs(),
((moz_cost as f64 - jpg_cost as f64) / moz_cost as f64 * 100.0).abs()
);
let max_diff = (moz_cost.max(jpg_cost) as f64 * 0.01) as u64 + 1;
assert!(
(moz_cost as i64 - jpg_cost as i64).unsigned_abs() <= max_diff,
"Costs differ by more than 1%: mozjpeg={}, jpegli={}",
moz_cost,
jpg_cost
);
}
#[test]
fn test_optimized_table_from_frequencies() {
let mut freq = SymbolFrequencies::new();
freq.add(0, 100);
freq.add(1, 50);
freq.add(2, 25);
let table = HuffmanAlgorithm::MozjpegClassic
.generate_table(&freq)
.unwrap();
let (code0, len0) = table.encode(0);
let (code1, len1) = table.encode(1);
let (_code2, len2) = table.encode(2);
assert!(len0 > 0);
assert!(len1 > 0);
assert!(len2 > 0);
assert!(len0 <= len2);
assert!(code0 != code1 || len0 != len1);
}
}