pub const ANS_BITS: u32 = 15;
pub const ANS_TOTAL: u32 = 1 << ANS_BITS;
pub const ANS_LOW: u32 = 1 << ANS_BITS;
pub const ANS_HIGH: u32 = 1 << 31;
#[derive(Clone, Debug)]
pub struct Cdf {
pub lo: u32,
pub hi: u32,
pub total: u32,
}
impl Cdf {
#[inline]
pub fn new(lo: u32, hi: u32, total: u32) -> Self {
Self { lo, hi, total }
}
#[inline]
pub fn freq(&self) -> u32 {
self.hi - self.lo
}
}
pub fn quantize_pdf_to_rans_cdf(pdf: &[f64]) -> Vec<u32> {
let mut cdf = vec![0u32; pdf.len() + 1];
let mut freqs = vec![0i64; pdf.len()];
quantize_pdf_to_rans_cdf_with_buffer(pdf, &mut cdf, &mut freqs);
cdf
}
pub fn quantize_pdf_to_rans_cdf_with_buffer(
pdf: &[f64],
cdf_out: &mut [u32],
freq_buf: &mut [i64],
) {
let n = pdf.len();
super::quantize_pdf_to_integer_cdf_with_buffer(pdf, ANS_TOTAL, cdf_out, freq_buf);
debug_assert_eq!(cdf_out[n], ANS_TOTAL, "CDF total must equal ANS_TOTAL");
for i in 0..n {
if pdf[i] > 0.0 {
debug_assert!(
cdf_out[i + 1] > cdf_out[i],
"Symbol {} with p={} has zero frequency",
i,
pdf[i]
);
}
}
}
#[inline]
pub fn cdf_for_symbol(cdf: &[u32], sym: usize) -> Cdf {
Cdf::new(cdf[sym], cdf[sym + 1], ANS_TOTAL)
}
pub struct RansEncoder {
state: u32,
output: Vec<u16>, }
impl RansEncoder {
pub fn new() -> Self {
Self {
state: ANS_LOW,
output: Vec::new(),
}
}
#[inline]
pub fn encode(&mut self, cdf: &Cdf) {
let freq = cdf.freq();
debug_assert!(freq > 0, "Symbol frequency must be > 0");
while self.state >= (freq << 16) {
self.output.push(self.state as u16);
self.state >>= 16;
}
let q = self.state / freq;
let r = self.state % freq;
self.state = (q << ANS_BITS) + r + cdf.lo;
}
pub fn encode_pdf(&mut self, pdf: &[f64], sym: usize) {
let cdf_table = quantize_pdf_to_rans_cdf(pdf);
let cdf = cdf_for_symbol(&cdf_table, sym);
self.encode(&cdf);
}
pub fn finish(self) -> Vec<u8> {
let mut result = Vec::with_capacity(self.output.len() * 2 + 4);
result.extend_from_slice(&self.state.to_le_bytes());
for &word in self.output.iter().rev() {
result.extend_from_slice(&word.to_le_bytes());
}
result
}
pub fn size_estimate(&self) -> usize {
self.output.len() * 2 + 4 }
}
impl Default for RansEncoder {
fn default() -> Self {
Self::new()
}
}
pub struct RansDecoder<'a> {
state: u32,
input: &'a [u8],
pos: usize,
}
impl<'a> RansDecoder<'a> {
pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
if input.len() < 4 {
anyhow::bail!("rANS input too short");
}
let state = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
Ok(Self {
state,
input,
pos: 4,
})
}
#[inline]
pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
let slot = self.state & (ANS_TOTAL - 1);
let mut lo = 0usize;
let mut hi = cdf.len() - 1;
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if cdf[mid] <= slot {
lo = mid;
} else {
hi = mid;
}
}
let sym = lo;
let c_lo = cdf[sym];
let c_hi = cdf[sym + 1];
let freq = c_hi - c_lo;
self.state = freq * (self.state >> ANS_BITS) + slot - c_lo;
while self.state < ANS_LOW && self.pos + 1 < self.input.len() {
let word = u16::from_le_bytes([self.input[self.pos], self.input[self.pos + 1]]);
self.state = (self.state << 16) | (word as u32);
self.pos += 2;
}
Ok(sym)
}
pub fn decode_pdf(&mut self, pdf: &[f64]) -> anyhow::Result<usize> {
let cdf = quantize_pdf_to_rans_cdf(pdf);
self.decode(&cdf)
}
}
#[cfg(target_arch = "x86_64")]
mod simd {
use super::*;
pub const RANS_LANES: usize = 8;
pub struct SimdRansEncoder {
states: [u32; RANS_LANES],
outputs: [Vec<u8>; RANS_LANES],
lane: usize,
}
impl SimdRansEncoder {
pub fn new() -> Self {
Self {
states: [ANS_LOW; RANS_LANES],
outputs: Default::default(),
lane: 0,
}
}
pub fn encode(&mut self, cdf: &Cdf) {
let freq = cdf.freq();
let lane = self.lane;
self.lane = (self.lane + 1) % RANS_LANES;
let state = &mut self.states[lane];
let output = &mut self.outputs[lane];
while *state >= (ANS_HIGH / cdf.total) * freq {
output.push(*state as u8);
*state >>= 8;
}
*state = ((*state / freq) * cdf.total) + (*state % freq) + cdf.lo;
}
pub fn finish(self) -> Vec<u8> {
let mut result = Vec::new();
for &s in self.states.iter().take(RANS_LANES) {
result.extend_from_slice(&s.to_le_bytes());
}
let max_len = self.outputs.iter().map(|v| v.len()).max().unwrap_or(0);
for pos in 0..max_len {
for lane in 0..RANS_LANES {
let out = &self.outputs[lane];
if pos < out.len() {
result.push(out[out.len() - 1 - pos]);
} else {
result.push(0);
}
}
}
result
}
}
impl Default for SimdRansEncoder {
fn default() -> Self {
Self::new()
}
}
pub struct SimdRansDecoder<'a> {
states: [u32; RANS_LANES],
input: &'a [u8],
pos: usize,
lane: usize,
}
impl<'a> SimdRansDecoder<'a> {
pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
if input.len() < RANS_LANES * 4 {
anyhow::bail!("SIMD rANS input too short");
}
let mut states = [0u32; RANS_LANES];
for (i, state) in states.iter_mut().enumerate() {
let offset = i * 4;
*state = u32::from_le_bytes([
input[offset],
input[offset + 1],
input[offset + 2],
input[offset + 3],
]);
}
Ok(Self {
states,
input,
pos: RANS_LANES * 4,
lane: 0,
})
}
pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
let lane = self.lane;
self.lane = (self.lane + 1) % RANS_LANES;
let state = &mut self.states[lane];
let total = ANS_TOTAL;
let value = *state & (total - 1);
let mut lo = 0usize;
let mut hi = cdf.len() - 1;
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if cdf[mid] <= value {
lo = mid;
} else {
hi = mid;
}
}
let sym = lo;
let c_lo = cdf[sym];
let c_hi = cdf[sym + 1];
let freq = c_hi - c_lo;
*state = freq * (*state >> ANS_BITS) + (*state & (total - 1)) - c_lo;
while *state < ANS_LOW {
let byte_idx = self.pos + lane;
if byte_idx < self.input.len() {
*state = (*state << 8) | (self.input[byte_idx] as u32);
}
self.pos += RANS_LANES;
}
Ok(sym)
}
}
}
#[cfg(target_arch = "x86_64")]
pub use simd::*;
#[cfg(not(target_arch = "x86_64"))]
pub const RANS_LANES: usize = 1;
#[cfg(not(target_arch = "x86_64"))]
pub struct SimdRansEncoder {
inner: RansEncoder,
}
#[cfg(not(target_arch = "x86_64"))]
impl SimdRansEncoder {
pub fn new() -> Self {
Self {
inner: RansEncoder::new(),
}
}
pub fn encode(&mut self, cdf: &Cdf) {
self.inner.encode(cdf);
}
pub fn finish(self) -> Vec<u8> {
self.inner.finish()
}
}
#[cfg(not(target_arch = "x86_64"))]
impl Default for SimdRansEncoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(target_arch = "x86_64"))]
pub struct SimdRansDecoder<'a> {
inner: RansDecoder<'a>,
}
#[cfg(not(target_arch = "x86_64"))]
impl<'a> SimdRansDecoder<'a> {
pub fn new(input: &'a [u8]) -> anyhow::Result<Self> {
Ok(Self {
inner: RansDecoder::new(input)?,
})
}
pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
self.inner.decode(cdf)
}
}
pub const BLOCK_SIZE: usize = 128 * 1024;
pub struct BlockedRansEncoder {
symbols: Vec<Cdf>,
blocks: Vec<Vec<u8>>,
}
impl BlockedRansEncoder {
pub fn new() -> Self {
Self {
symbols: Vec::with_capacity(BLOCK_SIZE),
blocks: Vec::new(),
}
}
pub fn encode(&mut self, cdf: Cdf) {
self.symbols.push(cdf);
if self.symbols.len() >= BLOCK_SIZE {
self.flush_block();
}
}
fn flush_block(&mut self) {
if self.symbols.is_empty() {
return;
}
let mut encoder = RansEncoder::new();
for cdf in self.symbols.iter().rev() {
encoder.encode(cdf);
}
let encoded = encoder.finish();
self.blocks.push(encoded);
self.symbols.clear();
}
pub fn finish(mut self) -> Vec<Vec<u8>> {
self.flush_block();
self.blocks
}
}
impl Default for BlockedRansEncoder {
fn default() -> Self {
Self::new()
}
}
pub struct BlockedRansDecoder<'a> {
blocks: Vec<&'a [u8]>,
current_block: usize,
symbols_remaining_in_block: usize,
total_symbols: usize,
decoder: Option<RansDecoder<'a>>,
}
impl<'a> BlockedRansDecoder<'a> {
pub fn new(blocks: Vec<&'a [u8]>, total_symbols: usize) -> anyhow::Result<Self> {
let expected_blocks = if total_symbols == 0 {
0
} else {
total_symbols.div_ceil(BLOCK_SIZE)
};
if blocks.len() != expected_blocks {
anyhow::bail!(
"blocked rANS expected {expected_blocks} blocks for {total_symbols} symbols, got {}",
blocks.len()
);
}
Ok(Self {
blocks,
current_block: 0,
symbols_remaining_in_block: 0,
total_symbols,
decoder: None,
})
}
#[inline]
fn open_block(&mut self, block_index: usize) -> anyhow::Result<()> {
if block_index >= self.blocks.len() {
anyhow::bail!("No more blocks to decode");
}
let consumed = block_index.saturating_mul(BLOCK_SIZE);
let remaining = self.total_symbols.saturating_sub(consumed);
self.current_block = block_index;
self.symbols_remaining_in_block = remaining.min(BLOCK_SIZE);
self.decoder = Some(RansDecoder::new(self.blocks[block_index])?);
Ok(())
}
pub fn decode(&mut self, cdf: &[u32]) -> anyhow::Result<usize> {
if self.symbols_remaining_in_block == 0 {
if self.decoder.is_some() {
self.open_block(self.current_block + 1)?;
} else {
self.open_block(0)?;
}
}
let sym = self
.decoder
.as_mut()
.expect("decoder initialized for current block")
.decode(cdf)?;
self.symbols_remaining_in_block = self.symbols_remaining_in_block.saturating_sub(1);
Ok(sym)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip_scalar() {
let pdf = vec![0.5, 0.3, 0.15, 0.05];
let symbols = vec![0, 0, 1, 0, 2, 1, 0, 3, 0, 0, 1, 2];
let mut enc = RansEncoder::new();
let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
for &s in symbols.iter().rev() {
let cdf = cdf_for_symbol(&cdf_table, s);
enc.encode(&cdf);
}
let encoded = enc.finish();
let mut dec = RansDecoder::new(&encoded).unwrap();
for &expected in &symbols {
let got = dec.decode(&cdf_table).unwrap();
assert_eq!(got, expected, "Symbol mismatch");
}
}
#[test]
fn test_cdf_quantization() {
let pdf = vec![0.25, 0.25, 0.25, 0.25];
let cdf = quantize_pdf_to_rans_cdf(&pdf);
assert_eq!(cdf[0], 0);
assert_eq!(cdf[4], ANS_TOTAL);
for i in 1..4 {
let delta = cdf[i] - cdf[i - 1];
assert!(delta > 0);
}
}
#[test]
fn test_extreme_probabilities() {
let pdf = vec![0.99, 0.005, 0.003, 0.002];
let symbols = vec![0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 3];
let mut enc = RansEncoder::new();
let cdf_table = quantize_pdf_to_rans_cdf(&pdf);
for &s in symbols.iter().rev() {
let cdf = cdf_for_symbol(&cdf_table, s);
enc.encode(&cdf);
}
let encoded = enc.finish();
let mut dec = RansDecoder::new(&encoded).unwrap();
for &expected in &symbols {
let got = dec.decode(&cdf_table).unwrap();
assert_eq!(got, expected);
}
}
#[test]
fn test_blocked_rans_roundtrip_across_block_boundary() {
let pdf = vec![0.5, 0.25, 0.125, 0.125];
let cdf = quantize_pdf_to_rans_cdf(&pdf);
let symbols: Vec<usize> = (0..(BLOCK_SIZE + 17)).map(|i| i % pdf.len()).collect();
let mut enc = BlockedRansEncoder::new();
for &sym in &symbols {
enc.encode(cdf_for_symbol(&cdf, sym));
}
let blocks = enc.finish();
let block_refs: Vec<&[u8]> = blocks.iter().map(Vec::as_slice).collect();
let mut dec = BlockedRansDecoder::new(block_refs, symbols.len()).unwrap();
for &expected in &symbols {
let got = dec.decode(&cdf).unwrap();
assert_eq!(got, expected, "blocked rANS mismatch at symbol {expected}");
}
}
}