use std::collections::HashMap;
use crate::tokens::Token;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Model<T: Token>(HashMap<T, Stats>);
#[derive(Clone, Debug, Default, PartialEq)]
struct Stats {
f: u64,
p: f64,
}
impl<T: Token> Model<T> {
pub fn frequency(&self, t: &T) -> u64 {
match self.0.get(t) {
Some(s) => s.f,
None => 0,
}
}
pub fn probability(&self, t: &T) -> f64 {
match self.0.get(t) {
Some(s) => s.p,
None => 0.0,
}
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn tokens_sorted(&self) -> Vec<T> {
let mut keys = Vec::with_capacity(self.0.len());
for k in self.0.keys() {
keys.push((*k).clone());
}
keys.sort_unstable_by(|x, y| self.frequency(y).cmp(&self.frequency(x)));
keys
}
}
pub fn from<K, KS>(ts: KS) -> Model<K>
where
K: Token,
KS: std::iter::IntoIterator<Item = K>,
{
let mut m = Model::<K>(HashMap::new());
let mut d: i64 = 0;
for t in ts {
let s = m.0.entry(t).or_insert(Stats { f: 0, p: 0.0 });
(*s).f += 1;
d += 1;
}
for s in m.0.values_mut() {
(*s).p = ((*s).f as f64) / (d as f64);
}
m
}
pub fn with_frequencies<K: Token>(fs: &[(K, u64)]) -> Model<K> {
let fs: HashMap<K, u64> = fs.to_vec().into_iter().collect();
let total = fs.values().sum::<u64>() as f64;
let mut m = Model::<K>(HashMap::new());
for (t, f) in fs.into_iter() {
m.0.insert(
t,
Stats {
f,
p: (f as f64) / total,
},
);
}
m
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::test_utils::I32Token;
#[test]
fn basic() {
let tokens = vec![
I32Token(2),
I32Token(3),
I32Token(1),
I32Token(2),
I32Token(5),
I32Token(11),
];
let m = from(tokens);
assert_eq!(m.frequency(&I32Token(1)), 1);
assert_eq!(m.frequency(&I32Token(2)), 2);
assert_eq!(m.frequency(&I32Token(13)), 0);
assert!(m.probability(&I32Token(5)) > 0.166);
assert!(m.probability(&I32Token(5)) < 0.167);
}
#[test]
fn with_frequencies() {
let m = super::with_frequencies(&[
(I32Token(2), 2),
(I32Token(3), 1),
(I32Token(1), 1),
(I32Token(5), 1),
(I32Token(11), 1),
]);
assert_eq!(m.frequency(&I32Token(1)), 1);
assert_eq!(m.frequency(&I32Token(2)), 2);
assert_eq!(m.frequency(&I32Token(13)), 0);
assert!(m.probability(&I32Token(5)) > 0.166);
assert!(m.probability(&I32Token(5)) < 0.167);
}
}