const EC_PROB_SHIFT: u32 = 6;
const EC_MIN_PROB: u32 = 4;
const CDF_PROB_TOP: u32 = 1 << 15;
#[derive(Debug, Clone)]
pub struct SymbolEncoder {
low: u64,
rng: u32,
cnt: i32,
precarry: Vec<u16>,
}
impl Default for SymbolEncoder {
fn default() -> Self {
Self::new()
}
}
impl SymbolEncoder {
#[must_use]
pub fn new() -> Self {
Self {
low: 0,
rng: CDF_PROB_TOP,
cnt: -9,
precarry: Vec::new(),
}
}
pub fn encode_symbol(&mut self, symbol: usize, cdf: &[u16]) {
let nsyms = cdf.len();
debug_assert!(symbol < nsyms);
debug_assert_eq!(u32::from(cdf[nsyms - 1]), CDF_PROB_TOP);
let fl = if symbol > 0 {
CDF_PROB_TOP - u32::from(cdf[symbol - 1])
} else {
CDF_PROB_TOP
};
let fh = CDF_PROB_TOP - u32::from(cdf[symbol]);
self.encode_q15(fl, fh, symbol as u32, nsyms as u32);
}
pub fn encode_symbol_adapt(&mut self, symbol: usize, cdf: &mut [u16], count: &mut u16) {
self.encode_symbol(symbol, cdf);
update_cdf(cdf, symbol, count);
}
pub fn encode_literal(&mut self, value: u32, n: u32) {
const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
for i in (0..n).rev() {
self.encode_symbol(((value >> i) & 1) as usize, &BOOL_CDF);
}
}
fn encode_q15(&mut self, fl: u32, fh: u32, s: u32, nsyms: u32) {
let mut low = self.low;
let mut r = self.rng;
debug_assert!(r >= CDF_PROB_TOP);
let n = nsyms - 1;
if fl < CDF_PROB_TOP {
let u = (((r >> 8) * (fl >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
+ EC_MIN_PROB * (n - (s - 1));
let v =
(((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
debug_assert!(u <= r && v < u);
low += u64::from(r - u);
r = u - v;
} else {
let v =
(((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
debug_assert!(v < r);
r -= v;
}
self.normalize(low, r);
}
fn normalize(&mut self, mut low: u64, rng: u32) {
let d = rng.leading_zeros() - 16;
let mut c = self.cnt;
let mut s = c + d as i32;
if s >= 0 {
c += 16;
let mut m = (1u64 << c) - 1;
if s >= 8 {
self.precarry.push((low >> c) as u16);
low &= m;
c -= 8;
m = (1u64 << c) - 1;
}
self.precarry.push((low >> c) as u16);
s = c + d as i32 - 24;
low &= m;
}
self.low = low << d;
self.rng = rng << d;
self.cnt = s;
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
let l = self.low;
let mut c = self.cnt;
let mut s = 10 + c;
let m: u64 = 0x3FFF;
let mut e = ((l + m) & !m) | (m + 1);
if s > 0 {
let mut n = (1u64 << (c + 16)) - 1;
loop {
self.precarry.push((e >> (c + 16)) as u16);
e &= n;
s -= 8;
c -= 8;
n >>= 8;
if s <= 0 {
break;
}
}
}
let mut out = vec![0u8; self.precarry.len()];
let mut carry: u32 = 0;
for i in (0..self.precarry.len()).rev() {
let val = u32::from(self.precarry[i]) + carry;
out[i] = (val & 0xff) as u8;
carry = val >> 8;
}
out
}
}
fn update_cdf(cdf: &mut [u16], symbol: usize, count: &mut u16) {
let n = cdf.len();
let rate = 3
+ u32::from(*count > 15)
+ u32::from(*count > 31)
+ (31 - (n as u32).leading_zeros()).min(2);
let (_top, body) = cdf.split_last_mut().expect("a CDF has at least one entry");
for v in &mut body[..symbol] {
*v -= *v >> rate;
}
for v in &mut body[symbol..] {
*v += ((1u16 << 15) - *v) >> rate;
}
if *count < 32 {
*count += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
struct SymbolDecoder<'a> {
data: &'a [u8],
bit_pos: usize,
value: u32,
range: u32,
max_bits: i64,
}
impl<'a> SymbolDecoder<'a> {
fn read_f(&mut self, n: u32) -> u32 {
let mut x = 0u32;
for _ in 0..n {
let idx = self.bit_pos >> 3;
let bit = if idx < self.data.len() {
(self.data[idx] >> (7 - (self.bit_pos & 7))) & 1
} else {
0
};
x = (x << 1) | u32::from(bit);
self.bit_pos += 1;
}
x
}
fn new(data: &'a [u8]) -> Self {
let sz = data.len();
let mut d = Self {
data,
bit_pos: 0,
value: 0,
range: 1 << 15,
max_bits: 8 * sz as i64 - 15,
};
let num_bits = core::cmp::min(sz * 8, 15) as u32;
let buf = d.read_f(num_bits);
let padded = buf << (15 - num_bits);
d.value = ((1 << 15) - 1) ^ padded;
d
}
fn read_symbol(&mut self, cdf: &[u16]) -> usize {
let n = cdf.len() as u32;
let mut cur = self.range;
let mut symbol: i64 = -1;
let mut prev;
loop {
symbol += 1;
prev = cur;
let f = (1u32 << 15) - u32::from(cdf[symbol as usize]);
cur = ((self.range >> 8) * (f >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT);
cur += EC_MIN_PROB * (n - symbol as u32 - 1);
if self.value >= cur {
break;
}
}
self.range = prev - cur;
self.value -= cur;
let bits = 15 - (31 - self.range.leading_zeros());
self.range <<= bits;
let num_bits = core::cmp::min(i64::from(bits), self.max_bits.max(0)) as u32;
let new_data = self.read_f(num_bits);
let padded = new_data << (bits - num_bits);
self.value = padded ^ (((self.value + 1) << bits) - 1);
self.max_bits -= i64::from(bits);
symbol as usize
}
fn read_symbol_adapt(&mut self, cdf: &mut [u16], count: &mut u16) -> usize {
let s = self.read_symbol(cdf);
update_cdf(cdf, s, count);
s
}
fn read_literal(&mut self, n: u32) -> u32 {
const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
let mut x = 0;
for _ in 0..n {
x = (x << 1) | self.read_symbol(&BOOL_CDF) as u32;
}
x
}
}
struct Lcg(u64);
impl Lcg {
fn next_u32(&mut self) -> u32 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(self.0 >> 32) as u32
}
fn below(&mut self, bound: u32) -> u32 {
self.next_u32() % bound
}
}
fn random_cdf(rng: &mut Lcg, nsyms: usize) -> Vec<u16> {
let mut points = Vec::new();
while points.len() < nsyms - 1 {
let p = 1 + rng.below(32767) as u16;
if !points.contains(&p) {
points.push(p);
}
}
points.sort_unstable();
points.push(32768);
points
}
#[test]
fn empty_stream_roundtrips() {
let enc = SymbolEncoder::new();
let bytes = enc.finish();
let _ = SymbolDecoder::new(&bytes);
}
#[test]
fn single_symbol_streams_roundtrip() {
for nsyms in 2..=12usize {
let mut cdf: Vec<u16> = (1..nsyms).map(|i| (i * 32768 / nsyms) as u16).collect();
cdf.push(32768);
for s in 0..nsyms {
let mut enc = SymbolEncoder::new();
enc.encode_symbol(s, &cdf);
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
assert_eq!(dec.read_symbol(&cdf), s, "nsyms={nsyms} s={s}");
}
}
}
#[test]
fn long_random_symbol_stream_roundtrips() {
let mut rng = Lcg(0x1234_5678_9abc_def0);
let cdfs: Vec<Vec<u16>> = (2..=14).map(|n| random_cdf(&mut rng, n)).collect();
let mut events = Vec::new();
let mut enc = SymbolEncoder::new();
for _ in 0..20_000 {
let cdf = &cdfs[rng.below(cdfs.len() as u32) as usize];
let s = rng.below(cdf.len() as u32) as usize;
enc.encode_symbol(s, cdf);
events.push((s, cdf.clone()));
}
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
for (i, (s, cdf)) in events.iter().enumerate() {
assert_eq!(dec.read_symbol(cdf), *s, "event {i}");
}
}
#[test]
fn literals_roundtrip() {
let mut rng = Lcg(0xdead_beef_0bad_f00d);
let mut enc = SymbolEncoder::new();
let mut events = Vec::new();
for _ in 0..5000 {
let n = 1 + rng.below(16);
let v = rng.next_u32() & ((1u32 << n) - 1);
enc.encode_literal(v, n);
events.push((v, n));
}
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
for (v, n) in events {
assert_eq!(dec.read_literal(n), v);
}
}
#[test]
fn mixed_symbols_and_literals_roundtrip() {
let mut rng = Lcg(0x0f0f_0f0f_1234_9999);
let cdf = random_cdf(&mut rng, 8);
let mut enc = SymbolEncoder::new();
let mut events: Vec<(bool, u32)> = Vec::new(); for _ in 0..8000 {
if rng.next_u32() & 1 == 0 {
let s = rng.below(cdf.len() as u32);
enc.encode_symbol(s as usize, &cdf);
events.push((false, s));
} else {
let v = rng.next_u32() & 0xff;
enc.encode_literal(v, 8);
events.push((true, v));
}
}
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
for (is_lit, payload) in events {
if is_lit {
assert_eq!(dec.read_literal(8), payload);
} else {
assert_eq!(dec.read_symbol(&cdf) as u32, payload);
}
}
}
#[test]
fn update_cdf_matches_spec_formula() {
fn upd(cdf: &[u16], symbol: usize, count: u16) -> (Vec<u16>, u16) {
let mut c = cdf.to_vec();
let mut n = count;
update_cdf(&mut c, symbol, &mut n);
(c, n)
}
assert_eq!(upd(&[16384, 32768], 0, 0), (vec![17408, 32768], 1)); assert_eq!(upd(&[16384, 32768], 1, 0), (vec![15360, 32768], 1)); assert_eq!(upd(&[16384, 32768], 0, 15), (vec![17408, 32768], 16)); assert_eq!(upd(&[16384, 32768], 0, 16), (vec![16896, 32768], 17)); assert_eq!(upd(&[16384, 32768], 0, 31), (vec![16896, 32768], 32)); assert_eq!(upd(&[16384, 32768], 0, 32), (vec![16640, 32768], 32)); assert_eq!(
upd(&[10000, 20000, 32768], 1, 20),
(vec![9688, 20399, 32768], 21)
); assert_eq!(
upd(
&[4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768],
3,
0
),
(
vec![3968, 7936, 11904, 16896, 20864, 24832, 28800, 32768],
1
) );
}
#[test]
fn adaptive_single_cdf_roundtrips() {
let mut rng = Lcg(0xa1b2_c3d4_e5f6_0719);
let init = random_cdf(&mut rng, 6);
let mut enc = SymbolEncoder::new();
let mut ecdf = init.clone();
let mut ecount = 0u16;
let mut syms = Vec::new();
for _ in 0..10_000 {
let s = (rng.below(6) * rng.below(2)) as usize;
enc.encode_symbol_adapt(s, &mut ecdf, &mut ecount);
syms.push(s);
}
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
let mut dcdf = init.clone();
let mut dcount = 0u16;
for (i, &s) in syms.iter().enumerate() {
assert_eq!(
dec.read_symbol_adapt(&mut dcdf, &mut dcount),
s,
"event {i}"
);
}
assert_eq!(ecdf, dcdf, "encoder/decoder CDFs diverged");
assert_eq!(ecount, dcount);
assert_ne!(
ecdf, init,
"CDF should have adapted away from its initial state"
);
}
#[test]
fn adaptive_multi_context_roundtrips() {
let mut rng = Lcg(0x0011_2233_4455_6677);
let inits: Vec<Vec<u16>> = (2..=10).map(|n| random_cdf(&mut rng, n)).collect();
let mut enc = SymbolEncoder::new();
let mut ecdfs = inits.clone();
let mut ecounts = vec![0u16; inits.len()];
let mut events = Vec::new();
for _ in 0..15_000 {
let ctx = rng.below(inits.len() as u32) as usize;
let s = rng.below(ecdfs[ctx].len() as u32) as usize;
enc.encode_symbol_adapt(s, &mut ecdfs[ctx], &mut ecounts[ctx]);
events.push((ctx, s));
}
let bytes = enc.finish();
let mut dec = SymbolDecoder::new(&bytes);
let mut dcdfs = inits.clone();
let mut dcounts = vec![0u16; inits.len()];
for (i, &(ctx, s)) in events.iter().enumerate() {
assert_eq!(
dec.read_symbol_adapt(&mut dcdfs[ctx], &mut dcounts[ctx]),
s,
"event {i} ctx {ctx}"
);
}
assert_eq!(ecdfs, dcdfs);
assert_eq!(ecounts, dcounts);
}
}