1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
//! A library that provides Markov chain data structures.
//!
//! The `MarkovChain` structure allows you to feed in several sequences of items
//! and get out a sequence that looks very similar, but is randomly generated.
//!
//! This is the library to my CLI app [`gmarkov`].
//!
//! # Example
//! ```
//! extern crate gmarkov_lib;
//!
//! use std::fs::File;
//! use std::io::{Result, BufRead, BufReader};
//! use gmarkov_lib::MarkovChain;
//!
//! fn main() -> Result<()> {
//!     let mut chain = MarkovChain::with_order::<char>(2);
//!
//!     let reader = BufReader::new(File::open("examples/dictionary_sample.txt")?);
//!     for line in reader.lines() {
//!         chain.feed(line?.chars());
//!     }
//!
//!     println!("New word: {}", chain.into_iter().collect::<String>());
//!
//!     Ok(())
//! }
//! ```
//!
//! The short program above will create a Markov chain of order 2, then feed it
//! 100 random words from the dictionary (in `examples/dictionary_sample.txt`), then print out
//! one new random word.
//!
//! [`gmarkov`]: https://crates.io/crates/gmarkov

use std::collections::HashMap;
use std::iter;
use std::hash::Hash;

/// Trait alias for `Eq + Hash + Clone`
///
/// All markov chain items must have this triat.
pub trait ChainItem: Eq + Hash + Clone {}
impl<T: Eq + Hash + Clone> ChainItem for T {}


mod weighted_set;
use weighted_set::WeightedSet;

#[derive(PartialEq, Eq, Hash, Clone, Debug)]
enum MarkovItem<I: ChainItem> {
    Start,
    Mid(I),
    End,
}

type MarkovSymbolTable<I> = HashMap<Vec<MarkovItem<I>>, WeightedSet<MarkovItem<I>>>;

/// An iterator that produces the values of a Markov chain
///
/// This `struct` is created by [`into_iter`] function on [`MarkovChain`]. See
/// it's documentation for more information.
///
/// [`into_iter`]: struct.MarkovChain.html#method.into_iter
/// [`MarkovChain`]: struct.MarkovChain.html
pub struct MarkovIterator<I: ChainItem> {
    order: usize,
    prev_symbols: Vec<MarkovItem<I>>,
    symbols: MarkovSymbolTable<I>,
}

impl<I: ChainItem> Iterator for MarkovIterator<I> {
    type Item = I;

    fn next(&mut self) -> Option<I> {
        self.prev_symbols
            .push(self.symbols[&self.prev_symbols].choose().cloned()?);

        if self.prev_symbols.len() > self.order {
            self.prev_symbols.remove(0);
        }

        if let MarkovItem::Mid(symbol) = self.prev_symbols.iter().last().cloned()? {
            Some(symbol)
        } else {
            None
        }
    }
}

/// The Markov Chain data structure
///
/// A Markov chain is a statistical model that is used to predict random
/// sequences based on the probability of one symbol coming after another. For a
/// basic introduction you can read [this article]. For a more technical
/// overview you can read the [wikipedia article] on the subject.
///
/// Items in markov chains in `gmarkov-lib` must implement the `Eq`, `Hash`, and
/// `Clone` traits for use.
///
/// [this article]: https://towardsdatascience.com/introduction-to-markov-chains-50da3645a50d
/// [wikipedia article]: https://en.wikipedia.org/wiki/Markov_chain
#[derive(Clone, Debug)]
pub struct MarkovChain<I: ChainItem> {
    order: usize,
    symbols: MarkovSymbolTable<I>,
}

impl<I: ChainItem> MarkovChain<I> {
    /// Create a Markov chain with order `order`.
    ///
    /// The order specifies how many steps back in the sequence the chain keeps
    /// track of. A higher order allows for better results but also requires a
    /// much larger dataset.
    pub fn with_order<T>(order: usize) -> MarkovChain<I> {
        MarkovChain {
            order,
            symbols: HashMap::new(),
        }
    }

    fn update(&mut self, item: &[MarkovItem<I>], next_item: &MarkovItem<I>) {
        if let Some(set) = self.symbols.get_mut(item) {
            // Can't use entry here because it is not generic
            set.count(next_item);
        } else {
            let mut set = WeightedSet::default();
            set.count(next_item);
            self.symbols.insert(item.to_vec(), set);
        }
    }

    /// Train the Markov chain by "feeding" it the iterator `input`.
    ///
    /// This adds the input to the current Markov chain. The generated output
    /// will resemble this input.
    pub fn feed(&mut self, input: impl Iterator<Item = I>) {
        let items = iter::once(MarkovItem::Start)
            .chain(input.map(move |i| MarkovItem::Mid(i.clone())))
            .chain(iter::once(MarkovItem::End))
            .collect::<Vec<_>>();

        (1..self.order)
            .map(|i| &items[..=i])
            .chain(items.windows(self.order + 1))
            .for_each(|window| {
                let (present, past) = window.split_last().unwrap();
                self.update(past, present);
            });
    }
}

impl<I: ChainItem> IntoIterator for MarkovChain<I> {
    type Item = I;
    type IntoIter = MarkovIterator<I>;

    /// Transform the Markov chain into an iterator.
    ///
    /// You will usually want to clone the Markov chain before turning it into
    /// an iterator so that you don't have to retrain it every time.
    fn into_iter(self) -> Self::IntoIter {
        let mut prev_symbols = Vec::with_capacity(2);
        prev_symbols.push(MarkovItem::Start);

        MarkovIterator {
            order: self.order,
            prev_symbols,
            symbols: self.symbols,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn chain() {
        let mut chain = MarkovChain::with_order::<char>(1);

        chain.feed("hello world".chars());
        chain.feed("foo bar baz".chars());

        println!("{}", chain.into_iter().collect::<String>());
    }
}