gmarkov_lib/lib.rs
1//! A library that provides Markov chain data structures.
2//!
3//! The `MarkovChain` structure allows you to feed in several sequences of items
4//! and get out a sequence that looks very similar, but is randomly generated.
5//!
6//! This is the library to my CLI app [`gmarkov`].
7//!
8//! # Example
9//! ```
10//! extern crate gmarkov_lib;
11//!
12//! use std::fs::File;
13//! use std::io::{Result, BufRead, BufReader};
14//! use gmarkov_lib::MarkovChain;
15//!
16//! fn main() -> Result<()> {
17//! let mut chain = MarkovChain::with_order::<char>(2);
18//!
19//! let reader = BufReader::new(File::open("examples/dictionary_sample.txt")?);
20//! for line in reader.lines() {
21//! chain.feed(line?.chars());
22//! }
23//!
24//! println!("New word: {}", chain.into_iter().collect::<String>());
25//!
26//! Ok(())
27//! }
28//! ```
29//!
30//! The short program above will create a Markov chain of order 2, then feed it
31//! 100 random words from the dictionary (in `examples/dictionary_sample.txt`), then print out
32//! one new random word.
33//!
34//! [`gmarkov`]: https://crates.io/crates/gmarkov
35
36use std::collections::HashMap;
37use std::iter;
38use std::hash::Hash;
39
40/// Trait alias for `Eq + Hash + Clone`
41///
42/// All markov chain items must have this triat.
43pub trait ChainItem: Eq + Hash + Clone {}
44impl<T: Eq + Hash + Clone> ChainItem for T {}
45
46
47mod weighted_set;
48use weighted_set::WeightedSet;
49
50#[derive(PartialEq, Eq, Hash, Clone, Debug)]
51enum MarkovItem<I: ChainItem> {
52 Start,
53 Mid(I),
54 End,
55}
56
57type MarkovSymbolTable<I> = HashMap<Vec<MarkovItem<I>>, WeightedSet<MarkovItem<I>>>;
58
59/// An iterator that produces the values of a Markov chain
60///
61/// This `struct` is created by [`into_iter`] function on [`MarkovChain`]. See
62/// it's documentation for more information.
63///
64/// [`into_iter`]: struct.MarkovChain.html#method.into_iter
65/// [`MarkovChain`]: struct.MarkovChain.html
66pub struct MarkovIterator<I: ChainItem> {
67 order: usize,
68 prev_symbols: Vec<MarkovItem<I>>,
69 symbols: MarkovSymbolTable<I>,
70}
71
72impl<I: ChainItem> Iterator for MarkovIterator<I> {
73 type Item = I;
74
75 fn next(&mut self) -> Option<I> {
76 self.prev_symbols
77 .push(self.symbols[&self.prev_symbols].choose().cloned()?);
78
79 if self.prev_symbols.len() > self.order {
80 self.prev_symbols.remove(0);
81 }
82
83 if let MarkovItem::Mid(symbol) = self.prev_symbols.iter().last().cloned()? {
84 Some(symbol)
85 } else {
86 None
87 }
88 }
89}
90
91/// The Markov Chain data structure
92///
93/// A Markov chain is a statistical model that is used to predict random
94/// sequences based on the probability of one symbol coming after another. For a
95/// basic introduction you can read [this article]. For a more technical
96/// overview you can read the [wikipedia article] on the subject.
97///
98/// Items in markov chains in `gmarkov-lib` must implement the `Eq`, `Hash`, and
99/// `Clone` traits for use.
100///
101/// [this article]: https://towardsdatascience.com/introduction-to-markov-chains-50da3645a50d
102/// [wikipedia article]: https://en.wikipedia.org/wiki/Markov_chain
103#[derive(Clone, Debug)]
104pub struct MarkovChain<I: ChainItem> {
105 order: usize,
106 symbols: MarkovSymbolTable<I>,
107}
108
109impl<I: ChainItem> MarkovChain<I> {
110 /// Create a Markov chain with order `order`.
111 ///
112 /// The order specifies how many steps back in the sequence the chain keeps
113 /// track of. A higher order allows for better results but also requires a
114 /// much larger dataset.
115 pub fn with_order<T>(order: usize) -> MarkovChain<I> {
116 MarkovChain {
117 order,
118 symbols: HashMap::new(),
119 }
120 }
121
122 fn update(&mut self, item: &[MarkovItem<I>], next_item: &MarkovItem<I>) {
123 if let Some(set) = self.symbols.get_mut(item) {
124 // Can't use entry here because it is not generic
125 set.count(next_item);
126 } else {
127 let mut set = WeightedSet::default();
128 set.count(next_item);
129 self.symbols.insert(item.to_vec(), set);
130 }
131 }
132
133 /// Train the Markov chain by "feeding" it the iterator `input`.
134 ///
135 /// This adds the input to the current Markov chain. The generated output
136 /// will resemble this input.
137 pub fn feed(&mut self, input: impl Iterator<Item = I>) {
138 let items = iter::once(MarkovItem::Start)
139 .chain(input.map(move |i| MarkovItem::Mid(i.clone())))
140 .chain(iter::once(MarkovItem::End))
141 .collect::<Vec<_>>();
142
143 (1..self.order)
144 .map(|i| &items[..=i])
145 .chain(items.windows(self.order + 1))
146 .for_each(|window| {
147 let (present, past) = window.split_last().unwrap();
148 self.update(past, present);
149 });
150 }
151}
152
153impl<I: ChainItem> IntoIterator for MarkovChain<I> {
154 type Item = I;
155 type IntoIter = MarkovIterator<I>;
156
157 /// Transform the Markov chain into an iterator.
158 ///
159 /// You will usually want to clone the Markov chain before turning it into
160 /// an iterator so that you don't have to retrain it every time.
161 fn into_iter(self) -> Self::IntoIter {
162 let mut prev_symbols = Vec::with_capacity(2);
163 prev_symbols.push(MarkovItem::Start);
164
165 MarkovIterator {
166 order: self.order,
167 prev_symbols,
168 symbols: self.symbols,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn chain() {
179 let mut chain = MarkovChain::with_order::<char>(1);
180
181 chain.feed("hello world".chars());
182 chain.feed("foo bar baz".chars());
183
184 println!("{}", chain.into_iter().collect::<String>());
185 }
186}