gmarkov_lib/
lib.rs

1//! A library that provides Markov chain data structures.
2//!
3//! The `MarkovChain` structure allows you to feed in several sequences of items
4//! and get out a sequence that looks very similar, but is randomly generated.
5//!
6//! This is the library to my CLI app [`gmarkov`].
7//!
8//! # Example
9//! ```
10//! extern crate gmarkov_lib;
11//!
12//! use std::fs::File;
13//! use std::io::{Result, BufRead, BufReader};
14//! use gmarkov_lib::MarkovChain;
15//!
16//! fn main() -> Result<()> {
17//!     let mut chain = MarkovChain::with_order::<char>(2);
18//!
19//!     let reader = BufReader::new(File::open("examples/dictionary_sample.txt")?);
20//!     for line in reader.lines() {
21//!         chain.feed(line?.chars());
22//!     }
23//!
24//!     println!("New word: {}", chain.into_iter().collect::<String>());
25//!
26//!     Ok(())
27//! }
28//! ```
29//!
30//! The short program above will create a Markov chain of order 2, then feed it
31//! 100 random words from the dictionary (in `examples/dictionary_sample.txt`), then print out
32//! one new random word.
33//!
34//! [`gmarkov`]: https://crates.io/crates/gmarkov
35
36use std::collections::HashMap;
37use std::iter;
38use std::hash::Hash;
39
40/// Trait alias for `Eq + Hash + Clone`
41///
42/// All markov chain items must have this triat.
43pub trait ChainItem: Eq + Hash + Clone {}
44impl<T: Eq + Hash + Clone> ChainItem for T {}
45
46
47mod weighted_set;
48use weighted_set::WeightedSet;
49
50#[derive(PartialEq, Eq, Hash, Clone, Debug)]
51enum MarkovItem<I: ChainItem> {
52    Start,
53    Mid(I),
54    End,
55}
56
57type MarkovSymbolTable<I> = HashMap<Vec<MarkovItem<I>>, WeightedSet<MarkovItem<I>>>;
58
59/// An iterator that produces the values of a Markov chain
60///
61/// This `struct` is created by [`into_iter`] function on [`MarkovChain`]. See
62/// it's documentation for more information.
63///
64/// [`into_iter`]: struct.MarkovChain.html#method.into_iter
65/// [`MarkovChain`]: struct.MarkovChain.html
66pub struct MarkovIterator<I: ChainItem> {
67    order: usize,
68    prev_symbols: Vec<MarkovItem<I>>,
69    symbols: MarkovSymbolTable<I>,
70}
71
72impl<I: ChainItem> Iterator for MarkovIterator<I> {
73    type Item = I;
74
75    fn next(&mut self) -> Option<I> {
76        self.prev_symbols
77            .push(self.symbols[&self.prev_symbols].choose().cloned()?);
78
79        if self.prev_symbols.len() > self.order {
80            self.prev_symbols.remove(0);
81        }
82
83        if let MarkovItem::Mid(symbol) = self.prev_symbols.iter().last().cloned()? {
84            Some(symbol)
85        } else {
86            None
87        }
88    }
89}
90
91/// The Markov Chain data structure
92///
93/// A Markov chain is a statistical model that is used to predict random
94/// sequences based on the probability of one symbol coming after another. For a
95/// basic introduction you can read [this article]. For a more technical
96/// overview you can read the [wikipedia article] on the subject.
97///
98/// Items in markov chains in `gmarkov-lib` must implement the `Eq`, `Hash`, and
99/// `Clone` traits for use.
100///
101/// [this article]: https://towardsdatascience.com/introduction-to-markov-chains-50da3645a50d
102/// [wikipedia article]: https://en.wikipedia.org/wiki/Markov_chain
103#[derive(Clone, Debug)]
104pub struct MarkovChain<I: ChainItem> {
105    order: usize,
106    symbols: MarkovSymbolTable<I>,
107}
108
109impl<I: ChainItem> MarkovChain<I> {
110    /// Create a Markov chain with order `order`.
111    ///
112    /// The order specifies how many steps back in the sequence the chain keeps
113    /// track of. A higher order allows for better results but also requires a
114    /// much larger dataset.
115    pub fn with_order<T>(order: usize) -> MarkovChain<I> {
116        MarkovChain {
117            order,
118            symbols: HashMap::new(),
119        }
120    }
121
122    fn update(&mut self, item: &[MarkovItem<I>], next_item: &MarkovItem<I>) {
123        if let Some(set) = self.symbols.get_mut(item) {
124            // Can't use entry here because it is not generic
125            set.count(next_item);
126        } else {
127            let mut set = WeightedSet::default();
128            set.count(next_item);
129            self.symbols.insert(item.to_vec(), set);
130        }
131    }
132
133    /// Train the Markov chain by "feeding" it the iterator `input`.
134    ///
135    /// This adds the input to the current Markov chain. The generated output
136    /// will resemble this input.
137    pub fn feed(&mut self, input: impl Iterator<Item = I>) {
138        let items = iter::once(MarkovItem::Start)
139            .chain(input.map(move |i| MarkovItem::Mid(i.clone())))
140            .chain(iter::once(MarkovItem::End))
141            .collect::<Vec<_>>();
142
143        (1..self.order)
144            .map(|i| &items[..=i])
145            .chain(items.windows(self.order + 1))
146            .for_each(|window| {
147                let (present, past) = window.split_last().unwrap();
148                self.update(past, present);
149            });
150    }
151}
152
153impl<I: ChainItem> IntoIterator for MarkovChain<I> {
154    type Item = I;
155    type IntoIter = MarkovIterator<I>;
156
157    /// Transform the Markov chain into an iterator.
158    ///
159    /// You will usually want to clone the Markov chain before turning it into
160    /// an iterator so that you don't have to retrain it every time.
161    fn into_iter(self) -> Self::IntoIter {
162        let mut prev_symbols = Vec::with_capacity(2);
163        prev_symbols.push(MarkovItem::Start);
164
165        MarkovIterator {
166            order: self.order,
167            prev_symbols,
168            symbols: self.symbols,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn chain() {
179        let mut chain = MarkovChain::with_order::<char>(1);
180
181        chain.feed("hello world".chars());
182        chain.feed("foo bar baz".chars());
183
184        println!("{}", chain.into_iter().collect::<String>());
185    }
186}