gmarkov-lib 0.1.1

A library that provides Markov chain data structures
Documentation
//! A library that provides Markov chain data structures.
//!
//! The `MarkovChain` structure allows you to feed in several sequences of items
//! and get out a sequence that looks very similar, but is randomly generated.
//!
//! This is the library to my CLI app [`gmarkov`].
//!
//! # Example
//! ```
//! extern crate gmarkov_lib;
//!
//! use std::fs::File;
//! use std::io::{Result, BufRead, BufReader};
//! use gmarkov_lib::MarkovChain;
//!
//! fn main() -> Result<()> {
//!     let mut chain = MarkovChain::with_order::<char>(2);
//!
//!     let reader = BufReader::new(File::open("examples/dictionary_sample.txt")?);
//!     for line in reader.lines() {
//!         chain.feed(line?.chars());
//!     }
//!
//!     println!("New word: {}", chain.into_iter().collect::<String>());
//!
//!     Ok(())
//! }
//! ```
//!
//! The short program above will create a Markov chain of order 2, then feed it
//! 100 random words from the dictionary (in `examples/dictionary_sample.txt`), then print out
//! one new random word.
//!
//! [`gmarkov`]: https://crates.io/crates/gmarkov

use std::collections::HashMap;
use std::iter;
use std::hash::Hash;

/// Trait alias for `Eq + Hash + Clone`
///
/// All markov chain items must have this triat.
pub trait ChainItem: Eq + Hash + Clone {}
impl<T: Eq + Hash + Clone> ChainItem for T {}


mod weighted_set;
use weighted_set::WeightedSet;

#[derive(PartialEq, Eq, Hash, Clone, Debug)]
enum MarkovItem<I: ChainItem> {
    Start,
    Mid(I),
    End,
}

type MarkovSymbolTable<I> = HashMap<Vec<MarkovItem<I>>, WeightedSet<MarkovItem<I>>>;

/// An iterator that produces the values of a Markov chain
///
/// This `struct` is created by [`into_iter`] function on [`MarkovChain`]. See
/// it's documentation for more information.
///
/// [`into_iter`]: struct.MarkovChain.html#method.into_iter
/// [`MarkovChain`]: struct.MarkovChain.html
pub struct MarkovIterator<I: ChainItem> {
    order: usize,
    prev_symbols: Vec<MarkovItem<I>>,
    symbols: MarkovSymbolTable<I>,
}

impl<I: ChainItem> Iterator for MarkovIterator<I> {
    type Item = I;

    fn next(&mut self) -> Option<I> {
        self.prev_symbols
            .push(self.symbols[&self.prev_symbols].choose().cloned()?);

        if self.prev_symbols.len() > self.order {
            self.prev_symbols.remove(0);
        }

        if let MarkovItem::Mid(symbol) = self.prev_symbols.iter().last().cloned()? {
            Some(symbol)
        } else {
            None
        }
    }
}

/// The Markov Chain data structure
///
/// A Markov chain is a statistical model that is used to predict random
/// sequences based on the probability of one symbol coming after another. For a
/// basic introduction you can read [this article]. For a more technical
/// overview you can read the [wikipedia article] on the subject.
///
/// Items in markov chains in `gmarkov-lib` must implement the `Eq`, `Hash`, and
/// `Clone` traits for use.
///
/// [this article]: https://towardsdatascience.com/introduction-to-markov-chains-50da3645a50d
/// [wikipedia article]: https://en.wikipedia.org/wiki/Markov_chain
#[derive(Clone, Debug)]
pub struct MarkovChain<I: ChainItem> {
    order: usize,
    symbols: MarkovSymbolTable<I>,
}

impl<I: ChainItem> MarkovChain<I> {
    /// Create a Markov chain with order `order`.
    ///
    /// The order specifies how many steps back in the sequence the chain keeps
    /// track of. A higher order allows for better results but also requires a
    /// much larger dataset.
    pub fn with_order<T>(order: usize) -> MarkovChain<I> {
        MarkovChain {
            order,
            symbols: HashMap::new(),
        }
    }

    fn update(&mut self, item: &[MarkovItem<I>], next_item: &MarkovItem<I>) {
        if let Some(set) = self.symbols.get_mut(item) {
            // Can't use entry here because it is not generic
            set.count(next_item);
        } else {
            let mut set = WeightedSet::default();
            set.count(next_item);
            self.symbols.insert(item.to_vec(), set);
        }
    }

    /// Train the Markov chain by "feeding" it the iterator `input`.
    ///
    /// This adds the input to the current Markov chain. The generated output
    /// will resemble this input.
    pub fn feed(&mut self, input: impl Iterator<Item = I>) {
        let items = iter::once(MarkovItem::Start)
            .chain(input.map(move |i| MarkovItem::Mid(i.clone())))
            .chain(iter::once(MarkovItem::End))
            .collect::<Vec<_>>();

        (1..self.order)
            .map(|i| &items[..=i])
            .chain(items.windows(self.order + 1))
            .for_each(|window| {
                let (present, past) = window.split_last().unwrap();
                self.update(past, present);
            });
    }
}

impl<I: ChainItem> IntoIterator for MarkovChain<I> {
    type Item = I;
    type IntoIter = MarkovIterator<I>;

    /// Transform the Markov chain into an iterator.
    ///
    /// You will usually want to clone the Markov chain before turning it into
    /// an iterator so that you don't have to retrain it every time.
    fn into_iter(self) -> Self::IntoIter {
        let mut prev_symbols = Vec::with_capacity(2);
        prev_symbols.push(MarkovItem::Start);

        MarkovIterator {
            order: self.order,
            prev_symbols,
            symbols: self.symbols,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn chain() {
        let mut chain = MarkovChain::with_order::<char>(1);

        chain.feed("hello world".chars());
        chain.feed("foo bar baz".chars());

        println!("{}", chain.into_iter().collect::<String>());
    }
}