use rand::{seq::IndexedRandom, Rng};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read as _;
use substrings::{Interner, Substr, WhitespaceSplitIterator};
mod substrings;
type Bigram = (Substr, Substr);
#[derive(Debug, Default)]
pub struct WurstsalatGeneratorPro {
string: String,
map: HashMap<Bigram, Vec<Substr>>,
keys: Vec<Bigram>,
}
impl WurstsalatGeneratorPro {
pub fn new() -> Self {
Default::default()
}
fn learn(string: String, mut breaks: &[usize]) -> Self {
let mut interner = Interner::new();
let words = WhitespaceSplitIterator::new(&string);
let mut map = HashMap::<Bigram, Vec<Substr>>::new();
for window in words.collect::<Vec<_>>().windows(3) {
let (a, b, c) = (window[0], window[1], window[2]);
let mut skip_triple = false;
while !breaks.is_empty() && breaks[0] <= c.start {
if breaks[0] <= a.start {
breaks = &breaks[1..];
} else {
skip_triple = true;
break;
}
}
if !skip_triple {
map.entry((interner.intern(&string, a), interner.intern(&string, b)))
.or_default()
.push(interner.intern(&string, c));
}
}
let mut keys = map.keys().copied().collect::<Vec<_>>();
keys.sort_unstable_by_key(|(s1, s2)| {
(&string[s1.start..s1.end], &string[s2.start..s2.end])
});
Self { string, keys, map }
}
pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
let mut s = String::new();
let mut breaks = Vec::new();
for source in files.iter() {
let mut f = File::open(source)?;
f.read_to_string(&mut s)?;
breaks.push(s.len());
s.push(' ');
}
Ok(Self::learn(s, &breaks))
}
pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
let initial_bigram = self.keys.choose(&mut rng).cloned().unwrap_or_default();
self.iter_with_rng_from(rng, initial_bigram)
}
fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
Words {
string: self.string.as_str(),
map: &self.map,
rng,
keys: &self.keys,
state: from,
}
}
}
#[derive(Clone)]
pub struct Words<'a, R: Rng> {
string: &'a str,
map: &'a HashMap<Bigram, Vec<Substr>>,
rng: R,
keys: &'a [Bigram],
state: Bigram,
}
impl<'a, R: Rng> Iterator for Words<'a, R> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.map.is_empty() {
return None;
}
let result = self.state.0.extract_str(self.string);
let next_words = self.map.get(&self.state).unwrap_or_else(|| {
self.state = *self.keys.choose(&mut self.rng).unwrap();
&self.map[&self.state]
});
let next = *next_words.choose(&mut self.rng).unwrap();
self.state = (self.state.1, next);
Some(result)
}
}
fn is_ascii_punctuation(c: char) -> bool {
c.is_ascii_punctuation()
}
fn capitalize(word: &str) -> String {
let idx = match word.chars().next() {
Some(c) => c.len_utf8(),
None => 0,
};
let mut result = String::with_capacity(word.len());
result.push_str(&word[..idx].to_uppercase());
result.push_str(&word[idx..]);
result
}
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
match words.next() {
None => String::new(),
Some(word) => {
let punctuation: &[char] = &['.', '!', '?'];
let mut sentence = capitalize(word);
let mut needs_cap = sentence.ends_with(punctuation);
for word in words {
sentence.push(' ');
if needs_cap {
sentence.push_str(&capitalize(word));
} else {
sentence.push_str(word);
}
needs_cap = word.ends_with(punctuation);
}
if !sentence.ends_with(punctuation) {
let idx = sentence.trim_end_matches(is_ascii_punctuation).len();
sentence.truncate(idx);
sentence.push('.');
}
sentence
}
}
}
#[cfg(test)]
mod tests {
use super::super::gobbledygook::GobbledyGook;
use super::*;
#[test]
fn test_load_error() {
let result = WurstsalatGeneratorPro::learn_from_files(&["/does-not-exist".to_string()]);
assert!(result.is_err());
}
#[test]
fn test_load_ok() {
let result = WurstsalatGeneratorPro::learn_from_files(&["README.md".to_string()]);
assert!(result.is_ok());
}
#[test]
fn test_generate() {
let wurstsalat =
WurstsalatGeneratorPro::learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
.unwrap();
let mut rng = GobbledyGook::for_url("/test");
let words = wurstsalat.generate(&mut rng).take(1);
let output = join_words(words);
assert_eq!(output, "Voluptate.");
}
}