#![allow(dead_code)]
use super::ans::{ANSEncodingHistogram, ANSHistogramStrategy, AnsDistribution, AnsEncoder};
use super::context_map::move_to_front_transform;
use super::encode::{
ALPHABET_SIZE, PrefixCode, encode_token_value, encode_token_value_with_config,
write_var_len_uint16,
};
use super::encode_huffman::{
convert_bit_depths_to_symbols, create_huffman_tree, write_prefix_code,
};
use super::hybrid_uint::HybridUintConfig;
use super::lz77::Lz77Params;
use super::token::{Lz77UintCoder, Token, UintCoder};
use crate::bit_writer::BitWriter;
use crate::error::{Error, Result};
pub const ANS_LOG_ALPHA_SIZE: usize = 6;
#[derive(Debug)]
pub struct OwnedAnsEntropyCode {
pub context_map: Vec<u8>,
pub histograms: Vec<ANSEncodingHistogram>,
pub distributions: Vec<AnsDistribution>,
pub log_alpha_size: usize,
pub uint_configs: Vec<HybridUintConfig>,
}
impl OwnedAnsEntropyCode {
pub fn num_distributions(&self) -> usize {
self.distributions.len()
}
}
pub struct AccumulatedAnsData {
pub histograms: Vec<super::histogram::Histogram>,
pub value_freqs: Vec<alloc::collections::BTreeMap<u32, u32>>,
pub lz77_freqs: Vec<alloc::collections::BTreeMap<u32, u32>>,
pub num_contexts: usize,
}
impl AccumulatedAnsData {
pub fn new(num_contexts: usize) -> Self {
Self {
histograms: (0..num_contexts)
.map(|_| super::histogram::Histogram::new())
.collect(),
value_freqs: (0..num_contexts)
.map(|_| alloc::collections::BTreeMap::new())
.collect(),
lz77_freqs: (0..num_contexts)
.map(|_| alloc::collections::BTreeMap::new())
.collect(),
num_contexts,
}
}
#[inline]
pub fn add_token(&mut self, token: &Token, lz77: Option<&Lz77Params>) {
let ctx = token.context() as usize;
if ctx < self.num_contexts {
let (_encoded, sym) = encode_token_value(token, lz77);
self.histograms[ctx].add(sym as usize);
if token.is_lz77_length() {
if let Some(lz77_params) = lz77 {
let encoded = Lz77UintCoder::encode(token.value);
let lz77_sym = encoded.token + lz77_params.min_symbol;
*self.lz77_freqs[ctx].entry(lz77_sym).or_insert(0) += 1;
}
} else {
*self.value_freqs[ctx].entry(token.value).or_insert(0) += 1;
}
}
}
pub fn add_tokens(&mut self, tokens: &[Token], lz77: Option<&Lz77Params>) {
for token in tokens {
self.add_token(token, lz77);
}
}
pub fn merge(&mut self, other: &Self) {
debug_assert_eq!(self.num_contexts, other.num_contexts);
for ctx in 0..self.num_contexts {
self.histograms[ctx].add_histogram(&other.histograms[ctx]);
for (&val, &count) in &other.value_freqs[ctx] {
*self.value_freqs[ctx].entry(val).or_insert(0) += count;
}
for (&sym, &count) in &other.lz77_freqs[ctx] {
*self.lz77_freqs[ctx].entry(sym).or_insert(0) += count;
}
}
}
}
pub fn build_entropy_code_from_accumulated_ans(
data: AccumulatedAnsData,
enhanced_clustering: bool,
optimize_uint_configs: bool,
lz77: Option<&Lz77Params>,
total_pixel_hint: Option<usize>,
) -> OwnedAnsEntropyCode {
use crate::entropy_coding::cluster::{
ClusteringType, EntropyType, cluster_histograms as enhanced_cluster,
};
use crate::entropy_coding::histogram::Histogram as EnhancedHistogram;
let num_contexts = data.num_contexts;
let cluster_type = if enhanced_clustering {
ClusteringType::Best
} else {
ClusteringType::Fast
};
let mut max_histograms = num_contexts.min(128);
if let Some(tp) = total_pixel_hint {
max_histograms = max_histograms.min((tp / 2048).max(1));
}
let result = enhanced_cluster(
cluster_type,
EntropyType::Ans,
&data.histograms,
max_histograms,
)
.expect("ANS clustering failed");
let context_map: Vec<u8> = result.symbols.iter().map(|&s| s as u8).collect();
debug_assert_eq!(context_map.len(), num_contexts);
let num_histograms = result.histograms.len();
let mut merged_value_freqs: Vec<alloc::collections::BTreeMap<u32, u32>> = (0..num_histograms)
.map(|_| alloc::collections::BTreeMap::new())
.collect();
let mut merged_lz77_freqs: Vec<alloc::collections::BTreeMap<u32, u32>> = (0..num_histograms)
.map(|_| alloc::collections::BTreeMap::new())
.collect();
for (ctx, &cm) in context_map.iter().enumerate() {
let histo_idx = cm as usize;
if histo_idx < num_histograms {
for (&val, &count) in &data.value_freqs[ctx] {
*merged_value_freqs[histo_idx].entry(val).or_insert(0) += count;
}
for (&sym, &count) in &data.lz77_freqs[ctx] {
*merged_lz77_freqs[histo_idx].entry(sym).or_insert(0) += count;
}
}
}
let uint_configs = if !optimize_uint_configs {
vec![HybridUintConfig::new(4, 2, 0); num_histograms]
} else if enhanced_clustering {
optimize_uint_configs_best_from_freqs(&merged_value_freqs, lz77)
} else {
optimize_uint_configs_fast_from_freqs(&merged_value_freqs, lz77)
};
let mut ans_histograms = Vec::with_capacity(num_histograms);
let mut ans_distributions = Vec::with_capacity(num_histograms);
let allowed_cache = super::ans::AllowedCountsCache::new();
for h in 0..num_histograms {
let config = &uint_configs[h];
let mut counts: Vec<u32> = Vec::new();
for (&val, &freq) in &merged_value_freqs[h] {
let (tok, _, _) = config.encode(val);
let sym = tok as usize;
if sym >= counts.len() {
counts.resize(sym + 1, 0);
}
counts[sym] += freq;
}
for (&sym, &freq) in &merged_lz77_freqs[h] {
let s = sym as usize;
if s >= counts.len() {
counts.resize(s + 1, 0);
}
counts[s] += freq;
}
if counts.is_empty() {
counts.push(0);
}
let i32_counts: Vec<i32> = counts.iter().map(|&c| c as i32).collect();
let histo = EnhancedHistogram::from_counts(&i32_counts);
let ans_histo = ANSEncodingHistogram::from_histogram_cached(
&histo,
ANSHistogramStrategy::Precise,
&allowed_cache,
)
.expect("ANS histogram normalization failed");
ans_histograms.push(ans_histo);
}
let max_alphabet_size = ans_histograms
.iter()
.map(|h| h.counts.len())
.max()
.unwrap_or(1);
let log_alpha_size = if lz77.is_some_and(|p| p.enabled) {
8
} else if max_alphabet_size <= (1 << ANS_LOG_ALPHA_SIZE) {
ANS_LOG_ALPHA_SIZE
} else {
let min_bits = if max_alphabet_size <= 1 {
5
} else {
(max_alphabet_size - 1).ilog2() as usize + 1
};
min_bits.clamp(5, 8)
};
for ans_histo in &ans_histograms {
let ans_dist = AnsDistribution::from_normalized_counts_with_log_alpha(
&ans_histo.counts,
log_alpha_size,
)
.expect("ANS distribution building failed");
ans_distributions.push(ans_dist);
}
OwnedAnsEntropyCode {
context_map,
histograms: ans_histograms,
distributions: ans_distributions,
log_alpha_size,
uint_configs,
}
}
pub fn build_entropy_code_ans(tokens: &[Token], num_contexts: usize) -> OwnedAnsEntropyCode {
build_entropy_code_ans_with_options(tokens, num_contexts, false, true, None, None)
}
pub fn build_entropy_code_ans_with_options(
tokens: &[Token],
num_contexts: usize,
enhanced_clustering: bool,
optimize_uint_configs: bool,
lz77: Option<&Lz77Params>,
total_pixel_hint: Option<usize>,
) -> OwnedAnsEntropyCode {
build_entropy_code_ans_from_token_groups(
&[tokens],
num_contexts,
enhanced_clustering,
optimize_uint_configs,
lz77,
total_pixel_hint,
)
}
pub fn build_entropy_code_ans_from_token_groups(
groups: &[&[Token]],
num_contexts: usize,
enhanced_clustering: bool,
optimize_uint_configs: bool,
lz77: Option<&Lz77Params>,
total_pixel_hint: Option<usize>,
) -> OwnedAnsEntropyCode {
let mut accumulated = AccumulatedAnsData::new(num_contexts);
for group in groups {
accumulated.add_tokens(group, lz77);
}
let code = build_entropy_code_from_accumulated_ans(
accumulated,
enhanced_clustering,
optimize_uint_configs,
lz77,
total_pixel_hint,
);
#[cfg(debug_assertions)]
for group in groups {
for (i, token) in group.iter().enumerate() {
let ctx = token.context() as usize;
let dist_idx = code.context_map.get(ctx).copied().unwrap_or(0) as usize;
let config = &code.uint_configs[dist_idx];
let (_encoded, sym) = encode_token_value_with_config(token, lz77, config);
let dist = &code.distributions[dist_idx];
let tok = sym as usize;
if tok >= dist.symbols.len() {
panic!(
"ANS validation: token[{}] ctx={} val={} tok={} exceeds distribution alphabet_size={} (dist_idx={})",
i,
ctx,
token.value,
tok,
dist.symbols.len(),
dist_idx
);
}
if dist.symbols[tok].freq == 0 {
panic!(
"ANS validation: token[{}] ctx={} val={} tok={} has zero frequency in distribution (dist_idx={})",
i, ctx, token.value, tok, dist_idx
);
}
}
}
code
}
fn optimize_uint_configs_fast_from_freqs(
freqs_per_histo: &[alloc::collections::BTreeMap<u32, u32>],
lz77: Option<&Lz77Params>,
) -> Vec<HybridUintConfig> {
use crate::entropy_coding::ans::ANS_MAX_ALPHABET_SIZE;
let candidates = [
HybridUintConfig::new(4, 2, 0),
HybridUintConfig::new(4, 1, 2),
HybridUintConfig::new(0, 0, 0),
HybridUintConfig::new(2, 0, 1),
];
let num_histograms = freqs_per_histo.len();
let max_alpha = ANS_MAX_ALPHABET_SIZE;
let mut best_configs = vec![HybridUintConfig::new(4, 2, 0); num_histograms];
let mut counts_buf: Vec<u32> = Vec::new();
for h in 0..num_histograms {
let freqs = &freqs_per_histo[h];
if freqs.is_empty() {
continue;
}
let max_value = freqs.keys().copied().max().unwrap_or(0);
let total: u32 = freqs.values().sum();
let mut best_cost = f64::MAX;
for &cfg in &candidates {
let (max_tok, _, _) = cfg.encode(max_value);
let max_tok_with_lsb = max_tok | ((1u32 << cfg.lsb_in_token) - 1);
if max_tok_with_lsb as usize >= max_alpha {
continue;
}
if let Some(lz77_params) = lz77
&& max_tok_with_lsb >= lz77_params.min_symbol
{
continue;
}
let capacity = max_tok_with_lsb as usize + 1;
counts_buf.clear();
counts_buf.resize(capacity, 0);
let mut extra_bits_total: u64 = 0;
for (&val, &freq) in freqs {
let (tok, _, nbits) = cfg.encode(val);
counts_buf[tok as usize] += freq;
extra_bits_total += nbits as u64 * freq as u64;
}
let inv_total = 1.0f32 / total as f32;
let mut entropy_cost = 0.0f64;
for &count in &counts_buf[..capacity] {
if count > 0 {
let c = count as f32;
entropy_cost -= c as f64 * jxl_simd::fast_log2f(c * inv_total) as f64;
}
}
let signaling_cost = if cfg.split_exponent == 0 {
0.0
} else {
ceil_log2_nonzero_usize(cfg.split_exponent as usize + 1) as f64
+ ceil_log2_nonzero_usize((cfg.split_exponent - cfg.msb_in_token) as usize + 1)
as f64
};
let cost = entropy_cost + extra_bits_total as f64 + signaling_cost;
if cost < best_cost {
best_cost = cost;
best_configs[h] = cfg;
}
}
}
best_configs
}
fn optimize_uint_configs_best_from_freqs(
freqs_per_histo: &[alloc::collections::BTreeMap<u32, u32>],
lz77: Option<&Lz77Params>,
) -> Vec<HybridUintConfig> {
use crate::entropy_coding::ans::ANS_MAX_ALPHABET_SIZE;
#[rustfmt::skip]
let candidates = [
HybridUintConfig::new(0,0,0), HybridUintConfig::new(1,0,0),
HybridUintConfig::new(2,0,0), HybridUintConfig::new(2,0,1),
HybridUintConfig::new(3,0,0), HybridUintConfig::new(3,1,0),
HybridUintConfig::new(3,0,1), HybridUintConfig::new(3,1,1),
HybridUintConfig::new(4,0,0), HybridUintConfig::new(4,2,0),
HybridUintConfig::new(4,1,0), HybridUintConfig::new(4,0,1),
HybridUintConfig::new(4,2,1), HybridUintConfig::new(4,1,1),
HybridUintConfig::new(5,0,0), HybridUintConfig::new(5,2,0),
HybridUintConfig::new(5,1,0), HybridUintConfig::new(5,0,1),
HybridUintConfig::new(5,2,1), HybridUintConfig::new(6,0,0),
HybridUintConfig::new(6,2,0), HybridUintConfig::new(6,1,0),
HybridUintConfig::new(7,0,0), HybridUintConfig::new(7,2,0),
HybridUintConfig::new(8,0,0), HybridUintConfig::new(8,2,0),
HybridUintConfig::new(10,0,0), HybridUintConfig::new(12,0,0),
];
let num_histograms = freqs_per_histo.len();
let max_alpha = ANS_MAX_ALPHABET_SIZE;
let mut best_configs = vec![HybridUintConfig::new(4, 2, 0); num_histograms];
let mut counts_buf: Vec<u32> = Vec::new();
for h in 0..num_histograms {
let freqs = &freqs_per_histo[h];
if freqs.is_empty() {
continue;
}
let max_value = freqs.keys().copied().max().unwrap_or(0);
let total: u32 = freqs.values().sum();
let mut best_cost = f64::MAX;
for &cfg in &candidates {
let (max_tok, _, _) = cfg.encode(max_value);
let max_tok_with_lsb = max_tok | ((1u32 << cfg.lsb_in_token) - 1);
if max_tok_with_lsb as usize >= max_alpha {
continue;
}
if let Some(lz77_params) = lz77
&& max_tok_with_lsb >= lz77_params.min_symbol
{
continue;
}
let capacity = max_tok_with_lsb as usize + 1;
counts_buf.clear();
counts_buf.resize(capacity, 0);
let mut extra_bits_total: u64 = 0;
for (&val, &freq) in freqs {
let (tok, _, nbits) = cfg.encode(val);
counts_buf[tok as usize] += freq;
extra_bits_total += nbits as u64 * freq as u64;
}
let inv_total = 1.0f32 / total as f32;
let mut entropy_cost = 0.0f64;
for &count in &counts_buf[..capacity] {
if count > 0 {
let c = count as f32;
entropy_cost -= c as f64 * jxl_simd::fast_log2f(c * inv_total) as f64;
}
}
let signaling_cost = if cfg.split_exponent == 0 {
0.0
} else {
ceil_log2_nonzero_usize(cfg.split_exponent as usize + 1) as f64
+ ceil_log2_nonzero_usize((cfg.split_exponent - cfg.msb_in_token) as usize + 1)
as f64
};
let cost = entropy_cost + extra_bits_total as f64 + signaling_cost;
if cost < best_cost {
best_cost = cost;
best_configs[h] = cfg;
}
}
}
best_configs
}
pub fn write_entropy_code_ans(code: &OwnedAnsEntropyCode, writer: &mut BitWriter) -> Result<()> {
#[cfg(feature = "debug-tokens")]
{
eprintln!("write_entropy_code_ans:");
eprintln!(" num_contexts: {}", code.context_map.len());
eprintln!(" num_histograms: {}", code.histograms.len());
eprintln!(
" context_map: {:?}",
&code.context_map[..code.context_map.len().min(20)]
);
for (i, h) in code.histograms.iter().enumerate() {
eprintln!(
" histogram[{}]: alphabet_size={}, method={}, counts[..8]={:?}",
i,
h.alphabet_size,
h.method,
&h.counts[..h.counts.len().min(8)]
);
}
}
let _cm_start = writer.bits_written();
write_context_map_for_ans(code, writer)?;
#[cfg(feature = "debug-tokens")]
eprintln!(" context_map: {} bits", writer.bits_written() - _cm_start);
writer.write(1, 0)?;
let las = code.log_alpha_size;
writer.write(2, (las - 5) as u64)?;
#[cfg(feature = "debug-tokens")]
eprintln!(" use_prefix_code=0, log_alpha_size={}", las);
let _cfg_start = writer.bits_written();
for (i, _) in code.histograms.iter().enumerate() {
let config = code.uint_configs.get(i).copied().unwrap_or_default();
write_hybrid_uint_config_value(las, &config, writer)?;
}
#[cfg(feature = "debug-tokens")]
eprintln!(
" HybridUint configs: {} bits ({} histograms)",
writer.bits_written() - _cfg_start,
code.histograms.len()
);
let _hist_start = writer.bits_written();
#[allow(clippy::unused_enumerate_index)]
for (_i, histo) in code.histograms.iter().enumerate() {
let _h_start = writer.bits_written();
histo.write(writer)?;
#[cfg(feature = "debug-tokens")]
eprintln!(
" histogram[{}]: {} bits",
_i,
writer.bits_written() - _h_start
);
}
#[cfg(feature = "debug-tokens")]
eprintln!(
" All histograms: {} bits",
writer.bits_written() - _hist_start
);
Ok(())
}
fn write_context_map_for_ans(code: &OwnedAnsEntropyCode, writer: &mut BitWriter) -> Result<()> {
let num_histograms = code.histograms.len();
if num_histograms == 1 {
writer.write(1, 1)?; writer.write(2, 0)?; return Ok(());
}
let entry_bits = ceil_log2_nonzero_usize(num_histograms);
if entry_bits < 4 {
let simple_cost = 3 + entry_bits * code.context_map.len();
let mut scratch = BitWriter::with_capacity(code.context_map.len());
write_context_map_nonsimple(&code.context_map, &mut scratch)?;
let nonsimple_cost = scratch.bits_written();
if simple_cost <= nonsimple_cost {
writer.write(1, 1)?; writer.write(2, entry_bits as u64)?;
for &ctx in &code.context_map {
writer.write(entry_bits, ctx as u64)?;
}
return Ok(());
}
let scratch_bytes = scratch.finish_with_padding();
let bits_to_copy = nonsimple_cost;
copy_bits(&scratch_bytes, bits_to_copy, writer)?;
return Ok(());
}
write_context_map_nonsimple(&code.context_map, writer)
}
fn write_context_map_nonsimple(context_map: &[u8], writer: &mut BitWriter) -> Result<()> {
let mtf_tokens = move_to_front_transform(context_map);
let direct_cost = estimate_context_map_cost(context_map);
let mtf_cost = estimate_context_map_cost(&mtf_tokens);
let use_mtf = mtf_cost < direct_cost;
let tokens: &[u8] = if use_mtf { &mtf_tokens } else { context_map };
let header_bits = if use_mtf { 0b010u64 } else { 0b000u64 };
writer.write(3, header_bits)?;
writer.write(1, 1)?;
writer.write(4, 4)?; writer.write(3, 2)?; writer.write(2, 0)?;
let mut histogram = [0u32; ALPHABET_SIZE];
for &t in tokens {
let encoded = UintCoder::encode(t as u32);
histogram[encoded.token as usize] += 1;
}
let mut length = ALPHABET_SIZE;
while length > 0 && histogram[length - 1] == 0 {
length -= 1;
}
length = length.max(1);
let mut depths = [0u8; ALPHABET_SIZE];
create_huffman_tree(&histogram, length, 15, &mut depths);
let mut bits = [0u16; ALPHABET_SIZE];
convert_bit_depths_to_symbols(&depths, &mut bits);
write_var_len_uint16(length - 1, writer)?;
if length > 1 {
let pc = PrefixCode { depths, bits };
write_prefix_code(&pc, writer)?;
}
for &t in tokens {
let encoded = UintCoder::encode(t as u32);
let tok = encoded.token as usize;
let depth = depths[tok] as usize;
let b = bits[tok] as u64;
let data = b | ((encoded.bits as u64) << depth);
let total_bits = depth + encoded.nbits as usize;
writer.write(total_bits, data)?;
}
Ok(())
}
fn copy_bits(src: &[u8], num_bits: usize, writer: &mut BitWriter) -> Result<()> {
let full_bytes = num_bits / 8;
let remaining_bits = num_bits % 8;
for &byte in &src[..full_bytes] {
writer.write(8, byte as u64)?;
}
if remaining_bits > 0 {
let last_byte = src[full_bytes];
let mask = (1u64 << remaining_bits) - 1;
writer.write(remaining_bits, (last_byte as u64) & mask)?;
}
Ok(())
}
fn estimate_context_map_cost(tokens: &[u8]) -> f64 {
if tokens.is_empty() {
return 0.0;
}
let mut counts = [0u32; 256];
for &t in tokens {
counts[t as usize] += 1;
}
let inv_total = 1.0f32 / tokens.len() as f32;
let mut cost = 0.0f32;
for &c in &counts {
if c > 0 {
let cf = c as f32;
let p = cf * inv_total;
cost -= p * jxl_simd::fast_log2f(p);
}
}
(cost * tokens.len() as f32) as f64
}
fn write_hybrid_uint_config_value(
log_alpha_size: usize,
config: &HybridUintConfig,
writer: &mut BitWriter,
) -> Result<()> {
let split_exponent = config.split_exponent;
let msb_in_token = config.msb_in_token;
let lsb_in_token = config.lsb_in_token;
let se_bits = ceil_log2_nonzero_usize(log_alpha_size + 1);
writer.write(se_bits, split_exponent as u64)?;
if split_exponent as usize == log_alpha_size {
return Ok(());
}
let msb_bits = ceil_log2_nonzero_usize(split_exponent as usize + 1);
writer.write(msb_bits, msb_in_token as u64)?;
let lsb_bits = ceil_log2_nonzero_usize((split_exponent - msb_in_token) as usize + 1);
writer.write(lsb_bits, lsb_in_token as u64)?;
Ok(())
}
fn ceil_log2_nonzero_usize(x: usize) -> usize {
debug_assert!(x > 0);
let x = x as u32;
let floor = 31 - x.leading_zeros();
if x.is_power_of_two() {
floor as usize
} else {
(floor + 1) as usize
}
}
pub fn write_tokens_ans(
tokens: &[Token],
code: &OwnedAnsEntropyCode,
lz77: Option<&Lz77Params>,
writer: &mut BitWriter,
) -> Result<()> {
let mut encoder = AnsEncoder::with_capacity(tokens.len());
#[cfg(feature = "debug-tokens")]
{
eprintln!(
"write_tokens_ans: {} tokens, {} distributions, context_map len={}",
tokens.len(),
code.distributions.len(),
code.context_map.len()
);
eprintln!(" initial state: 0x{:08x}", encoder.state());
}
#[allow(clippy::unused_enumerate_index)]
for (_i, token) in tokens.iter().rev().enumerate() {
let ctx = token.context() as usize;
let dist_idx = code.context_map.get(ctx).copied().unwrap_or(0) as usize;
let config = code.uint_configs.get(dist_idx).copied().unwrap_or_default();
let (encoded, sym) = encode_token_value_with_config(token, lz77, &config);
let dist = code.distributions.get(dist_idx).unwrap_or_else(|| {
panic!(
"ANS: missing distribution at index {} for context {}",
dist_idx, ctx
)
});
encoder.push_bits(encoded.bits, encoded.nbits as u8);
let info = dist.get(sym as usize).unwrap_or_else(|| {
panic!(
"ANS: symbol {} not in distribution (ctx={}, dist_idx={})",
sym, ctx, dist_idx
)
});
#[cfg(feature = "debug-tokens")]
if _i < 5 || _i >= tokens.len() - 3 {
eprintln!(
" token[{}]: ctx={}, val={}, tok={}, freq={}, state before=0x{:08x}",
tokens.len() - 1 - _i,
ctx,
token.value,
sym,
info.freq,
encoder.state()
);
}
encoder.put_symbol(info);
}
#[cfg(feature = "debug-tokens")]
eprintln!(" final state: 0x{:08x}", encoder.state());
encoder.finalize(writer)?;
Ok(())
}
pub fn verify_histogram_serialization(code: &OwnedAnsEntropyCode, label: &str) -> Result<()> {
use crate::entropy_coding::ans_decode::{AnsHistogram, BitReader};
for (i, histo) in code.histograms.iter().enumerate() {
let mut writer = BitWriter::new();
histo.write(&mut writer)?;
writer.write(8, 0)?;
writer.zero_pad_to_byte();
let bytes = writer.finish();
let mut br = BitReader::new(&bytes);
let decoded = match AnsHistogram::decode(&mut br, code.log_alpha_size) {
Ok(d) => d,
Err(e) => {
#[cfg(feature = "debug-rect")]
eprintln!(
"{} histo[{}]: DECODE FAILED - {} (method={} alpha={} omit={})",
label, i, e, histo.method, histo.alphabet_size, histo.omit_pos
);
return Err(e);
}
};
let mut mismatch = false;
for j in 0..histo.alphabet_size {
let expected = histo.counts[j] as u16;
let got = decoded.frequencies[j];
if expected != got {
if !mismatch {
#[cfg(feature = "debug-rect")]
eprintln!(
"{} histo[{}]: FREQ MISMATCH (method={} alpha={})",
label, i, histo.method, histo.alphabet_size
);
}
#[cfg(feature = "debug-rect")]
eprintln!("sym[{}]: expected={} got={}", j, expected, got);
mismatch = true;
}
}
if mismatch {
#[cfg(feature = "debug-rect")]
{
eprintln!("counts: {:?}", &histo.counts[..histo.alphabet_size]);
let mut encoder_omit_logcount = 0u32;
let mut encoder_omit = 0;
for (k, &c) in histo.counts.iter().enumerate().take(histo.alphabet_size) {
if c > 0 {
let lc = crate::entropy_coding::ans::floor_log2_ans(c as u32) + 1;
if lc > encoder_omit_logcount {
encoder_omit_logcount = lc;
encoder_omit = k;
}
}
}
eprintln!(
"omit_pos={} (stored: method={} omit={})",
encoder_omit, histo.method, histo.omit_pos
);
eprintln!(
"omit logcount={} count_at_omit={}",
encoder_omit_logcount, histo.counts[histo.omit_pos]
);
for k in 0..histo.alphabet_size.min(40) {
let c = histo.counts[k];
if c > 0 {
let lc = crate::entropy_coding::ans::floor_log2_ans(c as u32) + 1;
if lc == encoder_omit_logcount {
eprintln!("sym[{}]: count={} logcount={} (same as max)", k, c, lc);
}
}
}
}
return Err(Error::InvalidHistogram(format!(
"{} histogram[{}] serialization roundtrip failed",
label, i
)));
}
#[cfg(feature = "debug-tokens")]
{
let method_desc = match histo.method {
0 => "flat",
1 => "small",
_ => "general",
};
eprintln!(
" {} histogram[{}]: OK ({}, {} symbols, {} bytes)",
label,
i,
method_desc,
histo.alphabet_size,
bytes.len()
);
}
}
Ok(())
}
pub fn verify_ans_roundtrip(tokens: &[Token], code: &OwnedAnsEntropyCode) -> Result<()> {
use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
if tokens.is_empty() {
return Ok(());
}
let mut header_writer = BitWriter::new();
write_entropy_code_ans(code, &mut header_writer)?;
let header_bits = header_writer.bits_written();
let mut token_writer = BitWriter::new();
write_tokens_ans(tokens, code, None, &mut token_writer)?;
let _token_bits = token_writer.bits_written();
let mut combined_writer = BitWriter::new();
write_entropy_code_ans(code, &mut combined_writer)?;
write_tokens_ans(tokens, code, None, &mut combined_writer)?;
combined_writer.zero_pad_to_byte();
let encoded_bytes = combined_writer.finish();
let mut br = BitReader::new(&encoded_bytes);
let _num_histograms = code.histograms.len();
let _simple = br.read(1)?; let mut br2 = BitReader::new(&encoded_bytes);
for _ in 0..header_bits {
br2.read(1)?;
}
let mut ans_reader = AnsReader::init(&mut br2)?;
let _br_hist = BitReader::new(&encoded_bytes);
let log_alpha_size = code.log_alpha_size;
let table_size = 1usize << log_alpha_size;
let decoder_histograms: Vec<AnsHistogram> = code
.distributions
.iter()
.map(|dist| {
let mut freqs = vec![0u16; dist.symbols.len().max(table_size)];
for (i, sym) in dist.symbols.iter().enumerate() {
freqs[i] = sym.freq;
}
let log_bucket_size = 12 - log_alpha_size; let bucket_size = 1u16 << log_bucket_size;
let bucket_mask = bucket_size as u32 - 1;
if let Some(single_idx) = freqs.iter().position(|&f| f == 4096) {
let buckets = freqs
.iter()
.enumerate()
.map(|(i, &f)| crate::entropy_coding::ans_decode::Bucket {
dist: f,
alias_symbol: single_idx as u8,
alias_offset: bucket_size * i as u16,
alias_cutoff: 0,
alias_dist_xor: f ^ 4096,
})
.collect();
return AnsHistogram {
buckets,
log_bucket_size,
bucket_mask,
single_symbol: Some(single_idx as u32),
frequencies: freqs,
};
}
let buckets =
AnsHistogram::build_alias_map_from_freqs(freqs.len(), log_bucket_size, &freqs);
AnsHistogram {
buckets,
log_bucket_size,
bucket_mask,
single_symbol: None,
frequencies: freqs,
}
})
.collect();
let mut mismatches = 0;
#[allow(clippy::unused_enumerate_index)] for (_i, token) in tokens.iter().enumerate() {
let ctx = token.context() as usize;
let dist_idx = code.context_map.get(ctx).copied().unwrap_or(0) as usize;
let decoder_hist = &decoder_histograms[dist_idx];
let decoded_symbol = decoder_hist.read(&mut br2, &mut ans_reader.0);
let config = code.uint_configs.get(dist_idx).copied().unwrap_or_default();
let (expected_encoded, _expected_sym) =
encode_token_value_with_config(token, None, &config);
let decoded_extra = if expected_encoded.nbits > 0 {
br2.read(expected_encoded.nbits as usize).unwrap_or(0) as u32
} else {
0
};
if decoded_symbol != expected_encoded.token {
if mismatches < 5 {
#[cfg(feature = "debug-rect")]
eprintln!(
"MISMATCH token[{}]: ctx={} val={} exp={} got={} state=0x{:08x}",
_i, ctx, token.value, expected_encoded.token, decoded_symbol, ans_reader.0
);
}
mismatches += 1;
}
if decoded_extra != expected_encoded.bits {
if mismatches < 5 {
#[cfg(feature = "debug-rect")]
eprintln!(
"BITS MISMATCH token[{}]: exp=0x{:x} got=0x{:x}",
_i, expected_encoded.bits, decoded_extra
);
}
mismatches += 1;
}
}
if let Err(e) = ans_reader.check_final_state() {
return Err(Error::Bitstream(format!(
"ANS roundtrip final state check failed ({} token mismatches): {}",
mismatches, e
)));
}
if mismatches > 0 {
return Err(Error::Bitstream(format!(
"ANS roundtrip had {} mismatches out of {} tokens",
mismatches,
tokens.len()
)));
}
#[cfg(feature = "debug-tokens")]
eprintln!(
"ANS roundtrip OK: {} tokens, header={} bits, data={} bits",
tokens.len(),
header_bits,
_token_bits
);
Ok(())
}
#[cfg(debug_assertions)]
pub fn verify_ans_roundtrip_parsed(tokens: &[Token], code: &OwnedAnsEntropyCode) -> Result<()> {
use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
#[inline]
fn ceil_log2_nonzero(x: u32) -> u32 {
if x <= 1 {
0
} else {
u32::BITS - (x - 1).leading_zeros()
}
}
if tokens.is_empty() {
return Ok(());
}
assert_eq!(
code.histograms.len(),
1,
"verify_ans_roundtrip_parsed only supports single-distribution"
);
let mut writer = BitWriter::new();
writer.write(1, 0)?;
writer.write(1, 0)?;
let las = code.log_alpha_size;
writer.write(2, (las - 5) as u64)?;
let config = code
.uint_configs
.first()
.copied()
.unwrap_or(crate::entropy_coding::hybrid_uint::HybridUintConfig::default_config());
let se_bits = ceil_log2_nonzero(las as u32 + 1) as usize;
writer.write(se_bits, config.split_exponent as u64)?;
if (config.split_exponent as usize) != las {
let msb_bits = ceil_log2_nonzero(config.split_exponent + 1) as usize;
writer.write(msb_bits, config.msb_in_token as u64)?;
let lsb_bits = ceil_log2_nonzero(config.split_exponent - config.msb_in_token + 1) as usize;
writer.write(lsb_bits, config.lsb_in_token as u64)?;
}
code.histograms[0].write(&mut writer)?;
let header_bits = writer.bits_written();
write_tokens_ans(tokens, code, None, &mut writer)?;
writer.zero_pad_to_byte();
let encoded_bytes = writer.finish();
eprintln!(
"verify_ans_roundtrip_parsed: header={} bits, total={} bytes",
header_bits,
encoded_bytes.len()
);
let mut br = BitReader::new(&encoded_bytes);
let lz77_enabled = br.read(1)?;
assert_eq!(lz77_enabled, 0, "expected lz77.enabled=0");
let use_prefix_code = br.read(1)?;
assert_eq!(use_prefix_code, 0, "expected use_prefix_code=0 (ANS)");
let las = br.read(2)? as usize + 5;
eprintln!(" parsed log_alpha_size={}", las);
assert_eq!(las, code.log_alpha_size, "log_alpha_size mismatch");
let se_bits = ceil_log2_nonzero(las as u32 + 1) as usize;
let split_exponent = br.read(se_bits)? as u32;
let (msb_in_token, lsb_in_token) = if split_exponent != las as u32 {
let msb_bits = ceil_log2_nonzero(split_exponent + 1) as usize;
let msb = br.read(msb_bits)? as u32;
let lsb_bits = ceil_log2_nonzero(split_exponent - msb + 1) as usize;
let lsb = br.read(lsb_bits)? as u32;
(msb, lsb)
} else {
(0, 0)
};
let expected_config = code
.uint_configs
.first()
.copied()
.unwrap_or(crate::entropy_coding::hybrid_uint::HybridUintConfig::default_config());
eprintln!(
" parsed uint_config: se={} msb={} lsb={} (expected se={} msb={} lsb={})",
split_exponent,
msb_in_token,
lsb_in_token,
expected_config.split_exponent,
expected_config.msb_in_token,
expected_config.lsb_in_token
);
assert_eq!(
split_exponent, expected_config.split_exponent,
"split_exponent mismatch"
);
assert_eq!(
msb_in_token, expected_config.msb_in_token,
"msb_in_token mismatch"
);
assert_eq!(
lsb_in_token, expected_config.lsb_in_token,
"lsb_in_token mismatch"
);
let parsed_histo = AnsHistogram::decode(&mut br, las)?;
let bits_after_histo = br.bits_read();
eprintln!(
" histogram parsed OK at bit {}, freqs: {:?}",
bits_after_histo,
&parsed_histo.frequencies[..parsed_histo.frequencies.len().min(10)]
);
for (i, sym) in code.distributions[0].symbols.iter().enumerate() {
let parsed_freq = parsed_histo.frequencies.get(i).copied().unwrap_or(0);
if parsed_freq != sym.freq {
eprintln!(
" FREQ MISMATCH at symbol {}: encoder={} parsed={}",
i, sym.freq, parsed_freq
);
return Err(Error::Bitstream(format!(
"Parsed histogram frequency mismatch at symbol {}: encoder={} parsed={}",
i, sym.freq, parsed_freq
)));
}
}
eprintln!(" frequencies match encoder's distribution");
let mut ans_reader = AnsReader::init(&mut br)?;
eprintln!(" ANS initial state: 0x{:08x}", ans_reader.state());
let config = crate::entropy_coding::hybrid_uint::HybridUintConfig {
split_exponent,
split: 1 << split_exponent,
msb_in_token,
lsb_in_token,
};
let mut mismatches = 0;
for (i, token) in tokens.iter().enumerate() {
let (expected_encoded, _) =
crate::entropy_coding::encode::encode_token_value_with_config(token, None, &config);
let decoded_symbol = parsed_histo.read(&mut br, &mut ans_reader.0);
let decoded_extra = if expected_encoded.nbits > 0 {
br.read(expected_encoded.nbits as usize).unwrap_or(0) as u32
} else {
0
};
if decoded_symbol != expected_encoded.token || decoded_extra != expected_encoded.bits {
if mismatches < 5 {
eprintln!(
" MISMATCH token[{}]: val={} exp_tok={} got_tok={} exp_bits=0x{:x} got_bits=0x{:x}",
i,
token.value,
expected_encoded.token,
decoded_symbol,
expected_encoded.bits,
decoded_extra
);
}
mismatches += 1;
}
}
if let Err(e) = ans_reader.check_final_state() {
eprintln!(
" FINAL STATE FAILED: {} mismatches, state=0x{:08x}",
mismatches,
ans_reader.state()
);
return Err(Error::Bitstream(format!(
"ANS parsed roundtrip final state check failed ({} mismatches): {}",
mismatches, e
)));
}
if mismatches > 0 {
return Err(Error::Bitstream(format!(
"ANS parsed roundtrip had {} mismatches",
mismatches
)));
}
eprintln!(" ANS parsed roundtrip OK: {} tokens", tokens.len());
Ok(())
}