1use std::cmp::max;
2
3use fenwick::array::update;
4
5use crate::Model;
6
7pub enum EOFKind {
8 Specify(u32),
10 Start,
12 End,
14 EndAddOne,
16 None,
18}
19
20#[derive(Default)]
42pub struct Builder {
43 counts: Option<Vec<u32>>,
44 num_symbols: Option<u32>,
45 num_bits: Option<u32>,
46 eof: Option<EOFKind>,
47 pdf: Option<Vec<f32>>,
48 scale: Option<u32>,
49 binary: bool,
50}
51
52impl Builder {
53 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn num_symbols(&mut self, count: u32) -> &mut Self {
58 self.num_symbols = Some(count);
59 self
60 }
61
62 pub fn num_bits(&mut self, size: u32) -> &mut Self {
63 self.num_bits = Some(size);
64 self
65 }
66
67 pub fn counts(&mut self, counts: Vec<u32>) -> &mut Self {
70 self.counts = Some(counts);
71 self
72 }
73
74 pub fn eof(&mut self, eof: EOFKind) -> &mut Self {
81 self.eof = Some(eof);
82 self
83 }
84
85 pub fn scale(&mut self, mut scale: u32) -> &mut Self {
90 if scale < 10 {
91 scale = 10;
92 }
93 self.scale = Some(scale);
94 self
95 }
96
97 pub fn pdf(&mut self, pdf: Vec<f32>) -> &mut Self {
106 self.pdf = Some(pdf);
107 self
108 }
109
110 pub fn binary(&mut self) -> &mut Self {
112 self.binary = true;
113 self
114 }
115
116 pub fn build(&self) -> Model {
117 let mut counts = match &self.counts {
118 Some(counts) => counts.clone(),
119 None => match &self.pdf {
120 Some(pdf) => {
121 let scale = self.scale.unwrap_or_else(|| max(pdf.len() as u32, 10));
122 let scale = scale as f32;
123
124 pdf.iter()
125 .map(|p| max((p * scale) as i32, 1))
126 .map(|c| c as u32)
127 .collect()
128 }
129 None => match self.num_bits {
130 Some(num_bits) => vec![1; 1 << num_bits as usize],
131 None => match self.num_symbols {
132 Some(num_symbols) => vec![1; num_symbols as usize],
133 None => vec![1, 1], },
135 },
136 },
137 };
138
139 let eof = match &self.eof {
140 None => counts.len() as u32,
141 Some(eof_kind) => match eof_kind {
142 EOFKind::Specify(index) => {
143 assert!(*index < counts.len() as u32);
144 *index
145 }
146 EOFKind::Start => 0,
147 EOFKind::End => counts.len() as u32 - 1,
148 EOFKind::EndAddOne => {
149 counts.push(1);
150 counts.len() as u32 - 1
151 }
152 EOFKind::None => counts.len() as u32,
153 },
154 };
155
156 let mut fenwick_counts = vec![0u32; counts.len()];
157
158 for (i, count) in counts.iter().enumerate() {
159 update(&mut fenwick_counts, i, *count);
160 }
161
162 let total_count = counts.iter().sum();
163 Model::from_values(counts, fenwick_counts, total_count, eof)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::{EOFKind, Model};
170
171 fn model_eq(a: &Model, b: &Model) {
172 assert_eq!(a.eof(), b.eof(), "EOF not equal");
173 assert_eq!(a.counts(), b.counts(), "Counts not equal");
174 assert_eq!(a.fenwick_counts(), b.fenwick_counts(), "fenwicks not equal");
175 assert_eq!(a.total_count(), b.total_count(), "total not equal");
176 }
177
178 #[test]
179 fn num_symbols() {
180 let sut = Model::builder().num_symbols(4).build();
181
182 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
183
184 model_eq(&reference, &sut);
185 }
186
187 #[test]
188 fn num_bits() {
189 let sut = Model::builder().num_bits(2).build();
190
191 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
192
193 model_eq(&reference, &sut);
194 }
195
196 #[test]
197 fn counts() {
198 let sut = Model::builder().counts(vec![4, 1, 3, 1]).build();
199
200 let reference = Model::from_values(vec![4, 1, 3, 1], vec![4, 5, 3, 9], 9, 4);
201
202 model_eq(&reference, &sut);
203 }
204
205 #[test]
206 fn pdf() {
207 let sut = Model::builder().pdf(vec![0.4, 0.2, 0.3, 0.1]).build();
208
209 let reference = Model::from_values(vec![4, 2, 3, 1], vec![4, 6, 3, 10], 10, 4);
210
211 model_eq(&reference, &sut);
212 }
213
214 #[test]
215 fn pdf_scale() {
216 let sut = Model::builder()
217 .pdf(vec![0.4, 0.2, 0.3, 0.1])
218 .scale(20)
219 .build();
220
221 let reference = Model::from_values(vec![8, 4, 6, 2], vec![8, 12, 6, 20], 20, 4);
222
223 model_eq(&reference, &sut);
224 }
225
226 #[test]
227 fn pdf_scale_defaults_length() {
228 let sut = Model::builder()
229 .pdf(vec![
230 0.4, 0.2, 0.3, 0.1, 0.4, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.2, 0.3, 0.0, 0.0,
231 ])
232 .build();
233
234 let reference = Model::from_values(
235 vec![6, 3, 4, 1, 6, 3, 4, 6, 3, 4, 6, 3, 4, 1, 1],
236 vec![6, 9, 4, 14, 6, 9, 4, 33, 3, 7, 6, 16, 4, 5, 1],
237 55,
238 15,
239 );
240
241 model_eq(&reference, &sut);
242 }
243
244 #[test]
245 fn binary() {
246 let sut = Model::builder().binary().build();
247
248 let reference = Model::from_values(vec![1, 1], vec![1, 2], 2, 2);
249
250 model_eq(&reference, &sut);
251 }
252
253 #[test]
254 fn default_binary() {
255 let sut = Model::builder().build();
256
257 let reference = Model::from_values(vec![1, 1], vec![1, 2], 2, 2);
258
259 model_eq(&reference, &sut);
260 }
261
262 #[test]
263 fn eof_end() {
264 let sut = Model::builder().num_symbols(4).eof(EOFKind::End).build();
265
266 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 3);
267
268 model_eq(&reference, &sut);
269 }
270
271 #[test]
272 fn eof_end_add() {
273 let sut = Model::builder()
274 .num_symbols(4)
275 .eof(EOFKind::EndAddOne)
276 .build();
277
278 let reference = Model::from_values(vec![1, 1, 1, 1, 1], vec![1, 2, 1, 4, 1], 5, 4);
279
280 model_eq(&reference, &sut);
281 }
282
283 #[test]
284 fn eof_start() {
285 let sut = Model::builder().num_symbols(4).eof(EOFKind::Start).build();
286
287 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 0);
288
289 model_eq(&reference, &sut);
290 }
291
292 #[test]
293 fn eof_specify() {
294 let sut = Model::builder()
295 .num_symbols(4)
296 .eof(EOFKind::Specify(2))
297 .build();
298
299 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 2);
300
301 model_eq(&reference, &sut);
302 }
303
304 #[test]
305 fn eof_none() {
306 let sut = Model::builder().num_symbols(4).eof(EOFKind::None).build();
307
308 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
309
310 model_eq(&reference, &sut);
311 }
312
313 #[test]
314 fn eof_default() {
315 let sut = Model::builder().num_symbols(4).build();
316
317 let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
318
319 model_eq(&reference, &sut);
320 }
321}