1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5
6#[cfg(feature = "serialization")]
7use std::fmt;
8
9use rand::seq::SliceRandom;
10
11#[cfg(feature = "serialization")]
12use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
13
14#[cfg(feature = "serialization")]
15const KEY_NO_WORD: &str = "\n";
16#[cfg(feature = "serialization")]
17const KEY_SEPARATOR: &str = " ";
18
19#[derive(Debug, Hash, Clone, PartialEq, Eq)]
21pub struct ChainKey(Vec<Option<String>>);
22
23impl ChainKey {
24 pub fn blank(order: usize) -> Self {
25 ChainKey(vec![None; order])
26 }
27
28 pub fn from_vec(vec: Vec<Option<String>>) -> Self {
29 ChainKey(vec)
30 }
31
32 pub fn to_vec(self) -> Vec<Option<String>> {
33 self.0
34 }
35
36 pub fn advance(&mut self, next_word: &Option<String>) {
37 self.0 = self.0[1..self.0.len()].to_vec();
38 self.0.push(next_word.clone());
39 }
40
41 #[cfg(feature = "serialization")]
42 fn to_string(&self) -> String {
43 let mut result = String::new();
44
45 let mut first = true;
46
47 for word in &self.0 {
48 if first {
49 first = false;
50 } else {
51 result.push_str(KEY_SEPARATOR);
52 }
53
54 if let Some(word) = word {
55 result.push_str(&word);
56 } else {
57 result.push_str(KEY_NO_WORD);
58 }
59 }
60
61 result
62 }
63
64 #[cfg(feature = "serialization")]
66 fn from_str(string: &str) -> Self {
67 let mut result = Vec::new();
68
69 for word in string.split(KEY_SEPARATOR) {
70 if word == KEY_NO_WORD {
71 result.push(None);
72 } else {
73 result.push(Some(word.to_string()));
74 }
75 }
76
77 ChainKey(result)
78 }
79}
80
81#[cfg(feature = "serialization")]
82impl Serialize for ChainKey {
83 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 serializer.serialize_str(&self.to_string())
88 }
89}
90
91#[cfg(feature = "serialization")]
92struct ChainKeyVisitor;
93
94#[cfg(feature = "serialization")]
95impl<'de> Visitor<'de> for ChainKeyVisitor {
96 type Value = ChainKey;
97
98 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
99 formatter.write_str("a string")
100 }
101
102 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
103 where
104 E: serde::de::Error,
105 {
106 Ok(ChainKey::from_str(value))
107 }
108}
109
110#[cfg(feature = "serialization")]
111impl<'de> Deserialize<'de> for ChainKey {
112 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
113 where
114 D: Deserializer<'de>,
115 {
116 deserializer.deserialize_str(ChainKeyVisitor)
117 }
118}
119
120#[derive(Clone, Debug)]
122#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
123pub struct Chain {
124 map: HashMap<ChainKey, Vec<Option<String>>>,
126 order: usize,
127}
128
129impl Chain {
130 pub fn new(order: usize) -> Self {
131 Chain {
132 map: HashMap::new(),
133 order,
134 }
135 }
136
137 pub fn train(&mut self, string: &str) {
138 let mut words = vec![None; self.order];
141
142 for word in string.split_whitespace() {
143 words.push(Some(word.to_string()));
144 }
145
146 words.push(None);
147
148 for window in words.windows(self.order + 1) {
151 let key = &window[..self.order];
152 let word = &window[self.order];
153
154 let map_entry = self
155 .map
156 .entry(ChainKey::from_vec(key.to_vec()))
157 .or_insert(Vec::new());
158 map_entry.push(word.clone());
159 }
160 }
161
162 pub fn generate(&self) -> Option<String> {
164 let seed = ChainKey::blank(self.order);
167
168 self.generate_from_seed(&seed)
169 }
170
171 pub fn generate_from_seed(&self, seed: &ChainKey) -> Option<String> {
175 if !self.map.contains_key(seed) {
176 return None;
177 }
178
179 let mut rng = rand::thread_rng();
180 let mut result: Vec<String> = Vec::new();
181
182 let mut cursor = seed.clone();
183
184 loop {
185 let possible_words = &self.map[&cursor];
186
187 let next_word = possible_words.choose(&mut rng).unwrap();
190
191 if let Some(next_word) = next_word {
192 result.push(next_word.clone());
193 } else {
194 break;
196 }
197
198 cursor.advance(next_word);
201 }
202
203 Some(result.join(" "))
204 }
205
206 #[cfg(feature = "serialization")]
208 pub fn to_json(&self) -> serde_json::Result<String> {
209 serde_json::to_string(self)
210 }
211
212 #[cfg(feature = "serialization")]
214 pub fn from_json(json: &str) -> serde_json::Result<Self> {
215 serde_json::from_str(json)
216 }
217}