pub mod builder;
use crate::builder::MultiMarkovBuilder;
use rand::{Rng, RngCore};
use std::cmp::min;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt;
use std::hash::Hash;
pub struct MultiMarkov<T>
where
T: Eq + Hash + Clone + std::cmp::Ord,
{
pub markov_chain: HashMap<Vec<T>, BTreeMap<T, f64>>,
pub known_states: HashSet<T>,
pub order: i32,
pub rng: Box<dyn RngCore + Send + Sync>,
}
impl<T> MultiMarkov<T>
where
T: Eq + Hash + Clone + std::cmp::Ord,
{
pub const DEFAULT_ORDER: i32 = 3;
pub const DEFAULT_PRIOR: f64 = 0.005;
pub fn builder() -> MultiMarkovBuilder<T> {
MultiMarkovBuilder::<T>::new()
}
pub fn random_next(&mut self, current_sequence: &Vec<T>) -> Option<T> {
let r: f64 = self.rng.gen();
let bestmodel = self.best_model(current_sequence)?;
let sum_of_weights: f64 = bestmodel.values().sum();
let mut randomroll = r * sum_of_weights; for (k, v) in bestmodel {
if randomroll > *v {
randomroll -= v;
} else {
return Some(k.clone());
}
}
None }
fn best_model(&self, current_sequence: &Vec<T>) -> Option<&BTreeMap<T, f64>> {
for i in (1..(min(self.order as usize, current_sequence.len()) + 1)).rev() {
let subsequence =
¤t_sequence[(current_sequence.len() - i)..current_sequence.len()];
if self.markov_chain.contains_key(subsequence) {
return self.markov_chain.get(subsequence);
}
}
None
}
}
impl<T> fmt::Debug for MultiMarkov<T>
where
T: Eq + Hash + Clone + std::cmp::Ord,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(&format!(
"MultiMarkov<{}>(trained)",
std::any::type_name::<T>()
))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn char_data() -> Vec<Vec<char>> {
vec![
vec!['a'], vec!['a', 'c', 'e'],
vec!['f', 'o', 'o', 'b', 'a', 'r'],
vec!['b', 'a', 'z'],
]
}
#[test]
fn test_model_builder_works() {
let mut mm = MultiMarkov::<char>::builder()
.with_order(2)
.with_prior(0.015)
.train(char_data().into_iter())
.build();
assert!(mm.random_next(&vec!['a', 'b', 'c']).is_some()); assert!(mm.random_next(&vec!['x', 'y', 'z']).is_none()); }
#[test]
fn test_debug_implementation() {
let mm = MultiMarkov::<char>::builder()
.with_order(2)
.with_prior(0.015)
.train(char_data().into_iter())
.build();
assert_eq!(format!("{:?}", mm), "MultiMarkov<char>(trained)");
}
#[test]
fn test_model_weights_and_priors_are_correct() {
let mm = MultiMarkov::<char>::builder()
.with_order(2)
.with_prior(0.001)
.train(char_data().into_iter())
.build();
let chain = &mm.markov_chain;
assert_eq!(*chain.get(&*vec!['b']).unwrap().get(&'a').unwrap(), 2.0); assert_eq!(*chain.get(&*vec!['a']).unwrap().get(&'c').unwrap(), 1.0); assert_eq!(*chain.get(&*vec!['a']).unwrap().get(&'e').unwrap(), 0.001); }
}