#[cfg(test)]
mod tests;
use std::collections::HashMap;
#[cfg(feature = "serialization")]
use std::fmt;
use rand::seq::SliceRandom;
#[cfg(feature = "serialization")]
use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serialization")]
const KEY_NO_WORD: &str = "\n";
#[cfg(feature = "serialization")]
const KEY_SEPARATOR: &str = " ";
#[derive(Debug, Hash, Clone, PartialEq, Eq)]
pub struct ChainKey(Vec<Option<String>>);
impl ChainKey {
pub fn blank(order: usize) -> Self {
ChainKey(vec![None; order])
}
pub fn from_vec(vec: Vec<Option<String>>) -> Self {
ChainKey(vec)
}
pub fn to_vec(self) -> Vec<Option<String>> {
self.0
}
pub fn advance(&mut self, next_word: &Option<String>) {
self.0 = self.0[1..self.0.len()].to_vec();
self.0.push(next_word.clone());
}
#[cfg(feature = "serialization")]
fn to_string(&self) -> String {
let mut result = String::new();
let mut first = true;
for word in &self.0 {
if first {
first = false;
} else {
result.push_str(KEY_SEPARATOR);
}
if let Some(word) = word {
result.push_str(&word);
} else {
result.push_str(KEY_NO_WORD);
}
}
result
}
#[cfg(feature = "serialization")]
fn from_str(string: &str) -> Self {
let mut result = Vec::new();
for word in string.split(KEY_SEPARATOR) {
if word == KEY_NO_WORD {
result.push(None);
} else {
result.push(Some(word.to_string()));
}
}
ChainKey(result)
}
}
#[cfg(feature = "serialization")]
impl Serialize for ChainKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
#[cfg(feature = "serialization")]
struct ChainKeyVisitor;
#[cfg(feature = "serialization")]
impl<'de> Visitor<'de> for ChainKeyVisitor {
type Value = ChainKey;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ChainKey::from_str(value))
}
}
#[cfg(feature = "serialization")]
impl<'de> Deserialize<'de> for ChainKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ChainKeyVisitor)
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct Chain {
map: HashMap<ChainKey, Vec<Option<String>>>,
order: usize,
}
impl Chain {
pub fn new(order: usize) -> Self {
Chain {
map: HashMap::new(),
order,
}
}
pub fn train(&mut self, string: &str) {
let mut words = vec![None; self.order];
for word in string.split_whitespace() {
words.push(Some(word.to_string()));
}
words.push(None);
for window in words.windows(self.order + 1) {
let key = &window[..self.order];
let word = &window[self.order];
let map_entry = self
.map
.entry(ChainKey::from_vec(key.to_vec()))
.or_insert(Vec::new());
map_entry.push(word.clone());
}
}
pub fn generate(&self) -> Option<String> {
let seed = ChainKey::blank(self.order);
self.generate_from_seed(&seed)
}
pub fn generate_from_seed(&self, seed: &ChainKey) -> Option<String> {
if !self.map.contains_key(seed) {
return None;
}
let mut rng = rand::thread_rng();
let mut result: Vec<String> = Vec::new();
let mut cursor = seed.clone();
loop {
let possible_words = &self.map[&cursor];
let next_word = possible_words.choose(&mut rng).unwrap();
if let Some(next_word) = next_word {
result.push(next_word.clone());
} else {
break;
}
cursor.advance(next_word);
}
Some(result.join(" "))
}
#[cfg(feature = "serialization")]
pub fn to_json(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
#[cfg(feature = "serialization")]
pub fn from_json(json: &str) -> serde_json::Result<Self> {
serde_json::from_str(json)
}
}