extern crate rand;
use std::cell::RefCell;
use std::collections::HashMap;
use rand::Rng;
pub type Bigram<'a> = (&'a str, &'a str);
pub struct MarkovChain<'a, R: Rng> {
map: HashMap<Bigram<'a>, Vec<&'a str>>,
keys: Vec<Bigram<'a>>,
rng: R,
}
impl<'a> MarkovChain<'a, rand::ThreadRng> {
pub fn new() -> MarkovChain<'a, rand::ThreadRng> {
MarkovChain::new_with_rng(rand::thread_rng())
}
}
impl<'a, R: Rng> MarkovChain<'a, R> {
pub fn new_with_rng(rng: R) -> MarkovChain<'a, R> {
MarkovChain {
map: HashMap::new(),
keys: Vec::new(),
rng: rng,
}
}
pub fn learn(&mut self, sentence: &'a str) {
let words = sentence.split_whitespace().collect::<Vec<&str>>();
for window in words.windows(3) {
let (a, b, c) = (window[0], window[1], window[2]);
self.map.entry((a, b)).or_insert_with(Vec::new).push(c);
}
self.keys = self.map.keys().cloned().collect();
self.keys.sort();
}
#[inline]
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn words(&self, state: Bigram<'a>) -> Option<&Vec<&str>> {
self.map.get(&state)
}
pub fn generate(&mut self, n: usize) -> String {
join_words(self.iter().take(n))
}
pub fn generate_from(&mut self, n: usize, from: Bigram<'a>) -> String {
join_words(self.iter_from(from).take(n))
}
pub fn iter(&mut self) -> Words {
let state = if self.is_empty() {
("", "")
} else {
*choose(&mut self.rng, &self.keys).unwrap()
};
Words {
map: &self.map,
rng: &mut self.rng,
keys: &self.keys,
state: state,
}
}
pub fn iter_from(&mut self, from: Bigram<'a>) -> Words {
Words {
map: &self.map,
rng: &mut self.rng,
keys: &self.keys,
state: from,
}
}
}
pub struct Words<'a> {
map: &'a HashMap<Bigram<'a>, Vec<&'a str>>,
rng: &'a mut rand::Rng,
keys: &'a Vec<Bigram<'a>>,
state: Bigram<'a>,
}
impl<'a> Iterator for Words<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.map.is_empty() {
return None;
}
let result = Some(self.state.0);
while !self.map.contains_key(&self.state) {
self.state = *choose(self.rng, self.keys).unwrap();
}
let next_words = &self.map[&self.state];
let next = choose(self.rng, next_words).unwrap();
self.state = (self.state.1, next);
result
}
}
fn choose<'a, T>(rng: &mut Rng, values: &'a [T]) -> Option<&'a T> {
if values.is_empty() {
None
} else {
let idx = (values.len() as f32 * rng.next_f32()) as usize;
Some(&values[idx])
}
}
fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
match words.next() {
None => String::new(),
Some(word) => {
let mut sentence = String::from(word);
for word in words {
sentence.push(' ');
sentence.push_str(word);
}
sentence
}
}
}
pub const LOREM_IPSUM: &'static str = include_str!("lorem-ipsum.txt");
pub const LIBER_PRIMUS: &'static str = include_str!("liber-primus.txt");
thread_local! {
static LOREM_IPSUM_CHAIN: RefCell<MarkovChain<'static, rand::ThreadRng>> = {
let mut chain = MarkovChain::new();
chain.learn(LOREM_IPSUM);
chain.learn(LIBER_PRIMUS);
RefCell::new(chain)
}
}
pub fn lipsum(n: usize) -> String {
LOREM_IPSUM_CHAIN.with(|cell| {
let mut chain = cell.borrow_mut();
chain.generate_from(n, ("Lorem", "ipsum"))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn starts_with_lorem_ipsum() {
assert_eq!(&lipsum(10)[..11], "Lorem ipsum");
}
#[test]
fn generate_zero_words() {
assert_eq!(lipsum(0).split_whitespace().count(), 0);
}
#[test]
fn generate_one_word() {
assert_eq!(lipsum(1).split_whitespace().count(), 1);
}
#[test]
fn generate_two_words() {
assert_eq!(lipsum(2).split_whitespace().count(), 2);
}
#[test]
fn empty_chain() {
let mut chain = MarkovChain::new();
assert_eq!(chain.generate(10), "");
}
#[test]
fn generate_from() {
let mut chain = MarkovChain::new();
chain.learn("red orange yellow green blue indigo violet");
assert_eq!(chain.generate_from(5, ("orange", "yellow")),
"orange yellow green blue indigo");
}
#[test]
fn generate_last_bigram() {
let mut chain = MarkovChain::new();
chain.learn("xxx yyy zzz");
assert!(chain.generate_from(3, ("xxx", "yyy")) != "xxx yyy zzz");
}
#[test]
fn generate_from_no_panic() {
let mut chain = MarkovChain::new();
chain.learn("foo bar baz");
chain.generate_from(3, ("xxx", "yyy"));
}
#[test]
fn chain_map() {
let mut chain = MarkovChain::new();
chain.learn("foo bar baz quuz");
let map = &chain.map;
assert_eq!(map.len(), 2);
assert_eq!(map[&("foo", "bar")], vec!["baz"]);
assert_eq!(map[&("bar", "baz")], vec!["quuz"]);
}
#[test]
fn new_with_rng() {
extern crate rand;
use rand::SeedableRng;
let rng = rand::XorShiftRng::from_seed([1, 2, 3, 4]);
let mut chain = MarkovChain::new_with_rng(rng);
chain.learn("foo bar x y z");
chain.learn("foo bar a b c");
assert_eq!(chain.generate(15), "a b b x y b x y x y x y bar x y");
}
}