#![no_std]
#![warn(missing_docs)]
extern crate alloc;
use alloc::{format, string::String, string::ToString, vec, vec::Vec};
use core::fmt;
const RANS_L: u32 = 1 << 23;
#[derive(Debug, PartialEq, Eq)]
pub enum AnsError {
InvalidPrecision {
precision_bits: u32,
},
EmptyAlphabet,
InvalidSymbol {
symbol: u32,
alphabet_size: usize,
},
ZeroFrequency {
symbol: u32,
},
InvalidTable(String),
InvalidState {
state: u64,
min_state: u64,
},
TruncatedInput {
available: usize,
needed: usize,
},
}
impl fmt::Display for AnsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidPrecision { precision_bits } => {
write!(
f,
"invalid precision_bits={precision_bits} (must be in 1..=20)"
)
}
Self::EmptyAlphabet => f.write_str("empty frequency table"),
Self::InvalidSymbol {
symbol,
alphabet_size,
} => write!(
f,
"invalid symbol {symbol} for alphabet size {alphabet_size}"
),
Self::ZeroFrequency { symbol } => {
write!(f, "frequency for symbol {symbol} is zero")
}
Self::InvalidTable(msg) => write!(f, "frequency table normalization failed: {msg}"),
Self::InvalidState { state, min_state } => {
write!(
f,
"invalid rANS state 0x{state:x} (expected >= 0x{min_state:x})"
)
}
Self::TruncatedInput { available, needed } => {
write!(
f,
"truncated input ({available} bytes available, need at least {needed})"
)
}
}
}
}
impl core::error::Error for AnsError {}
#[derive(Debug, Clone)]
pub struct FrequencyTable {
precision_bits: u32,
total: u32,
freqs: Vec<u32>,
cdf: Vec<u32>, sym_by_slot: Vec<u32>, }
impl FrequencyTable {
pub fn from_counts(counts: &[u32], precision_bits: u32) -> Result<Self, AnsError> {
if !(1..=20).contains(&precision_bits) {
return Err(AnsError::InvalidPrecision { precision_bits });
}
if counts.is_empty() {
return Err(AnsError::EmptyAlphabet);
}
let total = 1u32 << precision_bits;
let sum: u64 = counts.iter().map(|&c| c as u64).sum();
if sum == 0 {
return Err(AnsError::InvalidTable("all counts are zero".to_string()));
}
let mut freqs = vec![0u32; counts.len()];
for (i, &c) in counts.iter().enumerate() {
let f = ((c as u128) * (total as u128) / (sum as u128)) as u32;
freqs[i] = f;
}
for (i, &c) in counts.iter().enumerate() {
if c > 0 && freqs[i] == 0 {
freqs[i] = 1;
}
}
let mut cur_sum: i64 = freqs.iter().map(|&f| f as i64).sum();
let target: i64 = total as i64;
debug_assert!(
cur_sum > 0,
"cur_sum should be > 0 after zero-floor correction"
);
while cur_sum != target {
if cur_sum < target {
let (idx, _) = counts.iter().enumerate().max_by_key(|&(_, &c)| c).unwrap();
freqs[idx] += 1;
cur_sum += 1;
} else {
let mut best: Option<(usize, u32)> = None;
for (i, &f) in freqs.iter().enumerate() {
if f > 1 && best.map(|(_, bf)| f > bf).unwrap_or(true) {
best = Some((i, f));
}
}
let Some((idx, _)) = best else {
return Err(AnsError::InvalidTable(format!(
"cannot reduce total (cur_sum={cur_sum}, target={target}): \
all {} symbols have freq<=1",
freqs.len()
)));
};
freqs[idx] -= 1;
cur_sum -= 1;
}
}
debug_assert_eq!(
freqs.iter().map(|&f| f as i64).sum::<i64>(),
target,
"freq sum mismatch after correction loop"
);
Ok(Self::build_lookups(freqs, precision_bits, total))
}
pub fn from_normalized(freqs: &[u32], precision_bits: u32) -> Result<Self, AnsError> {
if !(1..=20).contains(&precision_bits) {
return Err(AnsError::InvalidPrecision { precision_bits });
}
if freqs.is_empty() {
return Err(AnsError::EmptyAlphabet);
}
let total = 1u32 << precision_bits;
let sum: u32 = freqs.iter().sum();
if sum != total {
return Err(AnsError::InvalidTable(format!(
"frequencies sum to {sum}, expected {total}"
)));
}
Ok(Self::build_lookups(freqs.to_vec(), precision_bits, total))
}
pub fn from_float_probs(probs: &[f32], precision_bits: u32) -> Result<Self, AnsError> {
if !(1..=20).contains(&precision_bits) {
return Err(AnsError::InvalidPrecision { precision_bits });
}
if probs.is_empty() {
return Err(AnsError::EmptyAlphabet);
}
let sum: f64 = probs.iter().map(|&p| (p.max(0.0)) as f64).sum();
if sum == 0.0 {
return Err(AnsError::InvalidTable(
"all probabilities are zero or negative".to_string(),
));
}
let scale = (1u64 << 30) as f64;
let counts: Vec<u32> = probs
.iter()
.map(|&p| {
let p = (p.max(0.0)) as f64 / sum;
let v = p * scale;
(if v > 0.0 { v + 0.5 } else { 0.0 }) as u32
})
.collect();
Self::from_counts(&counts, precision_bits)
}
fn build_lookups(freqs: Vec<u32>, precision_bits: u32, total: u32) -> Self {
let mut cdf = vec![0u32; freqs.len() + 1];
for i in 0..freqs.len() {
cdf[i + 1] = cdf[i] + freqs[i];
}
let mut sym_by_slot = vec![0u32; total as usize];
for sym in 0..freqs.len() {
let start = cdf[sym] as usize;
let end = cdf[sym + 1] as usize;
for slot in sym_by_slot.iter_mut().take(end).skip(start) {
*slot = sym as u32;
}
}
Self {
precision_bits,
total,
freqs,
cdf,
sym_by_slot,
}
}
#[inline]
#[must_use]
pub fn precision_bits(&self) -> u32 {
self.precision_bits
}
#[inline]
#[must_use]
pub fn total(&self) -> u32 {
self.total
}
#[inline]
#[must_use]
pub fn alphabet_size(&self) -> usize {
self.freqs.len()
}
#[inline]
#[must_use]
pub fn freq(&self, sym: u32) -> Option<u32> {
self.freqs.get(sym as usize).copied()
}
#[inline]
#[must_use]
pub fn cum_freq(&self, sym: u32) -> Option<u32> {
self.cdf.get(sym as usize).copied()
}
#[inline]
#[must_use]
pub fn symbol_at_slot(&self, slot: u32) -> Option<u32> {
self.sym_by_slot.get(slot as usize).copied()
}
#[inline]
#[must_use]
pub fn freqs(&self) -> &[u32] {
&self.freqs
}
#[inline]
#[must_use]
pub fn cdf(&self) -> &[u32] {
&self.cdf
}
}
#[derive(Debug, Clone)]
pub struct RansEncoder {
state: u32,
buf: Vec<u8>,
}
impl RansEncoder {
#[must_use]
pub fn new() -> Self {
Self {
state: RANS_L,
buf: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(cap: usize) -> Self {
Self {
state: RANS_L,
buf: Vec::with_capacity(cap),
}
}
pub fn put(&mut self, sym: u32, table: &FrequencyTable) -> Result<(), AnsError> {
let sym_us = sym as usize;
if sym_us >= table.freqs.len() {
return Err(AnsError::InvalidSymbol {
symbol: sym,
alphabet_size: table.freqs.len(),
});
}
let freq = table.freqs[sym_us];
if freq == 0 {
return Err(AnsError::ZeroFrequency { symbol: sym });
}
let start = table.cdf[sym_us];
let x_max = ((RANS_L >> table.precision_bits) << 8) * freq;
while self.state >= x_max {
self.buf.push((self.state & 0xFF) as u8);
self.state >>= 8;
}
let q = self.state / freq;
let r = self.state - q * freq;
self.state = (q << table.precision_bits) + r + start;
Ok(())
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
self.buf.extend_from_slice(&self.state.to_le_bytes());
self.buf
}
#[inline]
#[must_use]
pub fn state(&self) -> u32 {
self.state
}
}
impl Default for RansEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct RansDecoder<'a> {
state: u32,
bytes: &'a [u8],
cursor: usize,
}
impl<'a> RansDecoder<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, AnsError> {
if bytes.len() < 4 {
return Err(AnsError::TruncatedInput {
available: bytes.len(),
needed: 4,
});
}
let cursor = bytes.len() - 4;
let state_bytes: [u8; 4] = bytes[cursor..cursor + 4].try_into().unwrap();
let state = u32::from_le_bytes(state_bytes);
if state < RANS_L {
return Err(AnsError::InvalidState {
state: state as u64,
min_state: RANS_L as u64,
});
}
Ok(Self {
state,
bytes,
cursor,
})
}
pub fn get(&mut self, table: &FrequencyTable) -> Result<u32, AnsError> {
let sym = self.peek(table);
self.advance(sym, table)?;
Ok(sym)
}
#[inline]
#[must_use]
pub fn peek(&self, table: &FrequencyTable) -> u32 {
debug_assert!(
self.state >= RANS_L,
"peek called with invalid state {} (expected >= {RANS_L})",
self.state
);
let slot = (self.state & (table.total - 1)) as usize;
debug_assert!(slot < table.sym_by_slot.len(), "slot {slot} out of range");
table.sym_by_slot[slot]
}
pub fn advance(&mut self, sym: u32, table: &FrequencyTable) -> Result<(), AnsError> {
let mask = table.total - 1;
let slot = self.state & mask;
let sym_us = sym as usize;
if sym_us >= table.freqs.len() {
return Err(AnsError::InvalidSymbol {
symbol: sym,
alphabet_size: table.freqs.len(),
});
}
let freq = table.freqs[sym_us];
let start = table.cdf[sym_us];
self.state = freq * (self.state >> table.precision_bits) + (slot - start);
while self.state < RANS_L {
if self.cursor == 0 {
return Err(AnsError::TruncatedInput {
available: 0,
needed: 1,
});
}
self.cursor -= 1;
self.state = (self.state << 8) | (self.bytes[self.cursor] as u32);
}
Ok(())
}
#[inline]
#[must_use]
pub fn state(&self) -> u32 {
self.state
}
#[inline]
#[must_use]
pub fn remaining_bytes(&self) -> usize {
self.cursor
}
}
pub fn encode(symbols: &[u32], table: &FrequencyTable) -> Result<Vec<u8>, AnsError> {
let mut enc = RansEncoder::with_capacity(symbols.len());
for &sym in symbols.iter().rev() {
enc.put(sym, table)?;
}
Ok(enc.finish())
}
pub fn decode(bytes: &[u8], table: &FrequencyTable, len: usize) -> Result<Vec<u32>, AnsError> {
let mut dec = RansDecoder::new(bytes)?;
let mut out = Vec::with_capacity(len);
for _ in 0..len {
out.push(dec.get(table)?);
}
Ok(out)
}
const RANS64_L: u64 = 1 << 31;
#[derive(Debug, Clone)]
pub struct Rans64Encoder {
state: u64,
buf: Vec<u8>,
}
impl Rans64Encoder {
#[must_use]
pub fn new() -> Self {
Self {
state: RANS64_L,
buf: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(cap: usize) -> Self {
Self {
state: RANS64_L,
buf: Vec::with_capacity(cap),
}
}
pub fn put(&mut self, sym: u32, table: &FrequencyTable) -> Result<(), AnsError> {
let sym_us = sym as usize;
if sym_us >= table.freqs.len() {
return Err(AnsError::InvalidSymbol {
symbol: sym,
alphabet_size: table.freqs.len(),
});
}
let freq = table.freqs[sym_us] as u64;
if freq == 0 {
return Err(AnsError::ZeroFrequency { symbol: sym });
}
let start = table.cdf[sym_us] as u64;
let x_max = ((RANS64_L >> table.precision_bits) << 32) * freq;
while self.state >= x_max {
self.buf
.extend_from_slice(&(self.state as u32).to_le_bytes());
self.state >>= 32;
}
let q = self.state / freq;
let r = self.state - q * freq;
self.state = (q << table.precision_bits) + r + start;
Ok(())
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
self.buf.extend_from_slice(&self.state.to_le_bytes());
self.buf
}
#[inline]
#[must_use]
pub fn state(&self) -> u64 {
self.state
}
}
impl Default for Rans64Encoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct Rans64Decoder<'a> {
state: u64,
bytes: &'a [u8],
cursor: usize,
}
impl<'a> Rans64Decoder<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, AnsError> {
if bytes.len() < 8 {
return Err(AnsError::TruncatedInput {
available: bytes.len(),
needed: 8,
});
}
let cursor = bytes.len() - 8;
let state_bytes: [u8; 8] = bytes[cursor..cursor + 8].try_into().unwrap();
let state = u64::from_le_bytes(state_bytes);
if state < RANS64_L {
return Err(AnsError::InvalidState {
state,
min_state: RANS64_L,
});
}
Ok(Self {
state,
bytes,
cursor,
})
}
pub fn get(&mut self, table: &FrequencyTable) -> Result<u32, AnsError> {
let sym = self.peek(table);
self.advance(sym, table)?;
Ok(sym)
}
#[inline]
#[must_use]
pub fn peek(&self, table: &FrequencyTable) -> u32 {
debug_assert!(
self.state >= RANS64_L,
"peek called with invalid state {} (expected >= {RANS64_L})",
self.state
);
let slot = (self.state & (table.total as u64 - 1)) as usize;
debug_assert!(slot < table.sym_by_slot.len(), "slot {slot} out of range");
table.sym_by_slot[slot]
}
pub fn advance(&mut self, sym: u32, table: &FrequencyTable) -> Result<(), AnsError> {
let mask = table.total as u64 - 1;
let slot = self.state & mask;
let sym_us = sym as usize;
if sym_us >= table.freqs.len() {
return Err(AnsError::InvalidSymbol {
symbol: sym,
alphabet_size: table.freqs.len(),
});
}
let freq = table.freqs[sym_us] as u64;
let start = table.cdf[sym_us] as u64;
self.state = freq * (self.state >> table.precision_bits) + (slot - start);
while self.state < RANS64_L {
if self.cursor < 4 {
return Err(AnsError::TruncatedInput {
available: self.cursor,
needed: 4,
});
}
self.cursor -= 4;
let word =
u32::from_le_bytes(self.bytes[self.cursor..self.cursor + 4].try_into().unwrap());
self.state = (self.state << 32) | (word as u64);
}
Ok(())
}
#[inline]
#[must_use]
pub fn state(&self) -> u64 {
self.state
}
#[inline]
#[must_use]
pub fn remaining_bytes(&self) -> usize {
self.cursor
}
}
pub fn encode64(symbols: &[u32], table: &FrequencyTable) -> Result<Vec<u8>, AnsError> {
let mut enc = Rans64Encoder::with_capacity(symbols.len());
for &sym in symbols.iter().rev() {
enc.put(sym, table)?;
}
Ok(enc.finish())
}
pub fn decode64(bytes: &[u8], table: &FrequencyTable, len: usize) -> Result<Vec<u32>, AnsError> {
let mut dec = Rans64Decoder::new(bytes)?;
let mut out = Vec::with_capacity(len);
for _ in 0..len {
out.push(dec.get(table)?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn smoke_roundtrip_small_alphabet() {
let counts = [1u32, 2, 3, 4];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let symbols = vec![0u32, 1, 2, 3, 2, 2, 1, 0, 3];
let enc = encode(&symbols, &table).unwrap();
let dec = decode(&enc, &table, symbols.len()).unwrap();
assert_eq!(symbols, dec);
}
#[test]
fn decode_rejects_invalid_final_state() {
let counts = [1u32, 2, 3, 4];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let symbols = vec![0u32, 1, 2, 3, 2, 2, 1, 0, 3];
let mut enc = encode(&symbols, &table).unwrap();
let n = enc.len();
enc[n - 4..n].copy_from_slice(&0u32.to_le_bytes());
let err = decode(&enc, &table, symbols.len()).unwrap_err();
assert!(matches!(err, AnsError::InvalidState { .. }));
}
#[test]
fn roundtrip_single_symbol_alphabet() {
let counts = [42u32];
let table = FrequencyTable::from_counts(&counts, 10).unwrap();
let symbols = vec![0u32; 50];
let enc = encode(&symbols, &table).unwrap();
let dec = decode(&enc, &table, symbols.len()).unwrap();
assert_eq!(symbols, dec);
}
#[test]
fn roundtrip_empty_message() {
let counts = [5u32, 3, 2];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let symbols: Vec<u32> = vec![];
let enc = encode(&symbols, &table).unwrap();
let dec = decode(&enc, &table, 0).unwrap();
assert_eq!(symbols, dec);
}
#[test]
fn roundtrip_precision_boundaries() {
for prec in [1, 2, 19, 20] {
let counts = [3u32, 7];
let table = FrequencyTable::from_counts(&counts, prec).unwrap();
let symbols = vec![0u32, 1, 1, 0, 1];
let enc = encode(&symbols, &table).unwrap();
let dec = decode(&enc, &table, symbols.len()).unwrap();
assert_eq!(symbols, dec, "failed at precision_bits={prec}");
}
}
#[test]
fn errors_on_invalid_precision() {
assert!(FrequencyTable::from_counts(&[1], 0).is_err());
assert!(FrequencyTable::from_counts(&[1], 21).is_err());
}
#[test]
fn errors_on_empty_counts() {
assert!(FrequencyTable::from_counts(&[], 12).is_err());
}
#[test]
fn errors_on_all_zero_counts() {
assert!(FrequencyTable::from_counts(&[0, 0, 0], 12).is_err());
}
#[test]
fn errors_when_precision_too_small_for_alphabet() {
let err = FrequencyTable::from_counts(&[1, 1, 1], 1).unwrap_err();
assert!(matches!(err, AnsError::InvalidTable(_)));
}
proptest! {
#[test]
fn prop_rans_roundtrip(
symbols in prop::collection::vec(0u32..256u32, 0..200),
counts in prop::collection::vec(1u32..100u32, 1..16),
) {
let alphabet = counts.len().max(1);
let min_bits = (alphabet as f64).log2().ceil().max(1.0) as u32;
let precision_bits = min_bits.clamp(1, 20);
let table = FrequencyTable::from_counts(&counts, precision_bits).unwrap();
let symbols: Vec<u32> = symbols.into_iter().map(|s| s % (alphabet as u32)).collect();
let enc = encode(&symbols, &table)?;
let dec = decode(&enc, &table, symbols.len())?;
prop_assert_eq!(symbols, dec);
}
}
#[test]
fn streaming_roundtrip() {
let counts = [1u32, 2, 3, 4];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let message = vec![0u32, 1, 2, 3, 2, 2, 1, 0, 3];
let mut enc = RansEncoder::new();
for &sym in message.iter().rev() {
enc.put(sym, &table).unwrap();
}
let bytes = enc.finish();
let mut dec = RansDecoder::new(&bytes).unwrap();
let mut decoded = Vec::new();
for _ in 0..message.len() {
decoded.push(dec.get(&table).unwrap());
}
assert_eq!(message, decoded);
}
proptest! {
#[test]
fn prop_streaming_matches_batch(
precision_bits in 1u32..21,
symbols in prop::collection::vec(0u32..256u32, 0..200),
counts in prop::collection::vec(1u32..100u32, 1..32),
) {
let alphabet = counts.len().max(1);
let table = match FrequencyTable::from_counts(&counts, precision_bits) {
Ok(t) => t,
Err(_) => { return Ok(()); }
};
let symbols: Vec<u32> = symbols.into_iter().map(|s| s % (alphabet as u32)).collect();
let batch_bytes = encode(&symbols, &table)?;
let mut enc = RansEncoder::with_capacity(symbols.len());
for &sym in symbols.iter().rev() {
enc.put(sym, &table)?;
}
let stream_bytes = enc.finish();
prop_assert_eq!(&batch_bytes, &stream_bytes);
}
}
#[test]
fn peek_and_advance() {
let counts = [3u32, 7];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let message = vec![0u32, 1, 1, 0, 1];
let bytes = encode(&message, &table).unwrap();
let mut dec = RansDecoder::new(&bytes).unwrap();
for &expected in &message {
let sym = dec.peek(&table);
assert_eq!(sym, expected);
let slot = dec.state() & (table.total() - 1);
assert_eq!(table.symbol_at_slot(slot), Some(sym));
dec.advance(sym, &table).unwrap();
}
}
#[test]
fn streaming_empty_message() {
let table = FrequencyTable::from_counts(&[5, 3, 2], 12).unwrap();
let enc = RansEncoder::new();
let bytes = enc.finish();
assert_eq!(bytes.len(), 4);
let dec = RansDecoder::new(&bytes).unwrap();
assert_eq!(dec.remaining_bytes(), 0);
assert_eq!(dec.state(), RANS_L);
let _ = &table; }
#[test]
fn streaming_single_symbol() {
let counts = [42u32];
let table = FrequencyTable::from_counts(&counts, 10).unwrap();
let message = vec![0u32; 50];
let mut enc = RansEncoder::new();
for &sym in message.iter().rev() {
enc.put(sym, &table).unwrap();
}
let bytes = enc.finish();
let mut dec = RansDecoder::new(&bytes).unwrap();
let mut decoded = Vec::new();
for _ in 0..message.len() {
decoded.push(dec.get(&table).unwrap());
}
assert_eq!(message, decoded);
}
#[test]
fn streaming_precision_boundaries() {
for prec in [1, 2, 19, 20] {
let counts = [3u32, 7];
let table = FrequencyTable::from_counts(&counts, prec).unwrap();
let message = vec![0u32, 1, 1, 0, 1];
let mut enc = RansEncoder::new();
for &sym in message.iter().rev() {
enc.put(sym, &table).unwrap();
}
let bytes = enc.finish();
let mut dec = RansDecoder::new(&bytes).unwrap();
let mut decoded = Vec::new();
for _ in 0..message.len() {
decoded.push(dec.get(&table).unwrap());
}
assert_eq!(message, decoded, "failed at precision_bits={prec}");
}
}
#[test]
fn streaming_encoder_error_invalid_symbol() {
let table = FrequencyTable::from_counts(&[5, 3, 2], 12).unwrap();
let mut enc = RansEncoder::new();
let err = enc.put(3, &table).unwrap_err(); assert!(matches!(err, AnsError::InvalidSymbol { symbol: 3, .. }));
}
#[test]
fn encode_zero_frequency_symbol() {
let table = FrequencyTable::from_counts(&[5, 0, 3], 12).unwrap();
assert_eq!(table.freq(1).unwrap(), 0);
let mut enc = RansEncoder::new();
let err = enc.put(1, &table).unwrap_err();
assert!(matches!(err, AnsError::ZeroFrequency { symbol: 1 }));
}
#[test]
fn decode_beyond_message_length() {
let table = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let message = [0u32, 1];
let bytes = encode(&message, &table).unwrap();
let err = decode(&bytes, &table, 100).unwrap_err();
assert!(matches!(err, AnsError::TruncatedInput { .. }));
}
#[test]
fn batch_decode_zero_symbols() {
let table = FrequencyTable::from_counts(&[5, 3, 2], 12).unwrap();
let bytes = RANS_L.to_le_bytes();
let result = decode(&bytes, &table, 0).unwrap();
assert!(result.is_empty());
}
#[test]
fn advance_rejects_invalid_symbol() {
let table = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let message = [0u32, 1];
let bytes = encode(&message, &table).unwrap();
let mut dec = RansDecoder::new(&bytes).unwrap();
let err = dec.advance(2, &table).unwrap_err();
assert!(matches!(
err,
AnsError::InvalidSymbol {
symbol: 2,
alphabet_size: 2
}
));
}
#[test]
fn streaming_decoder_truncated() {
let err = RansDecoder::new(&[0u8, 1, 2]).unwrap_err();
assert!(matches!(
err,
AnsError::TruncatedInput {
available: 3,
needed: 4
}
));
let err = RansDecoder::new(&[]).unwrap_err();
assert!(matches!(
err,
AnsError::TruncatedInput {
available: 0,
needed: 4
}
));
}
#[test]
fn streaming_decoder_corrupted_state() {
let bytes = 0u32.to_le_bytes();
let err = RansDecoder::new(&bytes).unwrap_err();
assert!(matches!(err, AnsError::InvalidState { .. }));
}
#[test]
fn decoder_exhaustion_no_panic() {
let table = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let message = [0u32, 1];
let bytes = encode(&message, &table).unwrap();
let mut dec = RansDecoder::new(&bytes).unwrap();
for _ in 0..2 {
dec.get(&table).unwrap();
}
for _ in 0..100 {
match dec.get(&table) {
Ok(_) => {}
Err(_) => break,
}
}
}
#[test]
fn stress_large_message() {
let counts = [1u32, 2, 4, 8, 16, 32, 64, 128, 256, 512];
let table = FrequencyTable::from_counts(&counts, 14).unwrap();
let message: Vec<u32> = (0..10_000).map(|i| (i % 10) as u32).collect();
let bytes = encode(&message, &table).unwrap();
let decoded = decode(&bytes, &table, message.len()).unwrap();
assert_eq!(message, decoded);
}
#[test]
fn bits_back_cross_model() {
let prior = FrequencyTable::from_counts(&[1, 1], 12).unwrap();
let posterior = FrequencyTable::from_counts(&[8, 2], 12).unwrap();
let seed_model = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let seed = [1u32, 0, 1, 1, 0, 1, 0, 0];
let seed_bytes = encode(&seed, &seed_model).unwrap();
let mut dec = RansDecoder::new(&seed_bytes).unwrap();
let mut latents = Vec::new();
for _ in 0..3 {
let z = dec.peek(&prior);
dec.advance(z, &prior).unwrap();
latents.push(z);
assert!(z < prior.alphabet_size() as u32);
}
let mut enc = RansEncoder::new();
for &z in latents.iter().rev() {
enc.put(z, &posterior).unwrap();
}
let posterior_bytes = enc.finish();
let mut dec2 = RansDecoder::new(&posterior_bytes).unwrap();
let mut recovered = Vec::new();
for _ in 0..3 {
let z = dec2.peek(&posterior);
dec2.advance(z, &posterior).unwrap();
recovered.push(z);
}
assert_eq!(
latents, recovered,
"posterior decode must recover prior-decoded latents"
);
}
#[test]
fn frequency_table_invariants() {
let counts = [10u32, 20, 30, 40];
let table = FrequencyTable::from_counts(&counts, 14).unwrap();
for i in 0..table.alphabet_size() {
let cdf_i = table.cum_freq(i as u32).unwrap();
let cdf_next = table.cum_freq((i + 1) as u32).unwrap_or(table.total());
assert!(cdf_next >= cdf_i, "CDF not monotone at symbol {i}");
}
let sum: u32 = (0..table.alphabet_size())
.map(|i| table.freq(i as u32).unwrap())
.sum();
assert_eq!(sum, table.total());
for slot in 0..table.total() {
let sym = table.symbol_at_slot(slot).unwrap();
assert!((sym as usize) < table.alphabet_size());
let cdf = table.cum_freq(sym).unwrap();
let freq = table.freq(sym).unwrap();
assert!(
slot >= cdf && slot < cdf + freq,
"slot {slot} not in range [{cdf}, {}) for sym {sym}",
cdf + freq
);
}
assert!(table.freq(table.alphabet_size() as u32).is_none());
assert_eq!(
table.cum_freq(table.alphabet_size() as u32),
Some(table.total())
);
assert!(table.cum_freq((table.alphabet_size() + 1) as u32).is_none());
assert!(table.symbol_at_slot(table.total()).is_none());
}
proptest! {
#[test]
fn prop_frequency_table_cdf_invariants(
counts in prop::collection::vec(0u32..100u32, 1..16),
) {
if counts.iter().all(|&c| c == 0) { return Ok(()); }
let min_bits = {
let nonzero = counts.iter().filter(|&&c| c > 0).count();
(nonzero as f64).log2().ceil().max(1.0) as u32
};
let precision_bits = min_bits.clamp(1, 20);
let table = FrequencyTable::from_counts(&counts, precision_bits).unwrap();
let sum: u32 = table.freqs().iter().sum();
prop_assert_eq!(sum, table.total());
for (i, &c) in counts.iter().enumerate() {
let f = table.freq(i as u32).unwrap();
if c > 0 {
prop_assert!(f >= 1, "symbol {} had count {} but freq 0", i, c);
} else {
prop_assert_eq!(f, 0, "symbol {} had count 0 but freq {}", i, f);
}
}
let cdf = table.cdf();
for i in 0..table.alphabet_size() {
prop_assert_eq!(cdf[i + 1] - cdf[i], table.freq(i as u32).unwrap());
}
}
}
#[test]
fn from_normalized_roundtrip() {
let table1 = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let table2 =
FrequencyTable::from_normalized(table1.freqs(), table1.precision_bits()).unwrap();
assert_eq!(table1.freqs(), table2.freqs());
assert_eq!(table1.cdf(), table2.cdf());
let message = [0u32, 1, 1, 0, 1];
let bytes1 = encode(&message, &table1).unwrap();
let bytes2 = encode(&message, &table2).unwrap();
assert_eq!(bytes1, bytes2);
}
#[test]
fn from_normalized_rejects_wrong_sum() {
let err = FrequencyTable::from_normalized(&[100, 200], 12).unwrap_err();
assert!(matches!(err, AnsError::InvalidTable(_)));
}
#[test]
fn from_normalized_rejects_empty() {
let err = FrequencyTable::from_normalized(&[], 12).unwrap_err();
assert!(matches!(err, AnsError::EmptyAlphabet));
}
#[test]
fn from_float_probs_roundtrip() {
let table = FrequencyTable::from_float_probs(&[0.3, 0.7], 12).unwrap();
assert_eq!(table.alphabet_size(), 2);
assert_eq!(table.freqs().iter().sum::<u32>(), table.total());
let message = [0u32, 1, 1, 0, 1, 0, 1, 1];
let bytes = encode(&message, &table).unwrap();
let decoded = decode(&bytes, &table, message.len()).unwrap();
assert_eq!(decoded, message);
}
#[test]
fn from_float_probs_uniform() {
let table = FrequencyTable::from_float_probs(&[1.0, 1.0, 1.0, 1.0], 12).unwrap();
for i in 0..4 {
let f = table.freq(i).unwrap();
assert!(
(1020..=1028).contains(&f),
"freq[{i}] = {f}, expected ~1024"
);
}
}
#[test]
fn from_float_probs_with_zeros() {
let table = FrequencyTable::from_float_probs(&[0.0, 0.5, 0.5], 12).unwrap();
assert_eq!(table.freq(0).unwrap(), 0);
assert!(table.freq(1).unwrap() > 0);
assert!(table.freq(2).unwrap() > 0);
}
#[test]
fn from_float_probs_rejects_all_zero() {
let err = FrequencyTable::from_float_probs(&[0.0, 0.0], 12).unwrap_err();
assert!(matches!(err, AnsError::InvalidTable(_)));
}
#[test]
fn freqs_and_cdf_accessors() {
let table = FrequencyTable::from_counts(&[10, 20, 30], 12).unwrap();
let freqs = table.freqs();
let cdf = table.cdf();
assert_eq!(freqs.len(), 3);
assert_eq!(cdf.len(), 4);
assert_eq!(cdf[0], 0);
assert_eq!(*cdf.last().unwrap(), table.total());
assert_eq!(freqs.iter().sum::<u32>(), table.total());
for i in 0..freqs.len() {
assert_eq!(cdf[i + 1], cdf[i] + freqs[i]);
}
}
#[test]
fn encoder_state_always_valid() {
let table = FrequencyTable::from_counts(&[1, 2, 3, 4], 12).unwrap();
let mut enc = RansEncoder::new();
assert!(enc.state() >= RANS_L);
for sym in [0u32, 1, 2, 3, 2, 1, 0] {
enc.put(sym, &table).unwrap();
assert!(
enc.state() >= RANS_L,
"state {} < RANS_L after encoding sym {}",
enc.state(),
sym
);
}
}
#[test]
fn compression_ratio_sanity() {
let counts = [990u32, 5, 3, 1, 1];
let table = FrequencyTable::from_counts(&counts, 14).unwrap();
let message: Vec<u32> = (0..1000)
.map(|i| if i % 100 == 0 { 1 } else { 0 })
.collect();
let bytes = encode(&message, &table).unwrap();
assert!(
bytes.len() < 100,
"compressed {} bytes, expected < 100 for 99%-skewed distribution",
bytes.len()
);
}
#[test]
fn decoder_remaining_bytes_monotone() {
let table = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let message: Vec<u32> = (0..50).map(|i| (i % 2) as u32).collect();
let bytes = encode(&message, &table).unwrap();
let mut dec = RansDecoder::new(&bytes).unwrap();
let mut prev_remaining = dec.remaining_bytes();
for _ in 0..50 {
dec.get(&table).unwrap();
assert!(
dec.remaining_bytes() <= prev_remaining,
"remaining_bytes increased: {} > {}",
dec.remaining_bytes(),
prev_remaining
);
prev_remaining = dec.remaining_bytes();
}
}
#[test]
fn rans64_roundtrip() {
let counts = [1u32, 2, 3, 4];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let symbols = vec![0u32, 1, 2, 3, 2, 2, 1, 0, 3];
let enc = encode64(&symbols, &table).unwrap();
let dec = decode64(&enc, &table, symbols.len()).unwrap();
assert_eq!(symbols, dec);
}
#[test]
fn rans64_streaming_roundtrip() {
let counts = [3u32, 7];
let table = FrequencyTable::from_counts(&counts, 12).unwrap();
let message = vec![0u32, 1, 1, 0, 1];
let mut enc = Rans64Encoder::new();
for &sym in message.iter().rev() {
enc.put(sym, &table).unwrap();
}
let bytes = enc.finish();
let mut dec = Rans64Decoder::new(&bytes).unwrap();
let mut decoded = Vec::new();
for _ in 0..message.len() {
decoded.push(dec.get(&table).unwrap());
}
assert_eq!(message, decoded);
}
#[test]
fn rans64_peek_advance() {
let table = FrequencyTable::from_counts(&[3, 7], 12).unwrap();
let message = vec![0u32, 1, 1, 0, 1];
let bytes = encode64(&message, &table).unwrap();
let mut dec = Rans64Decoder::new(&bytes).unwrap();
for &expected in &message {
let sym = dec.peek(&table);
assert_eq!(sym, expected);
dec.advance(sym, &table).unwrap();
}
}
#[test]
fn rans64_large_message() {
let counts = [1u32, 2, 4, 8, 16, 32, 64, 128, 256, 512];
let table = FrequencyTable::from_counts(&counts, 14).unwrap();
let message: Vec<u32> = (0..10_000).map(|i| (i % 10) as u32).collect();
let bytes = encode64(&message, &table).unwrap();
let decoded = decode64(&bytes, &table, message.len()).unwrap();
assert_eq!(message, decoded);
}
#[test]
fn rans64_empty_message() {
let table = FrequencyTable::from_counts(&[5, 3, 2], 12).unwrap();
let bytes = encode64(&[], &table).unwrap();
assert_eq!(bytes.len(), 8);
let decoded = decode64(&bytes, &table, 0).unwrap();
assert!(decoded.is_empty());
}
#[test]
fn rans64_single_symbol_alphabet() {
let table = FrequencyTable::from_counts(&[42], 10).unwrap();
let message = vec![0u32; 50];
let bytes = encode64(&message, &table).unwrap();
let decoded = decode64(&bytes, &table, message.len()).unwrap();
assert_eq!(message, decoded);
}
proptest! {
#[test]
fn prop_rans64_roundtrip(
symbols in prop::collection::vec(0u32..256u32, 0..200),
counts in prop::collection::vec(1u32..100u32, 1..16),
) {
let alphabet = counts.len().max(1);
let min_bits = (alphabet as f64).log2().ceil().max(1.0) as u32;
let precision_bits = min_bits.clamp(1, 20);
let table = FrequencyTable::from_counts(&counts, precision_bits).unwrap();
let symbols: Vec<u32> = symbols.into_iter().map(|s| s % (alphabet as u32)).collect();
let enc = encode64(&symbols, &table)?;
let dec = decode64(&enc, &table, symbols.len())?;
prop_assert_eq!(symbols, dec);
}
}
}