use std::cmp::max;
use fenwick::array::update;
use crate::Model;
pub enum EOFKind {
Specify(u32),
Start,
End,
EndAddOne,
None,
}
#[derive(Default)]
pub struct Builder {
counts: Option<Vec<u32>>,
num_symbols: Option<u32>,
num_bits: Option<u32>,
eof: Option<EOFKind>,
pdf: Option<Vec<f32>>,
scale: Option<u32>,
binary: bool,
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn num_symbols(&mut self, count: u32) -> &mut Self {
self.num_symbols = Some(count);
self
}
pub fn num_bits(&mut self, size: u32) -> &mut Self {
self.num_bits = Some(size);
self
}
pub fn counts(&mut self, counts: Vec<u32>) -> &mut Self {
self.counts = Some(counts);
self
}
pub fn eof(&mut self, eof: EOFKind) -> &mut Self {
self.eof = Some(eof);
self
}
pub fn scale(&mut self, mut scale: u32) -> &mut Self {
if scale < 10 {
scale = 10;
}
self.scale = Some(scale);
self
}
pub fn pdf(&mut self, pdf: Vec<f32>) -> &mut Self {
self.pdf = Some(pdf);
self
}
pub fn binary(&mut self) -> &mut Self {
self.binary = true;
self
}
pub fn build(&self) -> Model {
let mut counts = match &self.counts {
Some(counts) => counts.clone(),
None => match &self.pdf {
Some(pdf) => {
let scale = self.scale.unwrap_or_else(|| max(pdf.len() as u32, 10));
let scale = scale as f32;
pdf.iter()
.map(|p| max((p * scale) as i32, 1))
.map(|c| c as u32)
.collect()
}
None => match self.num_bits {
Some(num_bits) => vec![1; 1 << num_bits as usize],
None => match self.num_symbols {
Some(num_symbols) => vec![1; num_symbols as usize],
None => vec![1, 1], },
},
},
};
let eof = match &self.eof {
None => counts.len() as u32,
Some(eof_kind) => match eof_kind {
EOFKind::Specify(index) => {
assert!(*index < counts.len() as u32);
*index
}
EOFKind::Start => 0,
EOFKind::End => counts.len() as u32 - 1,
EOFKind::EndAddOne => {
counts.push(1);
counts.len() as u32 - 1
}
EOFKind::None => counts.len() as u32,
},
};
let mut fenwick_counts = vec![0u32; counts.len()];
for (i, count) in counts.iter().enumerate() {
update(&mut fenwick_counts, i, *count);
}
let total_count = counts.iter().sum();
Model::from_values(counts, fenwick_counts, total_count, eof)
}
}
#[cfg(test)]
mod tests {
use super::{EOFKind, Model};
fn model_eq(a: &Model, b: &Model) {
assert_eq!(a.eof(), b.eof(), "EOF not equal");
assert_eq!(a.counts(), b.counts(), "Counts not equal");
assert_eq!(a.fenwick_counts(), b.fenwick_counts(), "fenwicks not equal");
assert_eq!(a.total_count(), b.total_count(), "total not equal");
}
#[test]
fn num_symbols() {
let sut = Model::builder().num_symbols(4).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
model_eq(&reference, &sut);
}
#[test]
fn num_bits() {
let sut = Model::builder().num_bits(2).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
model_eq(&reference, &sut);
}
#[test]
fn counts() {
let sut = Model::builder().counts(vec![4, 1, 3, 1]).build();
let reference = Model::from_values(vec![4, 1, 3, 1], vec![4, 5, 3, 9], 9, 4);
model_eq(&reference, &sut);
}
#[test]
fn pdf() {
let sut = Model::builder().pdf(vec![0.4, 0.2, 0.3, 0.1]).build();
let reference = Model::from_values(vec![4, 2, 3, 1], vec![4, 6, 3, 10], 10, 4);
model_eq(&reference, &sut);
}
#[test]
fn pdf_scale() {
let sut = Model::builder()
.pdf(vec![0.4, 0.2, 0.3, 0.1])
.scale(20)
.build();
let reference = Model::from_values(vec![8, 4, 6, 2], vec![8, 12, 6, 20], 20, 4);
model_eq(&reference, &sut);
}
#[test]
fn pdf_scale_defaults_length() {
let sut = Model::builder()
.pdf(vec![
0.4, 0.2, 0.3, 0.1, 0.4, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.2, 0.3, 0.0, 0.0,
])
.build();
let reference = Model::from_values(
vec![6, 3, 4, 1, 6, 3, 4, 6, 3, 4, 6, 3, 4, 1, 1],
vec![6, 9, 4, 14, 6, 9, 4, 33, 3, 7, 6, 16, 4, 5, 1],
55,
15,
);
model_eq(&reference, &sut);
}
#[test]
fn binary() {
let sut = Model::builder().binary().build();
let reference = Model::from_values(vec![1, 1], vec![1, 2], 2, 2);
model_eq(&reference, &sut);
}
#[test]
fn default_binary() {
let sut = Model::builder().build();
let reference = Model::from_values(vec![1, 1], vec![1, 2], 2, 2);
model_eq(&reference, &sut);
}
#[test]
fn eof_end() {
let sut = Model::builder().num_symbols(4).eof(EOFKind::End).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 3);
model_eq(&reference, &sut);
}
#[test]
fn eof_end_add() {
let sut = Model::builder()
.num_symbols(4)
.eof(EOFKind::EndAddOne)
.build();
let reference = Model::from_values(vec![1, 1, 1, 1, 1], vec![1, 2, 1, 4, 1], 5, 4);
model_eq(&reference, &sut);
}
#[test]
fn eof_start() {
let sut = Model::builder().num_symbols(4).eof(EOFKind::Start).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 0);
model_eq(&reference, &sut);
}
#[test]
fn eof_specify() {
let sut = Model::builder()
.num_symbols(4)
.eof(EOFKind::Specify(2))
.build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 2);
model_eq(&reference, &sut);
}
#[test]
fn eof_none() {
let sut = Model::builder().num_symbols(4).eof(EOFKind::None).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
model_eq(&reference, &sut);
}
#[test]
fn eof_default() {
let sut = Model::builder().num_symbols(4).build();
let reference = Model::from_values(vec![1, 1, 1, 1], vec![1, 2, 1, 4], 4, 4);
model_eq(&reference, &sut);
}
}