fenwick_model/
context_switching.rs

1//! Fenwick tree based context-switching model
2
3use arithmetic_coding_core::Model;
4
5use super::Weights;
6use crate::ValueError;
7
8#[derive(Debug, Clone)]
9pub struct FenwickModel {
10    contexts: Vec<Weights>,
11    current_context: usize,
12    max_denominator: u64,
13}
14
15impl FenwickModel {
16    #[must_use]
17    pub fn with_symbols(symbols: usize, max_denominator: u64) -> Self {
18        let mut contexts = Vec::with_capacity(symbols + 1);
19
20        for _ in 0..=symbols {
21            contexts.push(Weights::new(symbols));
22        }
23
24        Self {
25            contexts,
26            current_context: 1,
27            max_denominator,
28        }
29    }
30
31    fn context(&self) -> &Weights {
32        &self.contexts[self.current_context]
33    }
34
35    fn context_mut(&mut self) -> &mut Weights {
36        &mut self.contexts[self.current_context]
37    }
38}
39
40impl Model for FenwickModel {
41    type B = u64;
42    type Symbol = usize;
43    type ValueError = ValueError;
44
45    fn probability(&self, symbol: Option<&usize>) -> Result<std::ops::Range<u64>, ValueError> {
46        Ok(self.context().range(symbol.copied()))
47    }
48
49    fn denominator(&self) -> u64 {
50        self.context().total
51    }
52
53    fn max_denominator(&self) -> u64 {
54        self.max_denominator
55    }
56
57    fn symbol(&self, value: u64) -> Option<usize> {
58        self.context().symbol(value)
59    }
60
61    fn update(&mut self, symbol: Option<&usize>) {
62        debug_assert!(
63            self.denominator() < self.max_denominator,
64            "hit max denominator!"
65        );
66        if self.denominator() < self.max_denominator {
67            self.context_mut().update(symbol.copied(), 1);
68        }
69        self.current_context = symbol.map(|x| x + 1).unwrap_or_default();
70    }
71}