#![allow(dead_code)]
use crate::error::Result;
use crate::huffman::HuffmanEncodeTable;
use crate::huffman::classic::{
depths_to_bits_values, generate_code_lengths, generate_optimal_table,
};
#[derive(Clone, Debug)]
pub struct OptimizedTable {
pub table: HuffmanEncodeTable,
pub bits: [u8; 16],
pub values: Vec<u8>,
}
impl OptimizedTable {
pub fn from_bits_values(bits: [u8; 16], values: Vec<u8>) -> crate::error::Result<Self> {
let table = HuffmanEncodeTable::from_bits_values(&bits, &values)?;
Ok(Self {
table,
bits,
values,
})
}
pub fn from_bits_values_static(bits: [u8; 16], values: &[u8]) -> Self {
let table = HuffmanEncodeTable::from_bits_values(&bits, values)
.expect("static table data is valid");
Self {
table,
bits,
values: values.to_vec(),
}
}
#[inline]
pub fn encode(&self, symbol: u8) -> (u32, u8) {
self.table.encode(symbol)
}
}
#[derive(Clone, Debug)]
pub struct HuffmanTableSet {
pub dc_luma: OptimizedTable,
pub ac_luma: OptimizedTable,
pub dc_chroma: OptimizedTable,
pub ac_chroma: OptimizedTable,
}
impl HuffmanTableSet {
pub fn from_standard() -> crate::error::Result<Self> {
use crate::huffman::encode::{
STD_AC_CHROMINANCE_BITS, STD_AC_CHROMINANCE_VALUES, STD_AC_LUMINANCE_BITS,
STD_AC_LUMINANCE_VALUES, STD_DC_CHROMINANCE_BITS, STD_DC_CHROMINANCE_VALUES,
STD_DC_LUMINANCE_BITS, STD_DC_LUMINANCE_VALUES,
};
Ok(Self {
dc_luma: OptimizedTable::from_bits_values(
STD_DC_LUMINANCE_BITS,
STD_DC_LUMINANCE_VALUES.to_vec(),
)?,
ac_luma: OptimizedTable::from_bits_values(
STD_AC_LUMINANCE_BITS,
STD_AC_LUMINANCE_VALUES.to_vec(),
)?,
dc_chroma: OptimizedTable::from_bits_values(
STD_DC_CHROMINANCE_BITS,
STD_DC_CHROMINANCE_VALUES.to_vec(),
)?,
ac_chroma: OptimizedTable::from_bits_values(
STD_AC_CHROMINANCE_BITS,
STD_AC_CHROMINANCE_VALUES.to_vec(),
)?,
})
}
pub fn annex_k() -> crate::error::Result<Self> {
Self::from_standard()
}
}
#[derive(Clone, Debug)]
pub struct FrequencyCounter {
counts: [i64; 257],
}
impl Default for FrequencyCounter {
fn default() -> Self {
Self::new()
}
}
impl FrequencyCounter {
#[must_use]
pub fn new() -> Self {
Self { counts: [0; 257] }
}
pub fn reset(&mut self) {
self.counts.fill(0);
}
#[inline]
pub fn count(&mut self, symbol: u8) {
self.counts[symbol as usize] += 1;
}
pub fn set_count(&mut self, symbol: u8, value: i64) {
self.counts[symbol as usize] = value;
}
#[must_use]
pub fn get_count(&self, symbol: u8) -> i64 {
self.counts[symbol as usize]
}
#[must_use]
pub fn total(&self) -> i64 {
self.counts[..256].iter().sum()
}
#[must_use]
pub fn num_symbols(&self) -> usize {
self.counts[..256].iter().filter(|&&c| c > 0).count()
}
pub fn generate_table(&self) -> Result<HuffmanEncodeTable> {
let mut freq = self.counts;
let (bits, values) = generate_optimal_table(&mut freq)?;
HuffmanEncodeTable::from_bits_values(&bits, &values)
}
pub fn generate_table_with_dht(&self) -> Result<OptimizedTable> {
let mut freq = self.counts;
let (bits, values) = generate_optimal_table(&mut freq)?;
let table = HuffmanEncodeTable::from_bits_values(&bits, &values)?;
Ok(OptimizedTable {
table,
bits,
values,
})
}
pub fn generate_table_with_method(
&self,
method: crate::types::HuffmanMethod,
) -> Result<OptimizedTable> {
use crate::types::HuffmanMethod;
match method {
HuffmanMethod::JpegliCreateTree => {
let mut freqs: Vec<u64> = self.counts[..256]
.iter()
.map(|&c| c.max(0) as u64)
.collect();
freqs.push(1);
let depths = crate::huffman::build_code_lengths(&freqs, 16);
let (bits, values) = depths_to_bits_values(&depths);
let table = HuffmanEncodeTable::from_bits_values(&bits, &values)?;
Ok(OptimizedTable {
table,
bits,
values,
})
}
HuffmanMethod::MozjpegClassic => {
self.generate_table_with_dht()
}
}
}
pub fn generate_lengths(&self) -> Result<[u8; 256]> {
let mut freq = self.counts;
generate_code_lengths(&mut freq)
}
#[must_use]
pub fn estimate_cost(&self, lengths: &[u8; 256]) -> u64 {
(0..256)
.map(|i| self.counts[i] as u64 * lengths[i] as u64)
.sum()
}
pub fn is_empty_histogram(&self) -> bool {
self.counts[..256].iter().all(|&c| c == 0)
}
pub fn add(&mut self, other: &FrequencyCounter) {
for i in 0..257 {
self.counts[i] = self.counts[i].saturating_add(other.counts[i]);
}
}
pub fn combined(&self, other: &FrequencyCounter) -> FrequencyCounter {
let mut result = self.clone();
result.add(other);
result
}
pub fn estimate_encoding_cost(&self) -> f64 {
let lengths = match self.generate_lengths() {
Ok(l) => l,
Err(_) => return f64::MAX,
};
let mut header_bits = (1 + 16) * 8;
let mut data_bits: u64 = 0;
for i in 0..256 {
if lengths[i] > 0 {
header_bits += 8;
data_bits += self.counts[i] as u64 * lengths[i] as u64;
}
}
header_bits as f64 + data_bits as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frequency_counter_basic() {
let mut counter = FrequencyCounter::new();
counter.count(0);
counter.count(0);
counter.count(1);
assert_eq!(counter.get_count(0), 2);
assert_eq!(counter.get_count(1), 1);
assert_eq!(counter.get_count(2), 0);
assert_eq!(counter.total(), 3);
assert_eq!(counter.num_symbols(), 2);
}
#[test]
fn test_frequency_counter_reset() {
let mut counter = FrequencyCounter::new();
counter.count(0);
counter.count(1);
counter.reset();
assert_eq!(counter.total(), 0);
assert_eq!(counter.num_symbols(), 0);
}
#[test]
fn test_generate_table_uniform() {
let mut counter = FrequencyCounter::new();
for i in 0..8u8 {
for _ in 0..100 {
counter.count(i);
}
}
let table = counter.generate_table().unwrap();
let mut total_symbols = 0;
for i in 0..8 {
let (_, len) = table.encode(i);
assert!(len > 0, "Symbol {} should have a code", i);
assert!(len <= 4, "Uniform 8 symbols should have codes <= 4 bits");
total_symbols += 1;
}
assert_eq!(total_symbols, 8);
}
#[test]
fn test_generate_table_skewed() {
let mut counter = FrequencyCounter::new();
for _ in 0..10000 {
counter.count(0);
}
for _ in 0..100 {
counter.count(1);
}
for _ in 0..10 {
counter.count(2);
}
counter.count(3);
let table = counter.generate_table().unwrap();
let (_, len0) = table.encode(0);
let (_, len1) = table.encode(1);
let (_, len2) = table.encode(2);
let (_, len3) = table.encode(3);
assert!(
len0 <= len1,
"More frequent symbol should have shorter code"
);
assert!(len1 <= len2);
assert!(len2 <= len3);
}
#[test]
fn test_generate_table_single_symbol() {
let mut counter = FrequencyCounter::new();
counter.count(42);
counter.count(42);
counter.count(42);
let table = counter.generate_table().unwrap();
let (_, len) = table.encode(42);
assert_eq!(len, 1, "Single symbol should get length 1");
}
#[test]
fn test_generate_table_empty() {
let counter = FrequencyCounter::new();
let result = counter.generate_table();
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_code_length_limit() {
let mut counter = FrequencyCounter::new();
let mut f = 1i64;
for i in 0..30u8 {
for _ in 0..f {
counter.count(i);
}
f = (f * 3) / 2 + 1; }
let table = counter.generate_table().unwrap();
for i in 0..30 {
let (_, len) = table.encode(i);
assert!(len <= 16, "Symbol {} has length {} > 16", i, len);
}
}
#[test]
fn test_estimate_cost() {
let mut counter = FrequencyCounter::new();
for _ in 0..100 {
counter.count(0);
} for _ in 0..10 {
counter.count(1);
}
let lengths = counter.generate_lengths().unwrap();
let cost = counter.estimate_cost(&lengths);
assert!(cost > 0);
assert!(cost < 1000); }
}