use crate::error::{CodecError, CodecResult};
const DEFAULT_LOG_TABLE_SIZE: u8 = 10;
const RENORM_WORD_BITS: u32 = 16;
#[derive(Clone, Debug)]
pub struct AnsDistribution {
pub symbols: Vec<u16>,
pub frequencies: Vec<u32>,
pub cumulative: Vec<u32>,
pub log_table_size: u8,
}
impl AnsDistribution {
pub fn new(symbols: Vec<u16>, frequencies: Vec<u32>, log_table_size: u8) -> CodecResult<Self> {
if symbols.len() != frequencies.len() {
return Err(CodecError::InvalidParameter(
"Symbol and frequency vectors must have the same length".into(),
));
}
if symbols.is_empty() {
return Err(CodecError::InvalidParameter(
"Distribution must have at least one symbol".into(),
));
}
let total: u32 = frequencies.iter().sum();
if total == 0 {
return Err(CodecError::InvalidParameter(
"Total frequency must be non-zero".into(),
));
}
let table_size = 1u32 << log_table_size;
let mut normalized: Vec<u32> = frequencies
.iter()
.map(|&f| {
if f == 0 {
0
} else {
let n = (f as u64 * table_size as u64 / total as u64) as u32;
if n == 0 {
1
} else {
n
}
}
})
.collect();
let current_sum: u32 = normalized.iter().sum();
if current_sum != table_size {
let diff = table_size as i64 - current_sum as i64;
if let Some(max_idx) = normalized
.iter()
.enumerate()
.filter(|(_, &f)| f > 0)
.max_by_key(|(_, &f)| f)
.map(|(i, _)| i)
{
let adjusted = normalized[max_idx] as i64 + diff;
if adjusted > 0 {
normalized[max_idx] = adjusted as u32;
}
}
}
let mut cumulative = Vec::with_capacity(normalized.len() + 1);
cumulative.push(0);
let mut sum = 0u32;
for &f in &normalized {
sum += f;
cumulative.push(sum);
}
Ok(Self {
symbols,
frequencies: normalized,
cumulative,
log_table_size,
})
}
pub fn table_size(&self) -> u32 {
1u32 << self.log_table_size
}
pub fn num_symbols(&self) -> usize {
self.symbols.len()
}
pub fn lookup(&self, value: u32) -> CodecResult<(usize, u32, u32)> {
let mut lo = 0usize;
let mut hi = self.symbols.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.cumulative[mid + 1] <= value {
lo = mid + 1;
} else {
hi = mid;
}
}
if lo >= self.symbols.len() {
return Err(CodecError::InvalidBitstream(format!(
"ANS lookup failed: value {value} out of range"
)));
}
Ok((lo, self.cumulative[lo], self.frequencies[lo]))
}
fn find_symbol(&self, symbol: u16) -> CodecResult<usize> {
self.symbols
.iter()
.position(|&s| s == symbol)
.ok_or_else(|| {
CodecError::InvalidParameter(format!("Symbol {symbol} not found in distribution"))
})
}
}
pub fn uniform_distribution(n: u16) -> CodecResult<AnsDistribution> {
if n == 0 {
return Err(CodecError::InvalidParameter(
"Cannot create uniform distribution with 0 symbols".into(),
));
}
let symbols: Vec<u16> = (0..n).collect();
let freq = vec![1u32; n as usize];
AnsDistribution::new(symbols, freq, DEFAULT_LOG_TABLE_SIZE)
}
pub fn distribution_from_counts(
counts: &[u32],
log_table_size: u8,
) -> CodecResult<AnsDistribution> {
let mut symbols = Vec::new();
let mut frequencies = Vec::new();
for (i, &count) in counts.iter().enumerate() {
if count > 0 {
symbols.push(i as u16);
frequencies.push(count);
}
}
if symbols.is_empty() {
symbols.push(0);
frequencies.push(1);
}
AnsDistribution::new(symbols, frequencies, log_table_size)
}
pub struct AnsDecoder<'a> {
state: u32,
data: &'a [u8],
word_pos: usize,
}
impl<'a> AnsDecoder<'a> {
pub fn new(data: &'a [u8]) -> CodecResult<Self> {
if data.len() < 8 {
return Err(CodecError::InvalidBitstream("ANS data too short".into()));
}
let state = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
Ok(Self {
state,
data,
word_pos: 8,
})
}
fn read_word(&mut self) -> u16 {
if self.word_pos + 1 < self.data.len() {
let w = u16::from_le_bytes([self.data[self.word_pos], self.data[self.word_pos + 1]]);
self.word_pos += 2;
w
} else {
0
}
}
pub fn decode_symbol(&mut self, dist: &AnsDistribution) -> CodecResult<u16> {
let table_size = dist.table_size();
let mask = table_size - 1;
let slot = self.state & mask;
let (idx, start, freq) = dist.lookup(slot)?;
let symbol = dist.symbols[idx];
self.state = freq * (self.state >> dist.log_table_size) + slot - start;
if self.state < table_size {
let word = self.read_word() as u32;
self.state = (self.state << RENORM_WORD_BITS) | word;
}
Ok(symbol)
}
}
pub struct AnsEncoder {
state: u32,
words: Vec<u16>,
log_table_size: u8,
}
impl AnsEncoder {
pub fn new() -> Self {
let log_table_size = DEFAULT_LOG_TABLE_SIZE;
let table_size = 1u32 << log_table_size;
Self {
state: table_size, words: Vec::new(),
log_table_size,
}
}
pub fn encode_symbol(&mut self, symbol: u16, dist: &AnsDistribution) -> CodecResult<()> {
let idx = dist.find_symbol(symbol)?;
let start = dist.cumulative[idx];
let freq = dist.frequencies[idx];
if freq == 0 {
return Err(CodecError::InvalidParameter(format!(
"Symbol {symbol} has zero frequency"
)));
}
let table_size = dist.table_size();
let upper_bound = freq << RENORM_WORD_BITS;
while self.state >= upper_bound {
self.words.push(self.state as u16);
self.state >>= RENORM_WORD_BITS;
}
self.state = table_size * (self.state / freq) + (self.state % freq) + start;
Ok(())
}
pub fn finish(self) -> Vec<u8> {
let word_count = self.words.len() as u32;
let mut output = Vec::with_capacity(8 + self.words.len() * 2);
output.extend_from_slice(&self.state.to_le_bytes());
output.extend_from_slice(&word_count.to_le_bytes());
for &word in self.words.iter().rev() {
output.extend_from_slice(&word.to_le_bytes());
}
output
}
}
impl Default for AnsEncoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_uniform_distribution() {
let dist = uniform_distribution(4).expect("ok");
assert_eq!(dist.num_symbols(), 4);
assert_eq!(dist.table_size(), 1 << DEFAULT_LOG_TABLE_SIZE);
let expected = dist.table_size() / 4;
for &f in &dist.frequencies {
assert!((f as i64 - expected as i64).unsigned_abs() <= 1);
}
}
#[test]
#[ignore]
fn test_distribution_from_counts() {
let counts = [10u32, 20, 30, 0, 40];
let dist = distribution_from_counts(&counts, 10).expect("ok");
assert_eq!(dist.num_symbols(), 4);
assert_eq!(dist.symbols, vec![0, 1, 2, 4]);
}
#[test]
#[ignore]
fn test_distribution_cumulative() {
let symbols = vec![0, 1, 2];
let freqs = vec![256, 512, 256];
let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
assert_eq!(dist.cumulative[0], 0);
assert_eq!(
*dist.cumulative.last().expect("has last"),
dist.table_size()
);
}
#[test]
#[ignore]
fn test_distribution_lookup() {
let symbols = vec![0, 1];
let freqs = vec![512, 512];
let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
let (idx, start, freq) = dist.lookup(0).expect("ok");
assert_eq!(idx, 0);
assert_eq!(start, 0);
assert!(freq > 0);
let (idx, _start, _freq) = dist.lookup(dist.table_size() - 1).expect("ok");
assert_eq!(idx, 1);
}
#[test]
#[ignore]
fn test_ans_roundtrip_single_symbol() {
let dist = uniform_distribution(4).expect("ok");
let mut encoder = AnsEncoder::new();
encoder.encode_symbol(2, &dist).expect("ok");
let encoded = encoder.finish();
let mut decoder = AnsDecoder::new(&encoded).expect("ok");
let decoded = decoder.decode_symbol(&dist).expect("ok");
assert_eq!(decoded, 2);
}
#[test]
#[ignore]
fn test_ans_roundtrip_sequence() {
let dist = uniform_distribution(8).expect("ok");
let symbols_to_encode: Vec<u16> = vec![0, 3, 7, 1, 5, 2, 6, 4];
let mut encoder = AnsEncoder::new();
for &sym in symbols_to_encode.iter().rev() {
encoder.encode_symbol(sym, &dist).expect("ok");
}
let encoded = encoder.finish();
let mut decoder = AnsDecoder::new(&encoded).expect("ok");
for &expected in &symbols_to_encode {
let decoded = decoder.decode_symbol(&dist).expect("ok");
assert_eq!(decoded, expected, "ANS roundtrip mismatch");
}
}
#[test]
#[ignore]
fn test_ans_roundtrip_skewed_distribution() {
let symbols = vec![0, 1, 2, 3];
let freqs = vec![700, 200, 80, 20];
let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
let test_seq: Vec<u16> = vec![0, 0, 0, 1, 0, 2, 0, 0, 3, 0, 1];
let mut encoder = AnsEncoder::new();
for &sym in test_seq.iter().rev() {
encoder.encode_symbol(sym, &dist).expect("ok");
}
let encoded = encoder.finish();
let mut decoder = AnsDecoder::new(&encoded).expect("ok");
for &expected in &test_seq {
let decoded = decoder.decode_symbol(&dist).expect("ok");
assert_eq!(decoded, expected);
}
}
#[test]
#[ignore]
fn test_ans_roundtrip_repeated_symbol() {
let dist = uniform_distribution(4).expect("ok");
let symbols: Vec<u16> = vec![1, 1, 1, 1, 1];
let mut encoder = AnsEncoder::new();
for &sym in symbols.iter().rev() {
encoder.encode_symbol(sym, &dist).expect("ok");
}
let encoded = encoder.finish();
let mut decoder = AnsDecoder::new(&encoded).expect("ok");
for &expected in &symbols {
let decoded = decoder.decode_symbol(&dist).expect("ok");
assert_eq!(decoded, expected);
}
}
#[test]
#[ignore]
fn test_ans_roundtrip_long_sequence() {
let dist = uniform_distribution(16).expect("ok");
let symbols: Vec<u16> = (0..100).map(|i| (i % 16) as u16).collect();
let mut encoder = AnsEncoder::new();
for &sym in symbols.iter().rev() {
encoder.encode_symbol(sym, &dist).expect("ok");
}
let encoded = encoder.finish();
let mut decoder = AnsDecoder::new(&encoded).expect("ok");
for (i, &expected) in symbols.iter().enumerate() {
let decoded = decoder.decode_symbol(&dist).expect("ok");
assert_eq!(decoded, expected, "Mismatch at position {i}");
}
}
#[test]
#[ignore]
fn test_empty_distribution_error() {
assert!(AnsDistribution::new(vec![], vec![], 10).is_err());
}
#[test]
#[ignore]
fn test_zero_symbol_uniform_error() {
assert!(uniform_distribution(0).is_err());
}
}