1use fenwick::array::{prefix_sum, update};
2
3mod builder;
4pub use builder::{Builder, EOFKind};
5
6pub struct Model {
11 counts: Vec<u32>,
12 fenwick_counts: Vec<u32>,
13 total_count: u32,
14 eof: u32,
15 num_symbols: u32,
16}
17
18impl Model {
19 pub fn builder() -> Builder {
20 Builder::new()
21 }
22
23 pub fn from_values(
27 counts: Vec<u32>,
28 fenwick_counts: Vec<u32>,
29 total_count: u32,
30 eof: u32,
31 ) -> Self {
32 Self {
33 num_symbols: counts.len() as u32,
34 counts,
35 fenwick_counts,
36 total_count,
37 eof,
38 }
39 }
40
41 pub fn update_symbol(&mut self, symbol: u32) {
42 self.total_count += 1;
43 self.counts[symbol as usize] += 1;
44 update(&mut self.fenwick_counts, symbol as usize, 1);
45 }
46
47 pub const fn num_symbols(&self) -> u32 {
48 self.num_symbols
49 }
50
51 pub fn high(&self, index: u32) -> f64 {
52 let high = fenwick::array::prefix_sum(&self.fenwick_counts, index as usize);
53 f64::from(high) / f64::from(self.total_count)
54 }
55
56 pub fn low(&self, index: u32) -> f64 {
57 let low = fenwick::array::prefix_sum(&self.fenwick_counts, index as usize)
58 - self.counts[index as usize];
59 f64::from(low) / f64::from(self.total_count)
60 }
61
62 pub fn probability(&self, symbol: u32) -> (f64, f64) {
63 let total = f64::from(self.total_count);
64
65 let high = prefix_sum(&self.fenwick_counts, symbol as usize);
66 let low = high - self.counts[symbol as usize];
67
68 (f64::from(low) / total, f64::from(high) / total)
69 }
70
71 pub const fn eof(&self) -> u32 {
72 self.eof
73 }
74
75 pub const fn counts(&self) -> &Vec<u32> {
76 &self.counts
77 }
78
79 pub const fn fenwick_counts(&self) -> &Vec<u32> {
80 &self.fenwick_counts
81 }
82
83 pub const fn total_count(&self) -> u32 {
84 self.total_count
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::{EOFKind, Model};
91
92 #[test]
93 fn constructor() {
94 let model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
95
96 assert_eq!(3, model.eof());
97 assert_eq!(model.probability(0), (0.0, 0.25));
98 assert_eq!(model.probability(1), (0.25, 0.5));
99 assert_eq!(model.probability(2), (0.5, 0.75));
100 assert_eq!(model.probability(3), (0.75, 1.0));
101 }
102
103 #[test]
104 fn constructor_new() {
105 let model = Model::builder().num_symbols(4).build();
106 assert_eq!(4, model.eof());
107 assert_eq!(model.probability(0), (0.0, 0.25));
108 assert_eq!(model.probability(1), (0.25, 0.5));
109 assert_eq!(model.probability(2), (0.5, 0.75));
110 assert_eq!(model.probability(3), (0.75, 1.0));
111 }
112
113 #[test]
114 fn constructor_binary() {
115 let binary = Model::builder().binary().build();
116 let model = Model::builder().num_symbols(2).build();
117
118 assert_eq!(binary.eof(), model.eof());
119 assert_eq!(binary.probability(0), model.probability(0));
120 assert_eq!(binary.probability(1), model.probability(1));
121 }
122
123 #[test]
124 fn constructor_from_counts() {
125 let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
126
127 let counts_model = Model::builder()
128 .counts(vec![1; 4])
129 .eof(EOFKind::End)
130 .build();
131
132 assert_eq!(3, model.eof());
133 assert_eq!(model.probability(0), counts_model.probability(0));
134 assert_eq!(model.probability(1), counts_model.probability(1));
135 assert_eq!(model.probability(2), counts_model.probability(2));
136 assert_eq!(model.probability(3), counts_model.probability(3));
137
138 model.update_symbol(0);
139 model.update_symbol(0);
140 model.update_symbol(0);
141 model.update_symbol(2);
142 model.update_symbol(2);
143
144 let counts_model = Model::builder()
145 .counts(vec![4, 1, 3, 1])
146 .eof(EOFKind::End)
147 .build();
148 assert_eq!(model.probability(0), counts_model.probability(0));
149 assert_eq!(model.probability(1), counts_model.probability(1));
150 assert_eq!(model.probability(2), counts_model.probability(2));
151 assert_eq!(model.probability(3), counts_model.probability(3));
152 }
153
154 #[test]
155 fn constructor_from_pdf() {
156 let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
157
158 let pdf_model = Model::builder()
159 .pdf(vec![0.25f32; 4])
160 .eof(EOFKind::End)
161 .build();
162
163 assert_eq!(3, model.eof());
164 assert_eq!(model.probability(0), pdf_model.probability(0));
165 assert_eq!(model.probability(1), pdf_model.probability(1));
166 assert_eq!(model.probability(2), pdf_model.probability(2));
167 assert_eq!(model.probability(3), pdf_model.probability(3));
168
169 model.update_symbol(0);
170 model.update_symbol(0);
171 model.update_symbol(0);
172 model.update_symbol(1);
173 model.update_symbol(2);
174 model.update_symbol(2);
175
176 let pdf_model = Model::builder()
177 .pdf(vec![0.4, 0.2, 0.3, 0.1])
178 .eof(EOFKind::End)
179 .build();
180
181 assert_eq!(model.probability(0), pdf_model.probability(0));
182 assert_eq!(model.probability(1), pdf_model.probability(1));
183 assert_eq!(model.probability(2), pdf_model.probability(2));
184 assert_eq!(model.probability(3), pdf_model.probability(3));
185 }
186
187 #[test]
188 fn probability_min() {
189 let model = Model::builder().num_symbols(2315).build();
190 assert_eq!(model.probability(0), (model.low(0), model.high(0)));
191 }
192
193 #[test]
194 fn probability_high() {
195 let count = 1_000;
196
197 let model = Model::builder().num_symbols(count + 1).build();
198
199 assert_eq!(
200 model.probability(count),
201 (model.low(count), model.high(count))
202 );
203 }
204
205 #[test]
206 fn update_symbols() {
207 let mut model = Model::builder().num_symbols(4).eof(EOFKind::End).build();
208
209 model.update_symbol(2);
210 model.update_symbol(2);
211 model.update_symbol(2);
212 model.update_symbol(3);
213 model.update_symbol(1);
214 model.update_symbol(3);
215
216 assert_eq!(model.probability(0), (0.0, 0.1));
217 assert_eq!(model.probability(1), (0.1, 0.3));
218 assert_eq!(model.probability(2), (0.3, 0.7));
219 assert_eq!(model.probability(3), (0.7, 1.0));
220 }
221}