extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
const TABLE_SHIFT: i32 = 7;
pub struct RCQsModel {
symbols: usize,
bits: i32,
left: i32,
more: i32,
incr: u32,
rescale: i32,
target_rescale: i32,
symf: Vec<u32>,
cumf: Vec<u32>,
search_shift: i32,
search: Option<Vec<u32>>,
}
impl RCQsModel {
pub fn new(compress: bool, symbols: usize, bits: i32, period: i32) -> Self {
assert!(bits <= 16, "bits must be <= 16");
assert!(period < (1 << (bits + 1)), "period too large");
let n = symbols;
let symf = vec![0u32; n + 1];
let mut cumf = vec![0u32; n + 1];
cumf[0] = 0;
cumf[n] = 1u32 << bits;
let (search, search_shift) = if compress {
(None, 0)
} else {
let ss = bits - TABLE_SHIFT;
let s = vec![0u32; (1 << TABLE_SHIFT) + 1];
(Some(s), ss)
};
let mut model = Self {
symbols,
bits,
left: 0,
more: 0,
incr: 0,
rescale: 0,
target_rescale: period,
symf,
cumf,
search_shift,
search,
};
model.reset();
model
}
pub fn with_defaults(compress: bool, symbols: usize) -> Self {
Self::new(compress, symbols, 16, 0x400)
}
pub fn symbols(&self) -> usize {
self.symbols
}
pub fn reset(&mut self) {
let n = self.symbols;
self.rescale = (n as i32 >> 4) | 2;
self.more = 0;
let total_freq = self.cumf[n];
let f = total_freq / n as u32;
let m = total_freq % n as u32;
for i in 0..m as usize {
self.symf[i] = f + 1;
}
for i in m as usize..n {
self.symf[i] = f;
}
self.update();
}
#[inline]
pub fn encode(&mut self, s: u32) -> (u32, u32) {
let cum_freq = self.cumf[s as usize];
let freq = self.cumf[s as usize + 1] - cum_freq;
self.update_symbol(s);
(cum_freq, freq)
}
pub fn decode(&mut self, l: &mut u32, r: &mut u32) -> u32 {
let search = self.search.as_ref().unwrap();
let i = (*l >> self.search_shift) as usize;
let mut s = search[i];
let mut h = search[i + 1] + 1;
while s + 1 < h {
let m = (s + h) >> 1;
if *l < self.cumf[m as usize] {
h = m;
} else {
s = m;
}
}
*l = self.cumf[s as usize];
*r = self.cumf[s as usize + 1] - *l;
self.update_symbol(s);
s
}
#[inline]
pub fn normalize(&self, r: &mut u32) {
*r >>= self.bits;
}
fn update(&mut self) {
if self.more > 0 {
self.left = self.more;
self.more = 0;
self.incr += 1;
return;
}
if self.rescale != self.target_rescale {
self.rescale *= 2;
if self.rescale > self.target_rescale {
self.rescale = self.target_rescale;
}
}
let n = self.symbols;
let mut cf = self.cumf[n];
let mut count = cf;
for i in (0..n).rev() {
let mut sf = self.symf[i];
cf -= sf;
self.cumf[i] = cf;
sf = (sf >> 1) | 1; count -= sf;
self.symf[i] = sf;
}
self.incr = count / self.rescale as u32;
self.more = (count % self.rescale as u32) as i32;
self.left = self.rescale - self.more;
if let Some(ref mut search) = self.search {
let mut h = 1i32 << TABLE_SHIFT;
for i in (0..n).rev() {
let new_h = (self.cumf[i] >> self.search_shift) as i32;
for l in new_h..=h {
search[l as usize] = i as u32;
}
h = new_h;
}
}
}
#[inline]
fn update_symbol(&mut self, s: u32) {
if self.left == 0 {
self.update();
}
self.left -= 1;
self.symf[s as usize] += self.incr;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_default_params() {
let m = RCQsModel::with_defaults(true, 65);
assert_eq!(m.symbols(), 65);
}
#[test]
#[should_panic(expected = "bits must be <= 16")]
fn bits_too_large() {
RCQsModel::new(true, 10, 17, 0x400);
}
#[test]
#[should_panic(expected = "period too large")]
fn period_too_large() {
RCQsModel::new(true, 10, 16, 1 << 17);
}
#[test]
fn compress_mode_no_search_table() {
let m = RCQsModel::new(true, 10, 16, 0x400);
assert!(m.search.is_none());
}
#[test]
fn decompress_mode_has_search_table() {
let m = RCQsModel::new(false, 10, 16, 0x400);
assert!(m.search.is_some());
}
#[test]
fn encode_returns_valid_frequencies() {
let mut m = RCQsModel::with_defaults(true, 65);
for s in 0..65u32 {
let (cum, freq) = m.encode(s);
assert!(freq > 0, "freq must be > 0 for symbol {s}");
assert!(
cum + freq <= (1 << 16),
"cumulative overflow for symbol {s}"
);
}
}
#[test]
fn reset_restores_uniform() {
let mut m = RCQsModel::with_defaults(true, 10);
for _ in 0..100 {
m.encode(0);
}
m.reset();
let (_, f0) = m.encode(0);
let (_, f5) = m.encode(5);
let diff = (f0 as i64 - f5 as i64).unsigned_abs();
assert!(diff <= 2, "frequencies should be roughly equal after reset");
}
}