1#![warn(missing_docs)]
34extern crate serde;
35#[macro_use]
36extern crate serde_derive;
37
38#[macro_use]
39extern crate maplit;
40extern crate rand;
41extern crate regex;
42
43#[macro_use]
44extern crate lazy_static;
45
46use rand::distributions::{Weighted, WeightedChoice, IndependentSample};
47use rand::Rng;
48use regex::Regex;
49use std::collections::HashMap;
50use std::hash::Hash;
51
52pub trait Chainable: Eq + Hash {}
55impl<T> Chainable for T where T: Eq + Hash {}
56
57type Node<T> = Vec<Option<T>>;
58type Link<T> = HashMap<Option<T>, u32>;
59
60#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
82pub struct Chain<T> where T: Clone + Chainable {
83 chain: HashMap<Node<T>, Link<T>>,
84 order: usize,
85}
86
87impl<T> Chain<T> where T: Clone + Chainable {
88 pub fn new(order: usize) -> Self {
95 Chain {
96 chain: HashMap::new(),
97 order,
98 }
99 }
100
101 pub fn order(&self) -> usize {
103 self.order
104 }
105
106 pub fn train(&mut self, string: Vec<T>) -> &mut Self {
116 if string.is_empty() {
117 return self;
118 }
119
120 let order = self.order;
121
122 let mut string = string.into_iter()
123 .map(|x| Some(x))
124 .collect::<Vec<Option<T>>>();
125 while string.len() < order {
126 string.push(None);
127 }
128
129 let mut window = vec!(None; order);
130 self.update_link(&window, &string[0]);
131
132 let mut end = 0;
133 while end < string.len() - 1 {
134 window.remove(0);
135 let next = &string[end + 1];
136 window.push(string[end].clone());
137
138 self.update_link(&window, &next);
139
140 end += 1;
141 }
142 window.remove(0);
143 window.push(string[end].clone());
144 self.update_link(&window, &None);
145 self
146 }
147
148 pub fn merge(&mut self, other: &Self) -> &mut Self {
159 assert_eq!(self.order, other.order, "orders must be equal in order to merge markov chains");
160 if self.chain.is_empty() {
161 self.chain = other.chain.clone();
162 return self;
163 }
164
165 for (node, link) in &other.chain {
166 for (ref next, &weight) in link.iter() {
167 self.update_link_weight(node, next, weight);
168 }
169 }
170 self
171 }
172
173 fn update_link(&mut self, node: &[Option<T>], next: &Option<T>) {
176 self.update_link_weight(node, next, 1);
177 }
178
179 fn update_link_weight(&mut self, node: &[Option<T>], next: &Option<T>, weight: u32) {
182 if self.chain.contains_key(node) {
183 let links = self.chain
184 .get_mut(node)
185 .unwrap();
186 if links.contains_key(next) {
188 let weight = *links.get(next).unwrap() + weight;
189 links.insert(next.clone(), weight);
190 }
191 else {
193 links.insert(next.clone(), weight);
194 }
195 }
196 else {
197 self.chain.insert(Vec::from(node), hashmap!{next.clone() => weight});
198 }
199 }
200
201 pub fn generate(&self) -> Vec<T> {
204 self.generate_limit(-1)
205 }
206
207 pub fn generate_limit(&self, max: isize) -> Vec<T> {
210 if self.chain.is_empty() {
212 return vec![];
213 }
214
215 let mut curs = {
216 let c;
217 loop {
218 if let Some(n) = self.choose_random_node() {
219 c = n.clone();
220 break;
221 }
222 }
223 c
224 };
225
226 if curs.iter().find(|x| x.is_none()).is_some() {
229 return curs.iter()
230 .cloned()
231 .filter_map(|x| x)
232 .collect();
233 }
234
235 let mut result = curs.clone()
236 .into_iter()
237 .map(|x| x.unwrap())
238 .collect::<Vec<T>>();
239
240 loop {
241 let next = self.choose_random_link(&curs);
243 if let Some(next) = next {
244 result.push(next.clone());
245 curs.push(Some(next.clone()));
246 curs.remove(0);
247 }
248 else {
249 break;
250 }
251
252 if result.len() as isize >= max && max > 0 {
253 break;
254 }
255 }
256 result
257 }
258
259 fn choose_random_link(&self, node: &Node<T>) -> Option<&T> {
260 assert_eq!(node.len(), self.order);
261 if let Some(ref link) = self.chain.get(node) {
262 let mut weights = link.iter()
263 .map(|(k, v)| Weighted { weight: *v, item: k.as_ref() })
264 .collect::<Vec<_>>();
265 let chooser = WeightedChoice::new(&mut weights);
266 let mut rng = rand::thread_rng();
267 chooser.ind_sample(&mut rng)
268 }
269 else {
270 None
271 }
272 }
273
274 fn choose_random_node(&self) -> Option<&Node<T>> {
275 if self.chain.is_empty() {
276 None
277 }
278 else {
279 let mut rng = rand::thread_rng();
280 self.chain.keys()
281 .nth(rng.gen_range(0, self.chain.len()))
282 }
283 }
284}
285
286lazy_static! {
287 static ref BREAK: [&'static str; 7] = [".", "?", "!", ".\"", "!\"", "?\"", ",\""];
289}
290impl Chain<String> {
293 pub fn train_string(&mut self, sentence: &str) -> &mut Self {
296 lazy_static! {
297 static ref RE: Regex = Regex::new(
298 r#"[^ .!?,\-\n\r\t]+|[.,!?\-"]+"#
299 ).unwrap();
300 };
301 let parts = {
302 let mut parts = Vec::new();
303 let mut words = Vec::new();
304 for mat in RE.find_iter(sentence).map(|m| m.as_str()) {
305 words.push(String::from(mat));
306 if BREAK.contains(&mat) {
307 parts.push(words.clone());
308 words.clear();
309 }
310 }
311 parts
312 };
313 for string in parts {
314 self.train(string);
315 }
316 self
317 }
318
319 pub fn generate_sentence(&self) -> String {
323 if self.chain.is_empty() {
326 return String::new();
327 }
328
329 let mut curs = vec!(None; self.order);
330 let mut result = Vec::new();
331 loop {
332 let next = self.choose_random_link(&curs);
334 if let Some(next) = next {
335 result.push(next.clone());
336 curs.push(Some(next.clone()));
337 curs.remove(0);
338 if BREAK.contains(&next.as_str()) {
339 break;
340 }
341 }
342 else {
343 break;
344 }
345 }
346 let mut result = result.into_iter()
347 .fold(String::new(), |a, b| if BREAK.contains(&b.as_str()) || b == "," { a + b.as_str() } else { a + " " + b.as_str() });
348 result.remove(0); result
350 }
351
352 pub fn generate_paragraph(&self, sentences: usize) -> String {
355 let mut paragraph = Vec::new();
356 for _ in 0 .. sentences {
357 paragraph.push(self.generate_sentence());
358 }
359 paragraph.join(" ")
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use ::*;
366
367 macro_rules! test_get_link {
368 ($chain:expr, [$($key:expr),+]) => {{
369 let ref map = $chain.chain;
370 let key = vec![$(Some($key),)+];
371 assert_eq!(key.len(), $chain.order);
372 assert!(map.contains_key(&key));
373 map.get(&key)
374 .unwrap()
375 }};
376 }
377
378 macro_rules! test_link_weight {
379 ($link:expr, $key:expr, $weight:expr) => {
380 let link = $link;
381 let key = $key;
382 assert!(link.contains_key(&key));
383 assert_eq!(*link.get(&key).unwrap(), $weight);
384 };
385 }
386
387 #[cfg(feature = "serde_cbor")]
388 #[test]
389 fn test_cbor_serialize() {
390 let mut chain = Chain::<u32>::new(1);
391 chain.train(vec![1, 2, 3])
392 .train(vec![2, 3, 4])
393 .train(vec![1, 3, 4]);
394 let cbor_vec = chain.to_cbor();
395 assert!(cbor_vec.is_ok());
396 let de = Chain::from_cbor(&cbor_vec.unwrap());
397 assert_eq!(de.unwrap(), chain);
398 }
399
400 #[cfg(feature = "serde_yaml")]
401 #[test]
402 fn test_yaml_serialize() {
403 let mut chain = Chain::<u32>::new(1);
404 chain.train(vec![1, 2, 3])
405 .train(vec![2, 3, 4])
406 .train(vec![1, 3, 4]);
407 let yaml_str = chain.to_yaml();
408 assert!(yaml_str.is_ok());
409 let de = Chain::from_yaml(&yaml_str.unwrap());
410 assert_eq!(de.unwrap(), chain);
411 }
412
413 #[test]
414 fn test_order1_training() {
415 let mut chain = Chain::<u32>::new(1);
416 chain.train(vec![1, 2, 3])
417 .train(vec![2, 3, 4])
418 .train(vec![1, 3, 4]);
419 let link = test_get_link!(chain, [1u32]);
420 test_link_weight!(link, Some(2u32), 1);
421 test_link_weight!(link, Some(3u32), 1);
422
423 let link = test_get_link!(chain, [2u32]);
424 test_link_weight!(link, Some(3u32), 2);
425
426 let link = test_get_link!(chain, [3u32]);
427 test_link_weight!(link, None, 1);
428 test_link_weight!(link, Some(4u32), 2);
429
430 let link = test_get_link!(chain, [4u32]);
431 test_link_weight!(link, None, 2);
432 }
433
434 #[test]
435 fn test_order2_training() {
436 let mut chain = Chain::<u32>::new(2);
437 chain.train(vec![1, 2, 3])
438 .train(vec![2, 3, 4])
439 .train(vec![1, 3, 4]);
440 let link = test_get_link!(chain, [1u32, 2u32]);
441 test_link_weight!(link, Some(3u32), 1);
442
443 let link = test_get_link!(chain, [2u32, 3u32]);
444 test_link_weight!(link, None, 1);
445 test_link_weight!(link, Some(4u32), 1);
446
447 let link = test_get_link!(chain, [3u32, 4u32]);
448 test_link_weight!(link, None, 2);
449
450 let link = test_get_link!(chain, [1u32, 3u32]);
451 test_link_weight!(link, Some(4u32), 1);
452 }
453
454 #[test]
455 fn test_order3_training() {
456 let mut chain = Chain::<u32>::new(3);
457 chain.train(vec![1, 2, 3, 4, 1, 2, 3, 4]);
458
459 let link = test_get_link!(chain, [1u32, 2u32, 3u32]);
460 test_link_weight!(link, Some(4u32), 2);
461
462 let link = test_get_link!(chain, [2u32, 3u32, 4u32]);
463 test_link_weight!(link, Some(1u32), 1);
464 test_link_weight!(link, None, 1);
465
466 let link = test_get_link!(chain, [3u32, 4u32, 1u32]);
467 test_link_weight!(link, Some(2u32), 1);
468
469 let link = test_get_link!(chain, [4u32, 1u32, 2u32]);
470 test_link_weight!(link, Some(3u32), 1);
471 }
472}