1use crate::errors::{MarkovError, Result};
4use fxhash::FxHashMap;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7
8pub const BEGIN: &str = "___BEGIN__";
10pub const END: &str = "___END__";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CompiledNext {
16 pub words: Vec<String>,
17 pub cumulative_weights: Vec<usize>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Chain {
26 state_size: usize,
27 model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
28 compiled: bool,
29 #[serde(skip)]
30 compiled_model: FxHashMap<Vec<String>, CompiledNext>,
31 #[serde(skip)]
32 begin_choices: Option<Vec<String>>,
33 #[serde(skip)]
34 begin_cumdist: Option<Vec<usize>>,
35}
36
37impl Chain {
38 pub fn new(corpus: &[Vec<String>], state_size: usize) -> Self {
44 let model = Self::build(corpus, state_size);
45 let mut chain = Chain {
46 state_size,
47 model,
48 compiled: false,
49 compiled_model: FxHashMap::default(),
50 begin_choices: None,
51 begin_cumdist: None,
52 };
53 chain.precompute_begin_state();
54 chain
55 }
56
57 fn build(
59 corpus: &[Vec<String>],
60 state_size: usize,
61 ) -> FxHashMap<Vec<String>, FxHashMap<String, usize>> {
62 let mut model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = FxHashMap::default();
63
64 for run in corpus {
65 let mut items: Vec<String> = vec![BEGIN.to_string(); state_size];
66 items.extend(run.iter().cloned());
67 items.push(END.to_string());
68
69 for i in 0..=run.len() {
70 let state: Vec<String> = items[i..i + state_size].to_vec();
71 let follow = items[i + state_size].clone();
72
73 let next_dict = model.entry(state).or_default();
74 *next_dict.entry(follow).or_insert(0) += 1;
75 }
76 }
77
78 model
79 }
80
81 fn precompute_begin_state(&mut self) {
83 let begin_state: Vec<String> = vec![BEGIN.to_string(); self.state_size];
84 if let Some(next_dict) = self.model.get(&begin_state) {
85 let (choices, cumdist) = Self::compile_next_dict(next_dict);
86 self.begin_choices = Some(choices);
87 self.begin_cumdist = Some(cumdist);
88 }
89 }
90
91 fn compile_next_dict(next_dict: &FxHashMap<String, usize>) -> (Vec<String>, Vec<usize>) {
93 let mut words = Vec::with_capacity(next_dict.len());
94 let mut cumulative_weights = Vec::with_capacity(next_dict.len());
95 let mut cumsum = 0;
96
97 for (word, &weight) in next_dict.iter() {
98 words.push(word.clone());
99 cumsum += weight;
100 cumulative_weights.push(cumsum);
101 }
102
103 (words, cumulative_weights)
104 }
105
106 pub fn compile(&self) -> Self {
111 let mut compiled_model: FxHashMap<Vec<String>, CompiledNext> = FxHashMap::default();
112
113 for (state, next_dict) in &self.model {
114 let (words, cumulative_weights) = Self::compile_next_dict(next_dict);
115 compiled_model.insert(
116 state.clone(),
117 CompiledNext {
118 words,
119 cumulative_weights,
120 },
121 );
122 }
123
124 Chain {
125 state_size: self.state_size,
126 model: self.model.clone(),
127 compiled: true,
128 compiled_model,
129 begin_choices: self.begin_choices.clone(),
130 begin_cumdist: self.begin_cumdist.clone(),
131 }
132 }
133
134 fn move_state(&self, state: &[String]) -> Option<String> {
136 let (choices, cumdist) = if self.compiled {
137 if let Some(compiled) = self.compiled_model.get(state) {
138 (&compiled.words, &compiled.cumulative_weights)
139 } else {
140 return None;
141 }
142 } else if state.iter().all(|s| s == BEGIN) {
143 if let (Some(choices), Some(cumdist)) = (&self.begin_choices, &self.begin_cumdist) {
144 (choices, cumdist)
145 } else {
146 return None;
147 }
148 } else {
149 if let Some(next_dict) = self.model.get(state) {
150 let (choices, cumdist) = Self::compile_next_dict(next_dict);
151 return Self::select_random(&choices, &cumdist);
153 } else {
154 return None;
155 }
156 };
157
158 if cumdist.is_empty() {
159 return None;
160 }
161
162 Self::select_random(choices, cumdist)
163 }
164
165 fn select_random(choices: &[String], cumdist: &[usize]) -> Option<String> {
167 if cumdist.is_empty() {
168 return None;
169 }
170
171 let mut rng = rand::thread_rng();
172 let r = rng.gen_range(0..cumdist[cumdist.len() - 1]);
173
174 let idx = cumdist.partition_point(|&x| x <= r);
176
177 if idx < choices.len() {
178 Some(choices[idx].clone())
179 } else {
180 Some(choices[choices.len() - 1].clone())
181 }
182 }
183
184 pub fn gen(&self, init_state: Option<&[String]>) -> ChainGenerator<'_> {
188 let state = init_state
189 .map(|s| s.to_vec())
190 .unwrap_or_else(|| vec![BEGIN.to_string(); self.state_size]);
191
192 ChainGenerator {
193 chain: self,
194 state,
195 done: false,
196 }
197 }
198
199 pub fn walk(&self, init_state: Option<&[String]>) -> Vec<String> {
203 self.gen(init_state).collect()
204 }
205
206 pub fn state_size(&self) -> usize {
208 self.state_size
209 }
210
211 pub fn model(&self) -> &FxHashMap<Vec<String>, FxHashMap<String, usize>> {
213 &self.model
214 }
215
216 pub fn is_compiled(&self) -> bool {
218 self.compiled
219 }
220
221 pub fn to_json(&self) -> Result<String> {
223 let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = self
224 .model
225 .iter()
226 .map(|(k, v)| (k.clone(), v.clone()))
227 .collect();
228 Ok(serde_json::to_string(&items)?)
229 }
230
231 pub fn from_json(json_str: &str) -> Result<Self> {
233 let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = serde_json::from_str(json_str)?;
234
235 if items.is_empty() {
236 return Err(MarkovError::ModelFormatError("Empty model".to_string()));
237 }
238
239 let state_size = items[0].0.len();
240 let model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = items.into_iter().collect();
241
242 let mut chain = Chain {
243 state_size,
244 model,
245 compiled: false,
246 compiled_model: FxHashMap::default(),
247 begin_choices: None,
248 begin_cumdist: None,
249 };
250 chain.precompute_begin_state();
251 Ok(chain)
252 }
253
254 pub fn from_combined_model(
256 model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
257 state_size: usize,
258 ) -> Self {
259 let mut chain = Chain {
260 state_size,
261 model,
262 compiled: false,
263 compiled_model: FxHashMap::default(),
264 begin_choices: None,
265 begin_cumdist: None,
266 };
267 chain.precompute_begin_state();
268 chain
269 }
270}
271
272pub struct ChainGenerator<'a> {
274 chain: &'a Chain,
275 state: Vec<String>,
276 done: bool,
277}
278
279impl<'a> Iterator for ChainGenerator<'a> {
280 type Item = String;
281
282 fn next(&mut self) -> Option<Self::Item> {
283 if self.done {
284 return None;
285 }
286
287 if let Some(next_word) = self.chain.move_state(&self.state) {
288 if next_word == END {
289 self.done = true;
290 return None;
291 }
292
293 self.state.remove(0);
295 self.state.push(next_word.clone());
296 Some(next_word)
297 } else {
298 self.done = true;
299 None
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_chain_creation() {
310 let corpus = vec![
311 vec!["hello".to_string(), "world".to_string()],
312 vec!["hello".to_string(), "rust".to_string()],
313 ];
314 let chain = Chain::new(&corpus, 1);
315 assert_eq!(chain.state_size(), 1);
316 }
317
318 #[test]
319 fn test_chain_walk() {
320 let corpus = vec![vec![
321 "the".to_string(),
322 "cat".to_string(),
323 "sat".to_string(),
324 ]];
325 let chain = Chain::new(&corpus, 1);
326 let result = chain.walk(None);
327 assert!(!result.is_empty());
328 }
329
330 #[test]
331 fn test_chain_json_serialization() {
332 let corpus = vec![vec!["hello".to_string(), "world".to_string()]];
333 let chain = Chain::new(&corpus, 1);
334 let json = chain.to_json().unwrap();
335 let restored = Chain::from_json(&json).unwrap();
336 assert_eq!(chain.state_size(), restored.state_size());
337 }
338}