#![warn(missing_docs)]
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Eq, PartialEq, Clone, Debug, Serialize, Deserialize)]
pub struct InputData {
pub text: String,
pub meta: Option<String>,
}
impl From<String> for InputData {
fn from(text: String) -> Self {
InputData { text, meta: None }
}
}
#[derive(Serialize, Deserialize, Debug)]
struct DataMember {
state_size: usize,
}
#[derive(Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MarkovResult {
pub text: String,
pub score: u16,
pub refs: Vec<usize>,
pub tries: u16,
}
type Fragments = HashMap<String, Vec<usize>>;
#[derive(Serialize, Deserialize, Debug)]
pub struct ImportExport {
data: Vec<InputData>,
corpus: Corpus,
start_words: Fragments,
end_words: Fragments,
options: DataMember,
}
#[derive(Debug, Eq, PartialEq)]
pub enum ErrorType {
CorpusEmpty,
CorpusNotEmpty,
TriesExceeded,
}
type MarkovResultFilter = fn(&MarkovResult) -> bool;
type Corpus = HashMap<String, Fragments>;
#[derive(Serialize, Deserialize)]
pub struct Markov {
data: Vec<InputData>,
options: DataMember,
start_words: Fragments,
end_words: Fragments,
corpus: HashMap<String, Fragments>,
#[serde(skip)]
filter: Option<fn(&MarkovResult) -> bool>,
#[serde(skip)]
max_tries: u16,
}
impl Markov {
pub fn new() -> Markov {
let opts = DataMember { state_size: 2 };
Markov {
data: vec![],
options: opts,
start_words: HashMap::new(),
end_words: HashMap::new(),
corpus: HashMap::new(),
filter: None,
max_tries: 100,
}
}
pub fn from_export(export: ImportExport) -> Markov {
Markov {
data: export.data,
options: export.options,
corpus: export.corpus,
filter: None,
start_words: export.start_words,
end_words: export.end_words,
max_tries: 10,
}
}
pub fn set_state_size(&mut self, size: usize) -> Result<&mut Self, ErrorType> {
if self.start_words.len() > 0 {
return Err(ErrorType::CorpusNotEmpty);
}
self.options.state_size = size;
Ok(self)
}
pub fn add_to_corpus(&mut self, data: Vec<InputData>) {
data.iter().for_each(|o| self.data.push(o.to_owned()));
let state_size = self.options.state_size;
for item in data.iter() {
let pos = self.data.iter().position(|o| o == item).unwrap();
let words = item.text.split(' ').collect::<Vec<&str>>();
let count = words.len();
if count < self.options.state_size {
continue;
}
let start = (&words)
.iter()
.take(state_size)
.map(|s| s.to_owned())
.collect::<Vec<_>>()
.join(" ");
self.start_words.entry(start).or_insert(vec![]).push(pos);
let end = (&words)
.iter()
.skip(count - state_size)
.take(state_size)
.map(|s| s.to_owned())
.collect::<Vec<&str>>()
.join(" ");
self.end_words.entry(end).or_insert(vec![]).push(pos);
for (i, _) in words.clone().iter().enumerate() {
let curr = (&words)
.iter()
.skip(i)
.take(state_size)
.map(|s| s.to_owned())
.collect::<Vec<&str>>()
.join(" ");
let next = (&words)
.iter()
.skip(i + state_size)
.take(state_size)
.map(|s| s.to_owned())
.collect::<Vec<&str>>()
.join(" ");
if next.len() == 0 || next.split(' ').count() < state_size {
continue;
}
self.corpus
.entry(curr)
.or_insert(HashMap::new())
.entry(next)
.or_insert(vec![pos])
.push(pos);
}
}
}
pub fn set_filter(&mut self, f: MarkovResultFilter) -> &mut Self {
self.filter = Some(f);
self
}
pub fn unset_filter(&mut self) -> &mut Self {
self.filter = None;
self
}
pub fn set_max_tries(&mut self, tries: u16) -> &mut Self {
self.max_tries = tries;
self
}
pub fn generate(&self) -> Result<MarkovResult, ErrorType> {
if self.corpus.len() == 0 {
return Err(ErrorType::CorpusEmpty);
}
let max_tries = self.max_tries;
let mut tries: u16 = 0;
let mut rng = thread_rng();
for _ in 0..max_tries {
tries += 1;
let mut ended = false;
let mut references: HashSet<usize> = HashSet::new();
let mut arr = vec![self.start_words.iter().choose(&mut rng).unwrap()];
let mut score: u16 = 0;
for _ in 0..max_tries {
let block = arr[arr.len() - 1];
let fragments = match self.corpus.get(block.0) {
Some(v) => v,
None => break,
};
let state = fragments.iter().choose(&mut rng).unwrap();
arr.push(state);
state.1.iter().for_each(|o| {
references.insert(*o);
});
score += (self.corpus.get(block.0).unwrap().len() - 1) as u16;
if self.end_words.get(state.0).is_some() {
ended = true;
break;
}
}
let sentence = arr
.iter()
.map(|o| o.0.to_owned())
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_string();
let result = MarkovResult {
text: sentence,
score,
refs: references.into_iter().collect::<Vec<_>>(),
tries,
};
if !ended || (self.filter.is_some() && !self.filter.unwrap()(&result)) {
continue;
}
return Ok(result);
}
Err(ErrorType::TriesExceeded)
}
pub fn get_input_ref(self: &Self, index: usize) -> Option<&InputData> {
self.data.get(index)
}
pub fn export(self) -> ImportExport {
return ImportExport {
data: self.data,
options: self.options,
corpus: self.corpus,
start_words: self.start_words,
end_words: self.end_words,
};
}
}
#[cfg(test)]
mod tests {
use super::*;
fn get_example_data() -> Vec<InputData> {
let data: Vec<&str> = vec![
"Lorem ipsum dolor sit amet",
"Lorem ipsum duplicate start words",
"Consectetur adipiscing elit",
"Quisque tempor, erat vel lacinia imperdiet",
"Justo nisi fringilla dui",
"Egestas bibendum eros nisi ut lacus",
"fringilla dui avait annoncé une rupture avec le erat vel: il n'en est rien…",
"Fusce tincidunt tempor, erat vel lacinia vel ex pharetra pretium lacinia imperdiet",
];
data.iter()
.map(|s| InputData {
text: s.to_string(),
meta: None,
})
.collect()
}
#[test]
fn constructor_has_default_state_size() {
let markov = Markov::new();
assert!(markov.options.state_size == 2)
}
#[test]
fn set_state_size_works() {
let mut markov = Markov::new();
markov.set_state_size(3).unwrap();
assert_eq!(markov.options.state_size, 3)
}
#[test]
fn add_to_corpus_works() {
let mut markov = Markov::new();
assert_eq!(markov.corpus.len(), 0);
markov.add_to_corpus(get_example_data());
assert_eq!(markov.corpus.len(), 28)
}
#[test]
fn start_words_should_have_the_right_length() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
assert_eq!(markov.start_words.len(), 7 as usize);
}
#[test]
fn start_words_should_contain_the_right_values() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
let fragments = &markov.start_words;
assert!(fragments.iter().any(|o| o.0 == "Lorem ipsum"));
assert!(fragments.iter().any(|o| o.0 == "Consectetur adipiscing"));
assert!(fragments.iter().any(|o| o.0 == "Quisque tempor,"));
assert!(fragments.iter().any(|o| o.0 == "Justo nisi"));
assert!(fragments.iter().any(|o| o.0 == "Egestas bibendum"));
assert!(fragments.iter().any(|o| o.0 == "fringilla dui"));
assert!(fragments.iter().any(|o| o.0 == "Fusce tincidunt"));
}
#[test]
fn end_words_should_have_the_right_length() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
assert_eq!(markov.end_words.len(), 7 as usize);
}
#[test]
fn end_words_should_contain_the_right_values() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
let fragments = &markov.end_words;
assert!(fragments.iter().any(|o| o.0 == "sit amet"));
assert!(fragments.iter().any(|o| o.0 == "start words"));
assert!(fragments.iter().any(|o| o.0 == "adipiscing elit"));
assert!(fragments.iter().any(|o| o.0 == "fringilla dui"));
assert!(fragments.iter().any(|o| o.0 == "ut lacus"));
assert!(fragments.iter().any(|o| o.0 == "est rien…"));
}
#[test]
fn corpus_should_have_the_right_values_for_the_right_keys() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
let fragments = &markov.corpus.get("Lorem ipsum").unwrap();
assert!(fragments.iter().any(|f| f.0 == "dolor sit"));
assert!(fragments.iter().any(|f| f.0 == "duplicate start"));
let fragments = &markov.corpus.get("tempor, erat").unwrap();
assert!(fragments.iter().any(|f| f.0 == "vel lacinia"));
}
#[test]
fn generator_should_return_err_if_the_corpus_is_not_build() {
let markov = Markov::new();
let res = markov.generate();
assert_eq!(res.unwrap_err(), ErrorType::CorpusEmpty);
}
#[test]
fn generator_should_return_a_result_under_the_tries_limit() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
for _ in 0..10 {
let sentence = markov.generate();
assert!(sentence.unwrap().tries < 20);
}
}
#[test]
fn generator_should_return_error() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
let result = markov.set_filter(|_| false).generate();
assert_eq!(result.unwrap_err(), ErrorType::TriesExceeded);
}
#[test]
fn result_should_end_with_an_endwords_item() {
let mut markov = Markov::new();
markov.add_to_corpus(get_example_data());
for _ in 0..10 {
let result = markov.generate().unwrap();
let arr = result.text.split(' ').collect::<Vec<_>>();
let len = arr.len();
let end = arr
.into_iter()
.skip(len - 2)
.take(2)
.collect::<Vec<_>>()
.join(" ");
assert!(markov.end_words.iter().any(|f| f.0 == &end));
}
}
#[test]
fn input_data_from_string() {
let text = "foo";
let input = InputData::from(text.to_owned());
assert_eq!(input.text, "foo");
assert_eq!(input.meta, None);
let texts = vec!["foo".to_string()];
let mut markov = Markov::new();
markov.add_to_corpus(
texts
.iter()
.map(|t| t.to_owned().into())
.collect::<Vec<_>>(),
);
}
}