markov_chain/
lib.rs

1//! A markov chain library for Rust.
2//! 
3//! # Features
4//! * Training sequences of arbitrary types
5//! * Nodes of N order
6//! * Specialized string generation and training
7//! * Serialization via serde
8//! * Generation utility
9//! 
10//! # Examples
11//! In your Cargo.toml file, make sure you have the line `markov_chain = "0.1"`
12//! under the `[dependencies]` section.
13//! 
14//! Markov chains may be created with any type that implements `Clone`, `Hash`,
15//! and `Eq`, and with some order (which is the number of items per node on the
16//! markov chain).
17//! 
18//! ## Creating a basic chain
19//!
20//! ```
21//! use markov_chain::Chain;
22//! 
23//! let mut chain = Chain::new(1); // 1 is the order of the chain
24//! 
25//! // Train the chain on some vectors
26//! chain.train(vec![1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 1])
27//!     .train(vec![5, 4, 3, 2, 1]);
28//! 
29//! // Generate a sequence and print it out
30//! let sequence = chain.generate();
31//! println!("{:?} ", sequence);
32//! ```
33#![warn(missing_docs)]
34extern crate serde;
35#[macro_use]
36extern crate serde_derive;
37
38#[macro_use]
39extern crate maplit;
40extern crate rand;
41extern crate regex;
42
43#[macro_use]
44extern crate lazy_static;
45
46use rand::distributions::{Weighted, WeightedChoice, IndependentSample};
47use rand::Rng;
48use regex::Regex;
49use std::collections::HashMap;
50use std::hash::Hash;
51
52// Stolen from public domain project https://github.com/aatxe/markov
53/// A trait that defines a restrictions required for chainable items.
54pub trait Chainable: Eq + Hash {}
55impl<T> Chainable for T where T: Eq + Hash {}
56
57type Node<T> = Vec<Option<T>>;
58type Link<T> = HashMap<Option<T>, u32>;
59
60// don't add where T: Serialize + DeserializeOwned, see
61// https://github.com/serde-rs/serde/issues/890
62/// A struct representing a markov chain.
63///
64/// A markov chain has an order, which determines how many items
65/// per node are held. The chain itself is a map of vectors, which point to
66/// a map of single elements pointing at a weight.
67///
68/// ```
69/// use markov_chain::Chain;
70/// 
71/// let mut chain = Chain::new(1); // 1 is the order of the chain
72/// 
73/// // Train the chain on some vectors
74/// chain.train(vec![1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 1])
75///     .train(vec![5, 4, 3, 2, 1]);
76/// 
77/// // Generate a sequence and print it out
78/// let sequence = chain.generate();
79/// println!("{:?} ", sequence);
80/// ```
81#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
82pub struct Chain<T> where T: Clone + Chainable {
83    chain: HashMap<Node<T>, Link<T>>,
84    order: usize,
85}
86
87impl<T> Chain<T> where T: Clone + Chainable {
88    /// Initializes a new markov chain with a given order.
89    /// # Examples
90    /// ```
91    /// use markov_chain::Chain;
92    /// let chain: Chain<u32> = Chain::new(1);
93    /// ```
94    pub fn new(order: usize) -> Self {
95        Chain {
96            chain: HashMap::new(),
97            order,
98        }
99    } 
100
101    /// Gets the order of the markov chain. This is static from chain to chain.
102    pub fn order(&self) -> usize {
103        self.order
104    }
105
106    /// Trains a sentence on a string of items.
107    /// # Examples
108    /// ```
109    /// use markov_chain::Chain;
110    /// let mut chain = Chain::new(1);
111    /// let data = vec![10, 15, 20];
112    /// chain.train(data)
113    ///     .train(vec![]);
114    /// ```
115    pub fn train(&mut self, string: Vec<T>) -> &mut Self {
116        if string.is_empty() {
117            return self;
118        }
119
120        let order = self.order;
121
122        let mut string = string.into_iter()
123            .map(|x| Some(x))
124            .collect::<Vec<Option<T>>>();
125        while string.len() < order {
126            string.push(None);
127        }
128
129        let mut window = vec!(None; order);
130        self.update_link(&window, &string[0]);
131
132        let mut end = 0;
133        while end < string.len() - 1 {
134            window.remove(0);
135            let next = &string[end + 1];
136            window.push(string[end].clone());
137
138            self.update_link(&window, &next);
139
140            end += 1;
141        }
142        window.remove(0);
143        window.push(string[end].clone());
144        self.update_link(&window, &None);
145        self
146    }
147
148    /// Merges this markov chain with another.
149    /// # Examples
150    /// ```
151    /// use markov_chain::Chain;
152    /// let mut chain1 = Chain::new(1);
153    /// let mut chain2 = chain1.clone();
154    /// chain1.train(vec![1, 2, 3]);
155    /// chain2.train(vec![2, 3, 4, 5, 6])
156    ///     .merge(&chain1);
157    /// ```
158    pub fn merge(&mut self, other: &Self) -> &mut Self {
159        assert_eq!(self.order, other.order, "orders must be equal in order to merge markov chains");
160        if self.chain.is_empty() {
161            self.chain = other.chain.clone();
162            return self;
163        }
164
165        for (node, link) in &other.chain {
166            for (ref next, &weight) in link.iter() {
167                self.update_link_weight(node, next, weight);
168            }
169        }
170        self
171    }
172
173    /// Increments a link from a node by one, or adding it with a weight of 1
174    /// if it doesn't exist.
175    fn update_link(&mut self, node: &[Option<T>], next: &Option<T>) {
176        self.update_link_weight(node, next, 1);
177    }
178
179    /// Increments a link from a node by specified value, or adding it with a
180    /// weight of the specified value if it doesn't exist.
181    fn update_link_weight(&mut self, node: &[Option<T>], next: &Option<T>, weight: u32) {
182        if self.chain.contains_key(node) {
183            let links = self.chain
184                .get_mut(node)
185                .unwrap();
186            // Update the link
187            if links.contains_key(next) {
188                let weight = *links.get(next).unwrap() + weight;
189                links.insert(next.clone(), weight);
190            }
191            // Insert a new link
192            else {
193                links.insert(next.clone(), weight);
194            }
195        }
196        else {
197            self.chain.insert(Vec::from(node), hashmap!{next.clone() => weight});
198        }
199    }
200
201    /// Generates a string of items with no maximum limit.
202    /// This is equivalent to `generate_limit(-1)`.
203    pub fn generate(&self) -> Vec<T> {
204        self.generate_limit(-1)
205    }
206
207    /// Generates a string of items, based on the training, of up to N items.
208    /// Specifying a maximum of -1 allows any arbitrary size of list.
209    pub fn generate_limit(&self, max: isize) -> Vec<T> {
210        // TODO : DRY generate_sentence(1)
211        if self.chain.is_empty() {
212            return vec![];
213        }
214
215        let mut curs = {
216            let c;
217            loop {
218                if let Some(n) = self.choose_random_node() {
219                    c = n.clone();
220                    break;
221                }
222            }
223            c
224        };
225
226        // this takes care of an instance where we have order N and have chosen a node that is
227        // shorter than our order.
228        if curs.iter().find(|x| x.is_none()).is_some() {
229            return curs.iter()
230                .cloned()
231                .filter_map(|x| x)
232                .collect();
233        }
234
235        let mut result = curs.clone()
236            .into_iter()
237            .map(|x| x.unwrap())
238            .collect::<Vec<T>>();
239
240        loop {
241            // Choose the next item
242            let next = self.choose_random_link(&curs);
243            if let Some(next) = next {
244                result.push(next.clone());
245                curs.push(Some(next.clone()));
246                curs.remove(0);
247            }
248            else {
249                break;
250            }
251
252            if result.len() as isize >= max && max > 0 {
253                break;
254            }
255        }
256        result
257    }
258
259    fn choose_random_link(&self, node: &Node<T>) -> Option<&T> {
260        assert_eq!(node.len(), self.order);
261        if let Some(ref link) = self.chain.get(node) {
262            let mut weights = link.iter()
263                .map(|(k, v)| Weighted { weight: *v, item: k.as_ref() })
264                .collect::<Vec<_>>();
265            let chooser = WeightedChoice::new(&mut weights);
266            let mut rng = rand::thread_rng();
267            chooser.ind_sample(&mut rng)
268        }
269        else {
270            None
271        }
272    }
273
274    fn choose_random_node(&self) -> Option<&Node<T>> {
275        if self.chain.is_empty() {
276            None
277        }
278        else {
279            let mut rng = rand::thread_rng();
280            self.chain.keys()
281                .nth(rng.gen_range(0, self.chain.len()))
282        }
283    }
284}
285
286lazy_static! { 
287    /// Symbol combinations to break sentences on.
288    static ref BREAK: [&'static str; 7] = [".", "?", "!", ".\"", "!\"", "?\"", ",\""];
289}
290/// String-specific implementation of the chain. Contains some special string-
291/// specific functions.
292impl Chain<String> {
293    /// Trains this chain on a single string. Strings are broken into words,
294    /// which are split by whitespace and punctuation.
295    pub fn train_string(&mut self, sentence: &str) -> &mut Self {
296        lazy_static! {
297            static ref RE: Regex = Regex::new(
298                r#"[^ .!?,\-\n\r\t]+|[.,!?\-"]+"#
299                ).unwrap();
300        };
301        let parts = {
302            let mut parts = Vec::new();
303            let mut words = Vec::new();
304            for mat in RE.find_iter(sentence).map(|m| m.as_str()) {
305                words.push(String::from(mat));
306                if BREAK.contains(&mat) {
307                    parts.push(words.clone());
308                    words.clear();
309                }
310            }
311            parts
312        };
313        for string in parts {
314            self.train(string);
315        }
316        self
317    }
318
319    /// Generates a sentence, which are ended by "break" strings or null links.
320    /// "Break" strings are:
321    /// `.`, `?`, `!`, `."`, `!"`, `?"`, `,"`
322    pub fn generate_sentence(&self) -> String {
323        // TODO : DRY generate_sentence(1)
324        // consider an iterator?
325        if self.chain.is_empty() {
326            return String::new();
327        }
328
329        let mut curs = vec!(None; self.order);
330        let mut result = Vec::new();
331        loop {
332            // Choose the next item
333            let next = self.choose_random_link(&curs);
334            if let Some(next) = next {
335                result.push(next.clone());
336                curs.push(Some(next.clone()));
337                curs.remove(0);
338                if BREAK.contains(&next.as_str()) {
339                    break;
340                }
341            }
342            else {
343                break;
344            }
345        }
346        let mut result = result.into_iter()
347            .fold(String::new(), |a, b| if BREAK.contains(&b.as_str()) || b == "," { a + b.as_str() } else { a + " " + b.as_str() });
348        result.remove(0); // get rid of the leading space character
349        result
350    }
351
352    /// Generates a paragraph of N sentences. Each sentence is broken off by N
353    /// spaces.
354    pub fn generate_paragraph(&self, sentences: usize) -> String {
355        let mut paragraph = Vec::new();
356        for _ in 0 .. sentences {
357            paragraph.push(self.generate_sentence());
358        }
359        paragraph.join(" ")
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use ::*;
366
367    macro_rules! test_get_link {
368        ($chain:expr, [$($key:expr),+]) => {{
369            let ref map = $chain.chain;
370            let key = vec![$(Some($key),)+];
371            assert_eq!(key.len(), $chain.order);
372            assert!(map.contains_key(&key));
373            map.get(&key)
374                .unwrap()
375        }}; 
376    }
377
378    macro_rules! test_link_weight {
379        ($link:expr, $key:expr, $weight:expr) => {
380            let link = $link;
381            let key = $key;
382            assert!(link.contains_key(&key));
383            assert_eq!(*link.get(&key).unwrap(), $weight);
384        };
385    }
386
387    #[cfg(feature = "serde_cbor")]
388    #[test]
389    fn test_cbor_serialize() {
390        let mut chain = Chain::<u32>::new(1);
391        chain.train(vec![1, 2, 3])
392            .train(vec![2, 3, 4])
393            .train(vec![1, 3, 4]);
394        let cbor_vec = chain.to_cbor();
395        assert!(cbor_vec.is_ok());
396        let de = Chain::from_cbor(&cbor_vec.unwrap());
397        assert_eq!(de.unwrap(), chain);
398    }
399
400    #[cfg(feature = "serde_yaml")]
401    #[test]
402    fn test_yaml_serialize() {
403        let mut chain = Chain::<u32>::new(1);
404        chain.train(vec![1, 2, 3])
405            .train(vec![2, 3, 4])
406            .train(vec![1, 3, 4]);
407        let yaml_str = chain.to_yaml();
408        assert!(yaml_str.is_ok());
409        let de = Chain::from_yaml(&yaml_str.unwrap());
410        assert_eq!(de.unwrap(), chain);
411    }
412
413    #[test]
414    fn test_order1_training() {
415        let mut chain = Chain::<u32>::new(1);
416        chain.train(vec![1, 2, 3])
417            .train(vec![2, 3, 4])
418            .train(vec![1, 3, 4]);
419        let link = test_get_link!(chain, [1u32]);
420        test_link_weight!(link, Some(2u32), 1);
421        test_link_weight!(link, Some(3u32), 1);
422
423        let link = test_get_link!(chain, [2u32]);
424        test_link_weight!(link, Some(3u32), 2);
425
426        let link = test_get_link!(chain, [3u32]);
427        test_link_weight!(link, None, 1);
428        test_link_weight!(link, Some(4u32), 2);
429
430        let link = test_get_link!(chain, [4u32]);
431        test_link_weight!(link, None, 2);
432    }
433
434    #[test]
435    fn test_order2_training() {
436        let mut chain = Chain::<u32>::new(2);
437        chain.train(vec![1, 2, 3])
438            .train(vec![2, 3, 4])
439            .train(vec![1, 3, 4]);
440        let link = test_get_link!(chain, [1u32, 2u32]);
441        test_link_weight!(link, Some(3u32), 1);
442
443        let link = test_get_link!(chain, [2u32, 3u32]);
444        test_link_weight!(link, None, 1);
445        test_link_weight!(link, Some(4u32), 1);
446
447        let link = test_get_link!(chain, [3u32, 4u32]);
448        test_link_weight!(link, None, 2);
449
450        let link = test_get_link!(chain, [1u32, 3u32]);
451        test_link_weight!(link, Some(4u32), 1);
452    }
453
454    #[test]
455    fn test_order3_training() {
456        let mut chain = Chain::<u32>::new(3);
457        chain.train(vec![1, 2, 3, 4, 1, 2, 3, 4]);
458
459        let link = test_get_link!(chain, [1u32, 2u32, 3u32]);
460        test_link_weight!(link, Some(4u32), 2);
461
462        let link = test_get_link!(chain, [2u32, 3u32, 4u32]);
463        test_link_weight!(link, Some(1u32), 1);
464        test_link_weight!(link, None, 1);
465
466        let link = test_get_link!(chain, [3u32, 4u32, 1u32]);
467        test_link_weight!(link, Some(2u32), 1);
468
469        let link = test_get_link!(chain, [4u32, 1u32, 2u32]);
470        test_link_weight!(link, Some(3u32), 1);
471    }
472}