pub const SYMBOL_LENGTH_SHIFT: u32 = 15;
pub const SYMBOL_MAX_COUNT: u32 = 1 << SYMBOL_LENGTH_SHIFT;
pub const BIT_LENGTH_SHIFT: u32 = 13;
pub const BIT_MAX_COUNT: u32 = 1 << BIT_LENGTH_SHIFT;
#[derive(Debug, Clone)]
pub struct ArithmeticBitModel {
zero_count: u32,
total_count: u32,
zero_probability: u32,
updates_until_refresh: u32,
update_cycle: u32,
}
impl Default for ArithmeticBitModel {
fn default() -> Self {
Self {
zero_count: 1,
total_count: 2,
zero_probability: 1u32 << (BIT_LENGTH_SHIFT - 1),
updates_until_refresh: 4,
update_cycle: 4,
}
}
}
impl ArithmeticBitModel {
pub fn new() -> Self {
Self::default()
}
pub fn zero_probability(&self) -> u32 {
self.zero_probability
}
pub fn observe_bit(&mut self, bit: bool) {
if !bit {
self.zero_count += 1;
}
self.updates_until_refresh -= 1;
if self.updates_until_refresh == 0 {
self.refresh();
}
}
fn refresh(&mut self) {
self.total_count += self.update_cycle;
if self.total_count > BIT_MAX_COUNT {
self.total_count = (self.total_count + 1) >> 1;
self.zero_count = (self.zero_count + 1) >> 1;
if self.zero_count >= self.total_count {
self.total_count = self.zero_count + 1;
}
}
let scale = 0x8000_0000u32 / self.total_count;
self.zero_probability = (self.zero_count * scale) >> (31 - BIT_LENGTH_SHIFT);
self.update_cycle = (5 * self.update_cycle) >> 2;
if self.update_cycle > 64 {
self.update_cycle = 64;
}
self.updates_until_refresh = self.update_cycle;
}
}
#[derive(Debug, Clone)]
pub struct ArithmeticSymbolModel {
symbols: u32,
last_symbol: u32,
counts: Vec<u32>,
cdf: Vec<u32>,
total_count: u32,
update_cycle: u32,
updates_until_refresh: u32,
}
impl ArithmeticSymbolModel {
pub fn new(symbols: u32) -> Self {
assert!((2..=(1 << 11)).contains(&symbols), "invalid symbol count");
let counts = vec![1u32; symbols as usize];
let cdf = vec![0u32; symbols as usize];
let mut model = Self {
symbols,
last_symbol: symbols - 1,
counts,
cdf,
total_count: 0,
update_cycle: symbols,
updates_until_refresh: (symbols + 6) >> 1,
};
model.refresh();
let post_init_cycle = (symbols + 6) >> 1;
model.update_cycle = post_init_cycle;
model.updates_until_refresh = post_init_cycle;
model
}
pub fn symbols(&self) -> u32 {
self.symbols
}
pub fn last_symbol(&self) -> u32 {
self.last_symbol
}
pub fn cdf_at(&self, symbol: u32) -> u32 {
self.cdf[symbol as usize]
}
pub fn symbol_for_scaled_value(&self, v: u32) -> u32 {
let mut lo = 0u32;
let mut hi = self.symbols;
while lo + 1 < hi {
let mid = (lo + hi) >> 1;
if self.cdf[mid as usize] <= v {
lo = mid;
} else {
hi = mid;
}
}
lo
}
pub fn observe_symbol(&mut self, symbol: u32) {
self.counts[symbol as usize] += 1;
self.updates_until_refresh -= 1;
if self.updates_until_refresh == 0 {
self.refresh();
}
}
fn refresh(&mut self) {
self.total_count += self.update_cycle;
if self.total_count > SYMBOL_MAX_COUNT {
self.total_count = 0;
for c in &mut self.counts {
*c = (*c + 1) >> 1;
self.total_count += *c;
}
}
let scale = 0x8000_0000u32 / self.total_count;
let mut running = 0u32;
for (dst, count) in self.cdf.iter_mut().zip(self.counts.iter()) {
*dst = (scale * running) >> (31 - SYMBOL_LENGTH_SHIFT);
running += *count;
}
self.update_cycle = (5 * self.update_cycle) >> 2;
let max_cycle = (self.symbols + 6) << 3;
if self.update_cycle > max_cycle {
self.update_cycle = max_cycle;
}
self.updates_until_refresh = self.update_cycle;
}
}
#[cfg(test)]
mod tests {
use super::{ArithmeticBitModel, ArithmeticSymbolModel};
#[test]
fn bit_model_adapts_towards_observed_zeros() {
let mut model = ArithmeticBitModel::new();
let p0_before = model.zero_probability();
for _ in 0..128 {
model.observe_bit(false);
}
assert!(model.zero_probability() > p0_before);
}
#[test]
fn symbol_model_returns_monotonic_cdf() {
let mut model = ArithmeticSymbolModel::new(32);
for i in 0..500 {
model.observe_symbol((i % 7) as u32);
}
let mut prev = 0u32;
for s in 0..model.symbols() {
let cur = model.cdf_at(s);
assert!(cur >= prev);
prev = cur;
}
}
}