fenwick_model/
context_switching.rs1use arithmetic_coding_core::Model;
4
5use super::Weights;
6use crate::ValueError;
7
8#[derive(Debug, Clone)]
9pub struct FenwickModel {
10 contexts: Vec<Weights>,
11 current_context: usize,
12 max_denominator: u64,
13}
14
15impl FenwickModel {
16 #[must_use]
17 pub fn with_symbols(symbols: usize, max_denominator: u64) -> Self {
18 let mut contexts = Vec::with_capacity(symbols + 1);
19
20 for _ in 0..=symbols {
21 contexts.push(Weights::new(symbols));
22 }
23
24 Self {
25 contexts,
26 current_context: 1,
27 max_denominator,
28 }
29 }
30
31 fn context(&self) -> &Weights {
32 &self.contexts[self.current_context]
33 }
34
35 fn context_mut(&mut self) -> &mut Weights {
36 &mut self.contexts[self.current_context]
37 }
38}
39
40impl Model for FenwickModel {
41 type B = u64;
42 type Symbol = usize;
43 type ValueError = ValueError;
44
45 fn probability(&self, symbol: Option<&usize>) -> Result<std::ops::Range<u64>, ValueError> {
46 Ok(self.context().range(symbol.copied()))
47 }
48
49 fn denominator(&self) -> u64 {
50 self.context().total
51 }
52
53 fn max_denominator(&self) -> u64 {
54 self.max_denominator
55 }
56
57 fn symbol(&self, value: u64) -> Option<usize> {
58 self.context().symbol(value)
59 }
60
61 fn update(&mut self, symbol: Option<&usize>) {
62 debug_assert!(
63 self.denominator() < self.max_denominator,
64 "hit max denominator!"
65 );
66 if self.denominator() < self.max_denominator {
67 self.context_mut().update(symbol.copied(), 1);
68 }
69 self.current_context = symbol.map(|x| x + 1).unwrap_or_default();
70 }
71}