use arithmetic_coding_adder_dep::Model;
use super::Weights;
use crate::codec::compressed::fenwick::ValueError;
#[derive(Debug, Clone)]
pub struct FenwickModel {
contexts: Vec<Weights>,
current_context: usize,
max_denominator: u64,
}
impl FenwickModel {
#[must_use]
pub fn with_symbols(symbols: usize, max_denominator: u64) -> Self {
let mut contexts = Vec::with_capacity(10);
contexts.push(Weights::new(symbols));
Self {
contexts,
current_context: 0,
max_denominator,
}
}
pub fn push_context(&mut self, symbols: usize) -> (usize, &mut Weights) {
self.contexts.push(Weights::new(symbols));
let index = self.contexts.len() - 1;
(index, &mut self.contexts[index])
}
pub fn push_context_with_weights(&mut self, weights: Weights) -> usize {
self.contexts.push(weights);
self.contexts.len() - 1
}
pub fn set_context(&mut self, context: usize) {
self.current_context = context;
}
fn context(&self) -> &Weights {
&self.contexts[self.current_context]
}
fn context_mut(&mut self) -> &mut Weights {
&mut self.contexts[self.current_context]
}
}
impl Model for FenwickModel {
type B = u64;
type Symbol = usize;
type ValueError = ValueError;
fn probability(&self, symbol: Option<&usize>) -> Result<std::ops::Range<u64>, ValueError> {
Ok(self.context().range(symbol.copied()))
}
fn denominator(&self) -> u64 {
self.context().total
}
fn max_denominator(&self) -> u64 {
self.max_denominator
}
fn symbol(&self, value: u64) -> Option<usize> {
self.context().symbol(value)
}
fn update(&mut self, symbol: Option<&usize>) {
debug_assert!(
self.denominator() < self.max_denominator,
"hit max denominator!"
);
if self.denominator() < self.max_denominator {
self.context_mut().update(symbol.copied(), 1);
}
}
}