use std::collections::HashMap;
use std::iter;
use std::hash::Hash;
pub trait ChainItem: Eq + Hash + Clone {}
impl<T: Eq + Hash + Clone> ChainItem for T {}
mod weighted_set;
use weighted_set::WeightedSet;
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
enum MarkovItem<I: ChainItem> {
Start,
Mid(I),
End,
}
type MarkovSymbolTable<I> = HashMap<Vec<MarkovItem<I>>, WeightedSet<MarkovItem<I>>>;
pub struct MarkovIterator<I: ChainItem> {
order: usize,
prev_symbols: Vec<MarkovItem<I>>,
symbols: MarkovSymbolTable<I>,
}
impl<I: ChainItem> Iterator for MarkovIterator<I> {
type Item = I;
fn next(&mut self) -> Option<I> {
self.prev_symbols
.push(self.symbols[&self.prev_symbols].choose().cloned()?);
if self.prev_symbols.len() > self.order {
self.prev_symbols.remove(0);
}
if let MarkovItem::Mid(symbol) = self.prev_symbols.iter().last().cloned()? {
Some(symbol)
} else {
None
}
}
}
#[derive(Clone, Debug)]
pub struct MarkovChain<I: ChainItem> {
order: usize,
symbols: MarkovSymbolTable<I>,
}
impl<I: ChainItem> MarkovChain<I> {
pub fn with_order<T>(order: usize) -> MarkovChain<I> {
MarkovChain {
order,
symbols: HashMap::new(),
}
}
fn update(&mut self, item: &[MarkovItem<I>], next_item: &MarkovItem<I>) {
if let Some(set) = self.symbols.get_mut(item) {
set.count(next_item);
} else {
let mut set = WeightedSet::default();
set.count(next_item);
self.symbols.insert(item.to_vec(), set);
}
}
pub fn feed(&mut self, input: impl Iterator<Item = I>) {
let items = iter::once(MarkovItem::Start)
.chain(input.map(move |i| MarkovItem::Mid(i.clone())))
.chain(iter::once(MarkovItem::End))
.collect::<Vec<_>>();
(1..self.order)
.map(|i| &items[..=i])
.chain(items.windows(self.order + 1))
.for_each(|window| {
let (present, past) = window.split_last().unwrap();
self.update(past, present);
});
}
}
impl<I: ChainItem> IntoIterator for MarkovChain<I> {
type Item = I;
type IntoIter = MarkovIterator<I>;
fn into_iter(self) -> Self::IntoIter {
let mut prev_symbols = Vec::with_capacity(2);
prev_symbols.push(MarkovItem::Start);
MarkovIterator {
order: self.order,
prev_symbols,
symbols: self.symbols,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chain() {
let mut chain = MarkovChain::with_order::<char>(1);
chain.feed("hello world".chars());
chain.feed("foo bar baz".chars());
println!("{}", chain.into_iter().collect::<String>());
}
}