use crate::bit_writer::BitWriter;
use crate::error::{Error, Result};
pub const ANS_LOG_TAB_SIZE: u32 = 12;
pub const ANS_TAB_SIZE: u32 = 1 << ANS_LOG_TAB_SIZE;
pub const ANS_TAB_MASK: u32 = ANS_TAB_SIZE - 1;
pub const ANS_MAX_ALPHABET_SIZE: usize = 256;
pub const ANS_SIGNATURE: u32 = 0x13;
const RLE_MARKER_SYM: u8 = 13;
const LOGCOUNT_PREFIX_CODE: [(u8, u8); 14] = [
(5, 0b10001), (4, 0b1011), (4, 0b1111), (4, 0b0011), (4, 0b1001), (4, 0b0111), (3, 0b100), (3, 0b010), (3, 0b101), (3, 0b110), (3, 0b000), (6, 0b100001), (7, 0b0000001), (7, 0b1000001), ];
fn build_allowed_counts(shift: u32) -> Vec<i32> {
let mut counts = Vec::with_capacity(256);
counts.push(1i32);
for bits in 1..ANS_LOG_TAB_SIZE {
let precision = get_population_count_precision(bits, shift);
let drop_bits = bits.saturating_sub(precision);
let num_mantissa = 1u32 << precision;
for mantissa in 0..num_mantissa {
let count = (1i32 << bits) | ((mantissa as i32) << drop_bits);
if count > 0 && count < ANS_TAB_SIZE as i32 {
counts.push(count);
}
}
}
counts.sort_unstable();
counts.dedup();
counts.reverse(); counts
}
pub struct AllowedCountsCache {
tables: [Vec<i32>; ANS_LOG_TAB_SIZE as usize + 1],
}
impl Default for AllowedCountsCache {
fn default() -> Self {
Self::new()
}
}
impl AllowedCountsCache {
pub fn new() -> Self {
Self {
tables: core::array::from_fn(|shift| build_allowed_counts(shift as u32)),
}
}
#[inline]
pub fn get(&self, shift: u32) -> &[i32] {
&self.tables[shift as usize]
}
}
fn find_allowed_leq(allowed: &[i32], target: i32) -> usize {
let mut lo = 0usize;
let mut hi = allowed.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if allowed[mid] > target {
lo = mid + 1;
} else {
hi = mid;
}
}
if lo >= allowed.len() {
allowed.len() - 1 } else {
lo
}
}
fn estimate_data_bits_normalized(
histo_counts: &[i32],
norm_counts: &[i32],
total_count: usize,
alphabet_size: usize,
) -> f64 {
let mut sum = 0.0f64;
for (actual, norm) in histo_counts
.iter()
.zip(norm_counts.iter())
.take(alphabet_size)
{
if *actual > 0 && *norm > 0 {
sum += *actual as f64 * jxl_simd::fast_log2f(*norm as f32) as f64;
}
}
total_count as f64 * ANS_LOG_TAB_SIZE as f64 - sum
}
const RECIPROCAL_PRECISION: u32 = 44;
#[derive(Debug, Clone)]
pub struct AnsEncSymbolInfo {
pub freq: u16,
pub ifreq: u64,
pub reverse_map: Vec<u16>,
}
impl AnsEncSymbolInfo {
pub fn new(freq: u16) -> Self {
let ifreq = if freq > 0 {
(1u64 << RECIPROCAL_PRECISION).div_ceil(freq as u64)
} else {
0
};
Self {
freq,
ifreq,
reverse_map: Vec::new(), }
}
}
pub struct AnsEncoder {
state: u32,
bits: Vec<(u32, u8)>, }
impl AnsEncoder {
pub fn new() -> Self {
Self {
state: ANS_SIGNATURE << 16,
bits: Vec::new(),
}
}
pub fn with_capacity(num_tokens: usize) -> Self {
Self {
state: ANS_SIGNATURE << 16,
bits: Vec::with_capacity(num_tokens * 2), }
}
#[inline]
pub fn put_symbol(&mut self, info: &AnsEncSymbolInfo) {
let freq = info.freq as u32;
if (self.state >> (32 - ANS_LOG_TAB_SIZE)) >= freq {
self.bits.push((self.state & 0xFFFF, 16));
self.state >>= 16;
}
let v = ((self.state as u64 * info.ifreq) >> RECIPROCAL_PRECISION) as u32;
let remainder = self.state - v * freq;
let offset = info.reverse_map[remainder as usize] as u32;
self.state = (v << ANS_LOG_TAB_SIZE) + offset;
}
#[inline]
pub fn push_bits(&mut self, bits: u32, nbits: u8) {
if nbits > 0 {
self.bits.push((bits, nbits));
}
}
pub fn finalize(self, writer: &mut BitWriter) -> Result<()> {
#[cfg(feature = "debug-tokens")]
eprintln!(
"ANS finalize: state=0x{:08x}, {} bit chunks",
self.state,
self.bits.len()
);
writer.write(32, self.state as u64)?;
for &(bits, nbits) in self.bits.iter().rev() {
writer.write(nbits as usize, bits as u64)?;
}
Ok(())
}
pub fn state(&self) -> u32 {
self.state
}
}
impl Default for AnsEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AnsDistribution {
pub symbols: Vec<AnsEncSymbolInfo>,
pub log_alpha_size: u32,
pub total: u32,
}
impl AnsDistribution {
pub fn from_frequencies(freqs: &[u32]) -> Result<Self> {
if freqs.is_empty() {
return Err(Error::InvalidHistogram("empty distribution".to_string()));
}
let total_count: u64 = freqs.iter().map(|&f| f as u64).sum();
if total_count == 0 {
return Err(Error::InvalidHistogram("all zero frequencies".to_string()));
}
let mut normalized: Vec<u16> = Vec::with_capacity(freqs.len());
let mut running_total: u32 = 0;
for &freq in freqs.iter() {
let normalized_freq = if freq == 0 {
0
} else {
((freq as u64 * ANS_TAB_SIZE as u64) / total_count).max(1) as u16
};
normalized.push(normalized_freq);
running_total += normalized_freq as u32;
}
let diff = running_total as i32 - ANS_TAB_SIZE as i32;
if diff != 0 {
if let Some((max_idx, _)) = normalized
.iter()
.enumerate()
.filter(|&(_, &f)| f > 0)
.max_by_key(|&(_, &f)| f)
{
let new_val = (normalized[max_idx] as i32 - diff).max(1) as u16;
normalized[max_idx] = new_val;
}
}
let mut symbols: Vec<AnsEncSymbolInfo> = normalized
.iter()
.map(|&f| AnsEncSymbolInfo::new(f))
.collect();
let log_alpha_size = Self::default_log_alpha_size(symbols.len());
Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
Ok(Self {
symbols,
log_alpha_size: ANS_LOG_TAB_SIZE,
total: ANS_TAB_SIZE,
})
}
pub fn flat(alphabet_size: usize) -> Result<Self> {
if alphabet_size == 0 || alphabet_size > ANS_TAB_SIZE as usize {
return Err(Error::InvalidHistogram(format!(
"invalid alphabet size: {}",
alphabet_size
)));
}
let base_freq = ANS_TAB_SIZE as usize / alphabet_size;
let remainder = ANS_TAB_SIZE as usize % alphabet_size;
let mut freqs = vec![base_freq as u32; alphabet_size];
for freq in freqs.iter_mut().take(remainder) {
*freq += 1;
}
Self::from_frequencies(&freqs)
}
pub fn from_normalized_counts(counts: &[i32]) -> Result<Self> {
let log_alpha_size = Self::default_log_alpha_size(counts.len());
Self::from_normalized_counts_with_log_alpha(counts, log_alpha_size)
}
pub fn from_normalized_counts_with_log_alpha(
counts: &[i32],
log_alpha_size: usize,
) -> Result<Self> {
if counts.is_empty() {
return Err(Error::InvalidHistogram("empty distribution".to_string()));
}
let total: i32 = counts.iter().sum();
if total != ANS_TAB_SIZE as i32 {
return Err(Error::InvalidHistogram(format!(
"normalized counts sum to {} instead of {}",
total, ANS_TAB_SIZE
)));
}
let mut symbols: Vec<AnsEncSymbolInfo> = counts
.iter()
.map(|&c| AnsEncSymbolInfo::new(c.max(0) as u16))
.collect();
Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
Ok(Self {
symbols,
log_alpha_size: ANS_LOG_TAB_SIZE,
total: ANS_TAB_SIZE,
})
}
fn default_log_alpha_size(alphabet_size: usize) -> usize {
use super::encode_ans::ANS_LOG_ALPHA_SIZE;
if alphabet_size <= (1 << ANS_LOG_ALPHA_SIZE) {
ANS_LOG_ALPHA_SIZE
} else {
let min_bits = if alphabet_size <= 1 {
5
} else {
(alphabet_size - 1).ilog2() as usize + 1
};
min_bits.clamp(5, 8)
}
}
fn build_reverse_maps(symbols: &mut [AnsEncSymbolInfo], log_alpha_size: usize) -> Result<()> {
let alphabet_size = symbols.len();
if alphabet_size == 0 {
return Ok(());
}
let total: u32 = symbols.iter().map(|s| s.freq as u32).sum();
if total != ANS_TAB_SIZE {
return Err(Error::InvalidHistogram(format!(
"frequencies sum to {} instead of {}",
total, ANS_TAB_SIZE
)));
}
if let Some(single_sym_idx) = symbols.iter().position(|s| s.freq == ANS_TAB_SIZE as u16) {
for sym in symbols.iter_mut() {
sym.reverse_map.clear();
}
let map = &mut symbols[single_sym_idx].reverse_map;
map.resize(ANS_TAB_SIZE as usize, 0);
for (i, v) in map.iter_mut().enumerate() {
*v = i as u16;
}
return Ok(());
}
let table_size = 1usize << log_alpha_size;
let log_bucket_size = ANS_LOG_TAB_SIZE as usize - log_alpha_size;
let bucket_size = 1u16 << log_bucket_size;
#[derive(Clone)]
#[allow(dead_code)]
struct WorkingBucket {
dist: u16, alias_symbol: u16, alias_offset: u16, alias_cutoff: u16, }
let mut buckets: Vec<WorkingBucket> = (0..table_size)
.map(|i| {
let dist = if i < alphabet_size {
symbols[i].freq
} else {
0
};
WorkingBucket {
dist,
alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
alias_offset: 0,
alias_cutoff: dist,
}
})
.collect();
let mut underfull: Vec<usize> = Vec::with_capacity(table_size);
let mut overfull: Vec<usize> = Vec::with_capacity(table_size);
for (i, bucket) in buckets.iter().enumerate() {
if bucket.alias_cutoff < bucket_size {
underfull.push(i);
} else if bucket.alias_cutoff > bucket_size {
overfull.push(i);
}
}
while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
let by = bucket_size - buckets[u].alias_cutoff;
buckets[o].alias_cutoff -= by;
buckets[u].alias_symbol = o as u16;
buckets[u].alias_offset = buckets[o].alias_cutoff;
match buckets[o].alias_cutoff.cmp(&bucket_size) {
std::cmp::Ordering::Less => underfull.push(o),
std::cmp::Ordering::Greater => overfull.push(o),
std::cmp::Ordering::Equal => {}
}
}
for sym in symbols.iter_mut() {
sym.reverse_map.clear();
sym.reverse_map.resize(sym.freq as usize, 0);
}
for idx in 0..ANS_TAB_SIZE {
let bucket_idx = (idx >> log_bucket_size) as usize;
let pos = (idx as u16) & (bucket_size - 1);
let bucket = &buckets[bucket_idx.min(table_size - 1)];
let alias_cutoff = bucket.alias_cutoff;
let (symbol, offset) = if pos < alias_cutoff {
(bucket_idx, pos)
} else {
let alias_sym = bucket.alias_symbol as usize;
let offset = bucket.alias_offset - alias_cutoff + pos;
(alias_sym, offset)
};
if symbol < alphabet_size {
symbols[symbol].reverse_map[offset as usize] = idx as u16;
}
}
Ok(())
}
pub fn alphabet_size(&self) -> usize {
self.symbols.len()
}
pub fn get(&self, symbol: usize) -> Option<&AnsEncSymbolInfo> {
self.symbols.get(symbol)
}
pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
let is_flat = self.is_flat();
writer.write(1, 0)?; writer.write(1, u64::from(is_flat))?;
if is_flat {
write_var_len_uint8(writer, (self.alphabet_size() - 1) as u8)?;
} else {
self.write_general(writer)?;
}
Ok(())
}
fn is_flat(&self) -> bool {
let first_freq = self.symbols.first().map(|s| s.freq).unwrap_or(0);
if first_freq == 0 {
return false;
}
self.symbols
.iter()
.all(|s| s.freq == first_freq || s.freq == first_freq - 1)
}
fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
let method: u64 = 13; let upper_bound_log = 4; let log = floor_log2(method as u32);
writer.write(log as usize, (1u64 << log) - 1)?;
if log != upper_bound_log {
writer.write(1, 0)?;
}
writer.write(log as usize, ((1u64 << log) - 1) & method)?;
write_var_len_uint8(writer, (self.alphabet_size() - 3) as u8)?;
for sym in &self.symbols {
let freq = sym.freq;
if freq == 0 {
writer.write(1, 0)?;
} else {
writer.write(1, 1)?;
let bits = 16 - freq.leading_zeros();
writer.write(4, bits as u64)?;
if bits > 0 {
writer.write(bits as usize, freq as u64)?;
}
}
}
Ok(())
}
}
fn write_var_len_uint8(writer: &mut BitWriter, n: u8) -> Result<()> {
if n == 0 {
writer.write(1, 0)?;
} else {
writer.write(1, 1)?;
let nbits = 8 - n.leading_zeros();
writer.write(3, (nbits - 1) as u64)?;
writer.write((nbits - 1) as usize, (n as u64) - (1u64 << (nbits - 1)))?;
}
Ok(())
}
#[inline]
pub fn floor_log2_ans(n: u32) -> u32 {
if n == 0 { 0 } else { 31 - n.leading_zeros() }
}
#[inline]
fn floor_log2(n: u32) -> u32 {
floor_log2_ans(n)
}
pub fn get_population_count_precision(logcount: u32, shift: u32) -> u32 {
let logcount_i = logcount as i32;
let shift_i = shift as i32;
let r = logcount_i.min(shift_i - ((ANS_LOG_TAB_SIZE as i32 - logcount_i) >> 1));
r.max(0) as u32
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum ANSHistogramStrategy {
Fast,
Approximate,
#[default]
Precise,
}
#[derive(Clone, Debug)]
pub struct ANSEncodingHistogram {
pub counts: Vec<i32>,
pub alphabet_size: usize,
pub cost: f32,
pub method: u32,
pub omit_pos: usize,
num_symbols: usize,
symbols: [usize; 2],
}
impl ANSEncodingHistogram {
pub fn new() -> Self {
Self {
counts: Vec::new(),
alphabet_size: 0,
cost: f32::MAX,
method: 0,
omit_pos: 0,
num_symbols: 0,
symbols: [0, 0],
}
}
pub fn from_histogram(
histo: &super::histogram::Histogram,
strategy: ANSHistogramStrategy,
) -> Result<Self> {
let cache = AllowedCountsCache::new();
Self::from_histogram_cached(histo, strategy, &cache)
}
pub fn from_histogram_cached(
histo: &super::histogram::Histogram,
strategy: ANSHistogramStrategy,
cache: &AllowedCountsCache,
) -> Result<Self> {
if histo.total_count == 0 {
return Ok(Self {
counts: vec![0i32; histo.counts.len().max(1)],
alphabet_size: 1,
cost: 0.0,
method: 0, omit_pos: 0,
num_symbols: 0,
symbols: [0, 0],
});
}
let alphabet_size = histo.alphabet_size();
let mut num_symbols = 0;
let mut symbols = [0usize; 2];
for (i, &count) in histo.counts.iter().enumerate() {
if count > 0 {
if num_symbols < 2 {
symbols[num_symbols] = i;
}
num_symbols += 1;
}
}
if num_symbols <= 2 {
let mut counts = vec![0i32; alphabet_size];
if num_symbols == 1 {
counts[symbols[0]] = ANS_TAB_SIZE as i32;
} else {
let total = histo.total_count as f64;
let count0 = histo.counts[symbols[0]] as f64;
let norm0 = ((count0 / total) * ANS_TAB_SIZE as f64).round() as i32;
let norm0 = norm0.clamp(1, (ANS_TAB_SIZE - 1) as i32);
counts[symbols[0]] = norm0;
counts[symbols[1]] = ANS_TAB_SIZE as i32 - norm0;
}
let cost = if num_symbols <= 1 { 4.0 } else { 4.0 + 12.0 };
return Ok(Self {
counts,
alphabet_size,
cost,
method: 1, omit_pos: symbols[0],
num_symbols,
symbols,
});
}
let flat_data_cost = {
let log2_alpha = jxl_simd::fast_log2f(alphabet_size as f32);
histo.total_count as f32 * log2_alpha
};
let flat_header_cost = 2.0 + 8.0; let mut best = Self {
counts: {
let alpha = alphabet_size as u32;
let per = ANS_TAB_SIZE / alpha;
let remainder = (ANS_TAB_SIZE % alpha) as usize;
let mut c = vec![per as i32; alphabet_size];
for c in c.iter_mut().take(remainder) {
*c += 1;
}
c
},
alphabet_size,
cost: flat_header_cost + flat_data_cost,
method: 0, omit_pos: 0,
num_symbols,
symbols,
};
let mut candidate_counts = vec![0i32; alphabet_size];
let shift_iter: &[u32] = match strategy {
ANSHistogramStrategy::Fast => &[0, 6, 12],
ANSHistogramStrategy::Approximate => &[0, 2, 4, 6, 8, 10, 12],
ANSHistogramStrategy::Precise => &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
};
for &shift in shift_iter {
candidate_counts.fill(0);
let mut candidate = Self {
counts: Vec::new(), alphabet_size,
cost: f32::MAX,
method: shift.min(ANS_LOG_TAB_SIZE - 1) + 1,
omit_pos: 0,
num_symbols,
symbols,
};
core::mem::swap(&mut candidate.counts, &mut candidate_counts);
if candidate.rebalance_histogram_cached(histo, shift, cache.get(shift)) {
candidate.cost = candidate.estimate_cost(histo);
if candidate.cost < best.cost {
core::mem::swap(&mut candidate_counts, &mut best.counts);
best = candidate;
candidate_counts.resize(alphabet_size, 0);
} else {
core::mem::swap(&mut candidate.counts, &mut candidate_counts);
}
} else {
core::mem::swap(&mut candidate.counts, &mut candidate_counts);
}
}
if best.cost == f32::MAX {
eprintln!(
"ANS rebalance FAILED: alphabet_size={}, num_symbols={}, total_count={}",
alphabet_size, num_symbols, histo.total_count
);
for (i, &c) in histo.counts.iter().enumerate() {
if c > 0 {
eprintln!(" symbol {}: count={}", i, c);
}
}
return Err(Error::InvalidHistogram(
"Failed to rebalance histogram".to_string(),
));
}
Ok(best)
}
fn rebalance_histogram_cached(
&mut self,
histo: &super::histogram::Histogram,
_shift: u32,
allowed: &[i32],
) -> bool {
let total_count = histo.total_count;
if total_count == 0 {
return false;
}
let norm = ANS_TAB_SIZE as f64 / total_count as f64;
let mut remainder_pos = 0;
let mut max_freq = 0i32;
let mut bins: Vec<(i32, usize, usize)> = Vec::with_capacity(self.alphabet_size);
let mut rest = ANS_TAB_SIZE as i32;
for (n, &freq) in histo.counts.iter().enumerate().take(self.alphabet_size) {
if freq > max_freq {
remainder_pos = n;
max_freq = freq;
}
if freq == 0 {
self.counts[n] = 0;
continue;
}
let target = freq as f64 * norm;
let rounded = target.round().max(1.0).min((ANS_TAB_SIZE - 1) as f64) as i32;
let ai = find_allowed_leq(allowed, rounded);
let count = allowed[ai];
self.counts[n] = count;
rest -= count;
if target > 1.0 {
bins.push((freq, ai, n));
}
}
if let Some(pos) = bins.iter().position(|b| b.2 == remainder_pos) {
bins.remove(pos);
}
rest += self.counts[remainder_pos];
if !bins.is_empty() {
let max_freq_f = max_freq as f64;
let lg2 = |v: i32| -> f64 {
if v <= 0 {
0.0
} else {
jxl_simd::fast_log2f(v as f32) as f64
}
};
loop {
let mut best_inc_net = 0.0f64; let mut best_inc_bi = None;
let mut best_dec_net = 0.0f64; let mut best_dec_bi = None;
for (bi, &(freq, ai, _bin)) in bins.iter().enumerate() {
let count = allowed[ai];
let freq_f = freq as f64;
let lg2_count = lg2(count);
if ai > 0 {
let new_count = allowed[ai - 1];
let step = new_count - count;
let new_rest = rest - step;
if new_rest > 0 || rest >= ANS_TAB_SIZE as i32 {
let gain = freq_f * (lg2(new_count) - lg2_count);
let cost = if rest >= ANS_TAB_SIZE as i32 {
0.0 } else if rest > 0 && new_rest > 0 {
max_freq_f * (lg2(rest) - lg2(new_rest))
} else {
f64::MAX
};
let net = gain - cost;
let step_log = floor_log2(step as u32);
let norm_net = if step_log > 0 {
net / (1u32 << step_log) as f64
} else {
net
};
if norm_net > best_inc_net {
best_inc_net = norm_net;
best_inc_bi = Some(bi);
}
}
}
if ai + 1 < allowed.len() && allowed[ai + 1] > 0 {
let new_count = allowed[ai + 1];
let step = count - new_count;
let new_rest = rest + step;
if new_rest < ANS_TAB_SIZE as i32 || rest <= 1 {
let loss = freq_f * (lg2_count - lg2(new_count));
let gain = if rest <= 1 {
f64::MAX } else if rest > 0 && new_rest < ANS_TAB_SIZE as i32 {
max_freq_f * (lg2(new_rest) - lg2(rest))
} else {
0.0
};
let net = gain - loss;
let step_log = floor_log2(step as u32);
let norm_net = if step_log > 0 {
net / (1u32 << step_log) as f64
} else {
net
};
if norm_net > best_dec_net {
best_dec_net = norm_net;
best_dec_bi = Some(bi);
}
}
}
}
if best_inc_net > 0.0 {
if let Some(bi) = best_inc_bi {
let step = allowed[bins[bi].1 - 1] - allowed[bins[bi].1];
bins[bi].1 -= 1; rest -= step;
}
} else if best_dec_net > 0.0 {
if let Some(bi) = best_dec_bi {
let step = allowed[bins[bi].1] - allowed[bins[bi].1 + 1];
bins[bi].1 += 1; rest += step;
}
} else {
break; }
}
for &(_freq, ai, bin) in &bins {
self.counts[bin] = allowed[ai];
}
for n in 0..remainder_pos {
if self.counts[n] >= 2048 {
self.counts[remainder_pos] = self.counts[n];
remainder_pos = n;
break;
}
}
}
self.counts[remainder_pos] = rest;
self.omit_pos = remainder_pos;
if rest <= 0 {
return false;
}
for _ in 0..10 {
let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
let mut adjusted = false;
for i in 0..self.alphabet_size {
if i == remainder_pos || self.counts[i] <= 0 {
continue;
}
let logcount = floor_log2(self.counts[i] as u32) + 1;
let needs_fix =
logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos);
if needs_fix {
let target_logcount = if i < remainder_pos {
omit_logcount.saturating_sub(1)
} else {
omit_logcount
};
let max_value = (1i32 << target_logcount) - 1;
let new_ai = find_allowed_leq(allowed, max_value);
let new_count = allowed[new_ai].max(1);
let reduction = self.counts[i] - new_count;
if reduction > 0 {
self.counts[i] = new_count;
self.counts[remainder_pos] += reduction;
adjusted = true;
}
}
}
if !adjusted {
break;
}
}
let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
if i == remainder_pos || count <= 0 {
continue;
}
let logcount = floor_log2(count as u32) + 1;
if logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos) {
return false;
}
}
let sum: i32 = self.counts.iter().sum();
sum == ANS_TAB_SIZE as i32
}
fn estimate_cost(&self, histo: &super::histogram::Histogram) -> f32 {
let header_cost = self.estimate_header_cost();
let data_cost = estimate_data_bits_normalized(
&histo.counts,
&self.counts,
histo.total_count,
self.alphabet_size,
) as f32;
header_cost + data_cost
}
fn estimate_header_cost(&self) -> f32 {
if self.method == 0 {
2.0 + 8.0
} else if self.num_symbols <= 2 {
if self.num_symbols <= 1 {
3.0 + 8.0 } else {
3.0 + 16.0 + 12.0 }
} else {
let method_bits = 4.0; let alphabet_bits = 8.0;
let freq_bits = self.alphabet_size as f32 * 5.0; method_bits + alphabet_bits + freq_bits
}
}
pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
if self.method == 0 {
writer.write(1, 0)?; writer.write(1, 1)?; write_var_len_uint8(writer, (self.alphabet_size - 1) as u8)?;
return Ok(());
}
if self.num_symbols <= 2 {
writer.write(1, 1)?; if self.num_symbols == 0 {
writer.write(1, 0)?;
write_var_len_uint8(writer, 0)?;
} else {
writer.write(1, (self.num_symbols - 1) as u64)?;
for i in 0..self.num_symbols {
write_var_len_uint8(writer, self.symbols[i] as u8)?;
}
if self.num_symbols == 2 {
writer.write(
ANS_LOG_TAB_SIZE as usize,
self.counts[self.symbols[0]] as u64,
)?;
}
}
return Ok(());
}
self.write_general(writer)
}
fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
writer.write(1, 0)?; writer.write(1, 0)?;
let shift = (self.method - 1) as i32;
let shift_val = (shift + 1) as u32;
let mut len = 0u32;
while len < 3 && shift_val >= (1u32 << (len + 1)) {
len += 1;
}
for _ in 0..len {
writer.write(1, 1)?;
}
if len < 3 {
writer.write(1, 0)?;
}
if len > 0 {
let suffix = shift_val - (1u32 << len);
writer.write(len as usize, suffix as u64)?;
}
if self.alphabet_size < 3 {
return Err(Error::InvalidHistogram(
"General histogram needs at least 3 symbols".to_string(),
));
}
write_var_len_uint8(writer, (self.alphabet_size - 3) as u8)?;
let logcounts: Vec<u32> = (0..self.alphabet_size)
.map(|i| {
let count = self.counts[i];
if count <= 0 {
0
} else {
floor_log2(count as u32) + 1
}
})
.collect();
let mut same = vec![0usize; self.alphabet_size];
#[allow(clippy::needless_range_loop)]
for i in 0..self.alphabet_size {
if i == self.omit_pos {
continue;
}
let mut run = 0;
let mut j = i + 1;
while j < self.alphabet_size && self.counts[j] == self.counts[i] {
if j == self.omit_pos {
break; }
run += 1;
j += 1;
}
same[i] = run;
}
const MIN_REPS: usize = 4; let mut i = 0;
while i < self.alphabet_size {
let (nbits, code) = LOGCOUNT_PREFIX_CODE[logcounts[i] as usize];
writer.write(nbits as usize, code as u64)?;
if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
let (rle_nbits, rle_code) = LOGCOUNT_PREFIX_CODE[RLE_MARKER_SYM as usize];
writer.write(rle_nbits as usize, rle_code as u64)?;
write_var_len_uint8(writer, (same[i] - MIN_REPS) as u8)?;
i += same[i]; }
i += 1;
}
let mut rle_covered = vec![false; self.alphabet_size];
{
let mut i = 0;
while i < self.alphabet_size {
if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
for item in rle_covered.iter_mut().take(i + same[i] + 1).skip(i + 1) {
*item = true;
}
i += same[i];
}
i += 1;
}
}
for i in 0..self.alphabet_size {
if i == self.omit_pos || rle_covered[i] {
continue;
}
let count = self.counts[i];
if count <= 0 {
continue;
}
let logcount = logcounts[i];
if logcount <= 1 {
continue;
}
let zeros = (logcount - 1) as i32;
let bitcount = (shift - ((ANS_LOG_TAB_SIZE as i32 - zeros) >> 1)).clamp(0, zeros);
if bitcount > 0 {
let base = 1i32 << zeros;
let extra = ((count - base) >> (zeros - bitcount)) as u32;
writer.write(bitcount as usize, extra as u64)?;
}
}
Ok(())
}
}
impl Default for ANSEncodingHistogram {
fn default() -> Self {
Self::new()
}
}
pub fn encode_tokens_ans(
tokens: &[(u32, u32)], distributions: &[AnsDistribution],
context_map: &[usize],
writer: &mut BitWriter,
) -> Result<()> {
let mut encoder = AnsEncoder::new();
for &(context, value) in tokens.iter().rev() {
let dist_idx = context_map.get(context as usize).copied().unwrap_or(0);
if let Some(dist) = distributions.get(dist_idx)
&& let Some(info) = dist.get(value as usize)
{
encoder.put_symbol(info);
}
}
encoder.finalize(writer)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entropy_coding::histogram::Histogram;
#[test]
fn test_ans_encoding_histogram_single_symbol() {
let h = Histogram::from_counts(&[100, 0, 0, 0]);
let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
assert_eq!(encoded.num_symbols, 1);
assert_eq!(encoded.method, 1); assert_eq!(encoded.counts[0], ANS_TAB_SIZE as i32);
assert!(encoded.cost < 100.0); }
#[test]
fn test_ans_encoding_histogram_two_symbols() {
let h = Histogram::from_counts(&[100, 100, 0, 0]);
let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
assert_eq!(encoded.num_symbols, 2);
assert_eq!(encoded.method, 1); let sum: i32 = encoded.counts.iter().sum();
assert_eq!(sum, ANS_TAB_SIZE as i32);
assert!(encoded.counts[0] > 0);
assert!(encoded.counts[1] > 0);
}
#[test]
fn test_ans_encoding_histogram_general() {
let h = Histogram::from_counts(&[100, 50, 25, 10, 5, 3, 2, 1]);
let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
assert!(encoded.method >= 2 || encoded.method == 0);
let sum: i32 = encoded.counts.iter().sum();
assert_eq!(sum, ANS_TAB_SIZE as i32);
for (i, &orig) in h.counts.iter().enumerate() {
if orig > 0 {
assert!(
encoded.counts.get(i).copied().unwrap_or(0) > 0,
"Symbol {} had count {} but normalized to 0",
i,
orig
);
}
}
}
#[test]
fn test_ans_encoding_histogram_empty() {
let h = Histogram::new();
let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
assert_eq!(encoded.cost, 0.0);
assert_eq!(encoded.method, 0); }
#[test]
fn test_get_population_count_precision() {
assert_eq!(get_population_count_precision(0, 12), 0);
assert_eq!(get_population_count_precision(12, 12), 12);
assert_eq!(get_population_count_precision(6, 6), 3);
assert_eq!(get_population_count_precision(1, 0), 0);
}
#[test]
fn test_ans_encoding_histogram_write() {
let h = Histogram::from_counts(&[100, 0, 0, 0]);
let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
let mut writer = BitWriter::new();
encoded.write(&mut writer).unwrap();
let bytes = writer.finish_with_padding();
assert!(!bytes.is_empty());
}
#[test]
fn test_flat_distribution() {
let dist = AnsDistribution::flat(16).unwrap();
assert_eq!(dist.alphabet_size(), 16);
for sym in &dist.symbols {
assert_eq!(sym.freq, 256);
}
}
#[test]
fn test_from_frequencies() {
let freqs = vec![100, 200, 300, 400];
let dist = AnsDistribution::from_frequencies(&freqs).unwrap();
assert_eq!(dist.alphabet_size(), 4);
let total: u32 = dist.symbols.iter().map(|s| s.freq as u32).sum();
assert_eq!(total, ANS_TAB_SIZE);
}
#[test]
fn test_ans_encoder_basic() {
let dist = AnsDistribution::flat(4).unwrap();
let mut encoder = AnsEncoder::new();
encoder.put_symbol(&dist.symbols[0]);
encoder.put_symbol(&dist.symbols[1]);
encoder.put_symbol(&dist.symbols[2]);
assert_ne!(encoder.state(), ANS_SIGNATURE << 16);
}
#[test]
fn test_reverse_map() {
let dist = AnsDistribution::flat(4).unwrap();
for sym in &dist.symbols {
assert_eq!(sym.reverse_map.len(), sym.freq as usize);
}
let mut covered = vec![false; ANS_TAB_SIZE as usize];
for sym in &dist.symbols {
for &pos in &sym.reverse_map {
assert!(!covered[pos as usize], "position {} covered twice", pos);
covered[pos as usize] = true;
}
}
assert!(covered.iter().all(|&c| c), "not all positions covered");
}
#[test]
fn test_write_distribution() {
let dist = AnsDistribution::flat(16).unwrap();
let mut writer = BitWriter::new();
dist.write(&mut writer).unwrap();
let bytes = writer.finish_with_padding();
assert!(!bytes.is_empty());
}
#[test]
fn test_ans_roundtrip_manual() {
let dist = AnsDistribution::flat(2).unwrap();
println!("Distribution: {} symbols", dist.alphabet_size());
for (i, sym) in dist.symbols.iter().enumerate() {
println!(" Symbol {}: freq={}", i, sym.freq);
}
let mut encoder = AnsEncoder::new();
let initial_state = encoder.state();
println!("\nInitial state: 0x{:08x}", initial_state);
assert_eq!(initial_state, 0x130000, "Initial state should be 0x130000");
let info = &dist.symbols[0];
encoder.put_symbol(info);
let encoded_state = encoder.state();
println!("After encoding symbol 0: state=0x{:08x}", encoded_state);
let idx = encoded_state & 0xFFF;
println!("Decode: idx = {}", idx);
let decoded_symbol = if idx < 2048 { 0 } else { 1 };
let offset_in_symbol = if idx < 2048 { idx } else { idx - 2048 };
let freq = 2048u32;
println!("Decoded symbol: {}", decoded_symbol);
println!("Offset in symbol: {}", offset_in_symbol);
let quotient = encoded_state >> 12;
let next_state = quotient * freq + offset_in_symbol;
println!(
"next_state = {} * {} + {} = 0x{:08x}",
quotient, freq, offset_in_symbol, next_state
);
assert_eq!(next_state, 0x130000, "Decoded state should be 0x130000");
assert_eq!(decoded_symbol, 0, "Decoded symbol should be 0");
}
#[test]
fn test_ans_roundtrip_multiple_symbols() {
use crate::bit_writer::BitWriter;
use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
let counts = [1024i32, 1024, 1024, 1024];
let dist = AnsDistribution::from_normalized_counts(&counts).unwrap();
let symbols_to_encode: Vec<usize> = vec![0, 1, 2, 3, 0, 1];
println!(
"Encoding {} symbols: {:?}",
symbols_to_encode.len(),
symbols_to_encode
);
let mut encoder = AnsEncoder::new();
for &sym in symbols_to_encode.iter().rev() {
encoder.put_symbol(&dist.symbols[sym]);
}
println!("Final state after encoding: 0x{:08x}", encoder.state());
let mut writer = BitWriter::new();
encoder.finalize(&mut writer).unwrap();
let encoded_bytes = writer.finish_with_padding();
println!("Encoded bytes: {:02x?}", encoded_bytes);
let ans_histo = ANSEncodingHistogram::from_histogram(
&Histogram::from_counts(&counts),
ANSHistogramStrategy::Precise,
)
.unwrap();
let mut hist_writer = BitWriter::new();
ans_histo.write(&mut hist_writer).unwrap();
let hist_bytes = hist_writer.finish_with_padding();
let mut hist_br = BitReader::new(&hist_bytes);
let decoded_hist = AnsHistogram::decode(&mut hist_br, 6).unwrap();
println!(
"Decoded histogram frequencies: {:?}",
&decoded_hist.frequencies[..4]
);
let mut br = BitReader::new(&encoded_bytes);
let mut ans_reader = AnsReader::init(&mut br).unwrap();
println!("Decoding:");
let mut decoded = Vec::new();
for i in 0..symbols_to_encode.len() {
let symbol = decoded_hist.read(&mut br, &mut ans_reader.0) as usize;
println!(
" step {}: symbol={}, state=0x{:08x}",
i, symbol, ans_reader.0
);
decoded.push(symbol);
}
println!("Original: {:?}", symbols_to_encode);
println!("Decoded: {:?}", decoded);
println!("Final state: 0x{:08x}", ans_reader.0);
assert_eq!(
decoded, symbols_to_encode,
"Decoded symbols should match original"
);
assert!(
ans_reader.check_final_state().is_ok(),
"Final state should be 0x130000, got 0x{:08x}",
ans_reader.0
);
}
#[test]
fn test_ans_histogram_write_decode_roundtrip() {
use crate::bit_writer::BitWriter;
use crate::entropy_coding::histogram::Histogram;
let histo = Histogram::from_counts(&[100, 50, 25, 10]);
let encoded =
ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
println!("Histogram: {:?}", histo.counts);
println!("Encoded counts: {:?}", encoded.counts);
println!(
"Method: {}, alphabet_size: {}, omit_pos: {}",
encoded.method, encoded.alphabet_size, encoded.omit_pos
);
let sum: i32 = encoded.counts.iter().sum();
assert_eq!(sum, ANS_TAB_SIZE as i32, "Sum should be 4096");
let mut writer = BitWriter::new();
encoded.write(&mut writer).unwrap();
let bytes = writer.finish_with_padding();
println!("Encoded histogram: {} bytes", bytes.len());
println!("Bytes: {:02x?}", bytes);
}
}