arcode/
model.rs

1use fenwick::array::{prefix_sum, update};
2
3mod builder;
4pub use builder::{Builder, EOFKind};
5
6/// Symbol table for the encoder/decoder.
7/// Used to store the probabilities as a vector of counts
8/// (number of occurrences). Uniform would be every symbol has
9/// a count of 0.
10pub struct Model {
11    counts: Vec<u32>,
12    fenwick_counts: Vec<u32>,
13    total_count: u32,
14    eof: u32,
15    num_symbols: u32,
16}
17
18impl Model {
19    pub fn builder() -> Builder {
20        Builder::new()
21    }
22
23    /// For loading a saved model. Use the
24    /// [`Builder`] for
25    /// more options.
26    pub fn from_values(
27        counts: Vec<u32>,
28        fenwick_counts: Vec<u32>,
29        total_count: u32,
30        eof: u32,
31    ) -> Self {
32        Self {
33            num_symbols: counts.len() as u32,
34            counts,
35            fenwick_counts,
36            total_count,
37            eof,
38        }
39    }
40
41    pub fn update_symbol(&mut self, symbol: u32) {
42        self.total_count += 1;
43        self.counts[symbol as usize] += 1;
44        update(&mut self.fenwick_counts, symbol as usize, 1);
45    }
46
47    pub const fn num_symbols(&self) -> u32 {
48        self.num_symbols
49    }
50
51    pub fn high(&self, index: u32) -> f64 {
52        let high = fenwick::array::prefix_sum(&self.fenwick_counts, index as usize);
53        f64::from(high) / f64::from(self.total_count)
54    }
55
56    pub fn low(&self, index: u32) -> f64 {
57        let low = fenwick::array::prefix_sum(&self.fenwick_counts, index as usize)
58            - self.counts[index as usize];
59        f64::from(low) / f64::from(self.total_count)
60    }
61
62    pub fn probability(&self, symbol: u32) -> (f64, f64) {
63        let total = f64::from(self.total_count);
64
65        let high = prefix_sum(&self.fenwick_counts, symbol as usize);
66        let low = high - self.counts[symbol as usize];
67
68        (f64::from(low) / total, f64::from(high) / total)
69    }
70
71    pub const fn eof(&self) -> u32 {
72        self.eof
73    }
74
75    pub const fn counts(&self) -> &Vec<u32> {
76        &self.counts
77    }
78
79    pub const fn fenwick_counts(&self) -> &Vec<u32> {
80        &self.fenwick_counts
81    }
82
83    pub const fn total_count(&self) -> u32 {
84        self.total_count
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::{EOFKind, Model};
91
92    #[test]
93    fn constructor() {
94        let model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
95
96        assert_eq!(3, model.eof());
97        assert_eq!(model.probability(0), (0.0, 0.25));
98        assert_eq!(model.probability(1), (0.25, 0.5));
99        assert_eq!(model.probability(2), (0.5, 0.75));
100        assert_eq!(model.probability(3), (0.75, 1.0));
101    }
102
103    #[test]
104    fn constructor_new() {
105        let model = Model::builder().num_symbols(4).build();
106        assert_eq!(4, model.eof());
107        assert_eq!(model.probability(0), (0.0, 0.25));
108        assert_eq!(model.probability(1), (0.25, 0.5));
109        assert_eq!(model.probability(2), (0.5, 0.75));
110        assert_eq!(model.probability(3), (0.75, 1.0));
111    }
112
113    #[test]
114    fn constructor_binary() {
115        let binary = Model::builder().binary().build();
116        let model = Model::builder().num_symbols(2).build();
117
118        assert_eq!(binary.eof(), model.eof());
119        assert_eq!(binary.probability(0), model.probability(0));
120        assert_eq!(binary.probability(1), model.probability(1));
121    }
122
123    #[test]
124    fn constructor_from_counts() {
125        let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
126
127        let counts_model = Model::builder()
128            .counts(vec![1; 4])
129            .eof(EOFKind::End)
130            .build();
131
132        assert_eq!(3, model.eof());
133        assert_eq!(model.probability(0), counts_model.probability(0));
134        assert_eq!(model.probability(1), counts_model.probability(1));
135        assert_eq!(model.probability(2), counts_model.probability(2));
136        assert_eq!(model.probability(3), counts_model.probability(3));
137
138        model.update_symbol(0);
139        model.update_symbol(0);
140        model.update_symbol(0);
141        model.update_symbol(2);
142        model.update_symbol(2);
143
144        let counts_model = Model::builder()
145            .counts(vec![4, 1, 3, 1])
146            .eof(EOFKind::End)
147            .build();
148        assert_eq!(model.probability(0), counts_model.probability(0));
149        assert_eq!(model.probability(1), counts_model.probability(1));
150        assert_eq!(model.probability(2), counts_model.probability(2));
151        assert_eq!(model.probability(3), counts_model.probability(3));
152    }
153
154    #[test]
155    fn constructor_from_pdf() {
156        let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
157
158        let pdf_model = Model::builder()
159            .pdf(vec![0.25f32; 4])
160            .eof(EOFKind::End)
161            .build();
162
163        assert_eq!(3, model.eof());
164        assert_eq!(model.probability(0), pdf_model.probability(0));
165        assert_eq!(model.probability(1), pdf_model.probability(1));
166        assert_eq!(model.probability(2), pdf_model.probability(2));
167        assert_eq!(model.probability(3), pdf_model.probability(3));
168
169        model.update_symbol(0);
170        model.update_symbol(0);
171        model.update_symbol(0);
172        model.update_symbol(1);
173        model.update_symbol(2);
174        model.update_symbol(2);
175
176        let pdf_model = Model::builder()
177            .pdf(vec![0.4, 0.2, 0.3, 0.1])
178            .eof(EOFKind::End)
179            .build();
180
181        assert_eq!(model.probability(0), pdf_model.probability(0));
182        assert_eq!(model.probability(1), pdf_model.probability(1));
183        assert_eq!(model.probability(2), pdf_model.probability(2));
184        assert_eq!(model.probability(3), pdf_model.probability(3));
185    }
186
187    #[test]
188    fn probability_min() {
189        let model = Model::builder().num_symbols(2315).build();
190        assert_eq!(model.probability(0), (model.low(0), model.high(0)));
191    }
192
193    #[test]
194    fn probability_high() {
195        let count = 1_000;
196
197        let model = Model::builder().num_symbols(count + 1).build();
198
199        assert_eq!(
200            model.probability(count),
201            (model.low(count), model.high(count))
202        );
203    }
204
205    #[test]
206    fn update_symbols() {
207        let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
208
209        model.update_symbol(2);
210        model.update_symbol(2);
211        model.update_symbol(2);
212        model.update_symbol(3);
213        model.update_symbol(1);
214        model.update_symbol(3);
215
216        assert_eq!(model.probability(0), (0.0, 0.1));
217        assert_eq!(model.probability(1), (0.1, 0.3));
218        assert_eq!(model.probability(2), (0.3, 0.7));
219        assert_eq!(model.probability(3), (0.7, 1.0));
220    }
221}