Skip to main content

markovify_rs/
chain.rs

1//! Markov chain implementation
2
3use crate::errors::{MarkovError, Result};
4use fxhash::FxHashMap;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7
8/// Special token indicating the beginning of a sequence
9pub const BEGIN: &str = "___BEGIN__";
10/// Special token indicating the end of a sequence
11pub const END: &str = "___END__";
12
13/// Cumulative frequency data for efficient random selection
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CompiledNext {
16    pub words: Vec<String>,
17    pub cumulative_weights: Vec<usize>,
18}
19
20/// A Markov chain representing processes that have both beginnings and ends.
21///
22/// The chain is represented as a HashMap where keys are states (tuples of words)
23/// and values are HashMaps of possible next words with their frequencies.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Chain {
26    state_size: usize,
27    model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
28    compiled: bool,
29    #[serde(skip)]
30    compiled_model: FxHashMap<Vec<String>, CompiledNext>,
31    #[serde(skip)]
32    begin_choices: Option<Vec<String>>,
33    #[serde(skip)]
34    begin_cumdist: Option<Vec<usize>>,
35}
36
37impl Chain {
38    /// Create a new Markov chain from a corpus
39    ///
40    /// # Arguments
41    /// * `corpus` - A list of runs, where each run is a sequence of items (e.g., words in a sentence)
42    /// * `state_size` - The number of items used to represent the chain's state
43    pub fn new(corpus: &[Vec<String>], state_size: usize) -> Self {
44        let model = Self::build(corpus, state_size);
45        let mut chain = Chain {
46            state_size,
47            model,
48            compiled: false,
49            compiled_model: FxHashMap::default(),
50            begin_choices: None,
51            begin_cumdist: None,
52        };
53        chain.precompute_begin_state();
54        chain
55    }
56
57    /// Build the Markov model from a corpus
58    fn build(
59        corpus: &[Vec<String>],
60        state_size: usize,
61    ) -> FxHashMap<Vec<String>, FxHashMap<String, usize>> {
62        let mut model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = FxHashMap::default();
63
64        for run in corpus {
65            let mut items: Vec<String> = vec![BEGIN.to_string(); state_size];
66            items.extend(run.iter().cloned());
67            items.push(END.to_string());
68
69            for i in 0..=run.len() {
70                let state: Vec<String> = items[i..i + state_size].to_vec();
71                let follow = items[i + state_size].clone();
72
73                let next_dict = model.entry(state).or_default();
74                *next_dict.entry(follow).or_insert(0) += 1;
75            }
76        }
77
78        model
79    }
80
81    /// Precompute the beginning state for faster sentence generation
82    fn precompute_begin_state(&mut self) {
83        let begin_state: Vec<String> = vec![BEGIN.to_string(); self.state_size];
84        if let Some(next_dict) = self.model.get(&begin_state) {
85            let (choices, cumdist) = Self::compile_next_dict(next_dict);
86            self.begin_choices = Some(choices);
87            self.begin_cumdist = Some(cumdist);
88        }
89    }
90
91    /// Compile a next dictionary for efficient random selection
92    fn compile_next_dict(next_dict: &FxHashMap<String, usize>) -> (Vec<String>, Vec<usize>) {
93        let mut words = Vec::with_capacity(next_dict.len());
94        let mut cumulative_weights = Vec::with_capacity(next_dict.len());
95        let mut cumsum = 0;
96
97        for (word, &weight) in next_dict.iter() {
98            words.push(word.clone());
99            cumsum += weight;
100            cumulative_weights.push(cumsum);
101        }
102
103        (words, cumulative_weights)
104    }
105
106    /// Compile the chain for faster generation
107    ///
108    /// This converts the frequency dictionaries into cumulative frequency lists
109    /// for more efficient random selection.
110    pub fn compile(&self) -> Self {
111        let mut compiled_model: FxHashMap<Vec<String>, CompiledNext> = FxHashMap::default();
112
113        for (state, next_dict) in &self.model {
114            let (words, cumulative_weights) = Self::compile_next_dict(next_dict);
115            compiled_model.insert(
116                state.clone(),
117                CompiledNext {
118                    words,
119                    cumulative_weights,
120                },
121            );
122        }
123
124        Chain {
125            state_size: self.state_size,
126            model: self.model.clone(),
127            compiled: true,
128            compiled_model,
129            begin_choices: self.begin_choices.clone(),
130            begin_cumdist: self.begin_cumdist.clone(),
131        }
132    }
133
134    /// Choose the next item given the current state
135    fn move_state(&self, state: &[String]) -> Option<String> {
136        let (choices, cumdist) = if self.compiled {
137            if let Some(compiled) = self.compiled_model.get(state) {
138                (&compiled.words, &compiled.cumulative_weights)
139            } else {
140                return None;
141            }
142        } else if state.iter().all(|s| s == BEGIN) {
143            if let (Some(choices), Some(cumdist)) = (&self.begin_choices, &self.begin_cumdist) {
144                (choices, cumdist)
145            } else {
146                return None;
147            }
148        } else {
149            if let Some(next_dict) = self.model.get(state) {
150                let (choices, cumdist) = Self::compile_next_dict(next_dict);
151                // For uncompiled chains, we compute on the fly
152                return Self::select_random(&choices, &cumdist);
153            } else {
154                return None;
155            }
156        };
157
158        if cumdist.is_empty() {
159            return None;
160        }
161
162        Self::select_random(choices, cumdist)
163    }
164
165    /// Select a random item based on cumulative weights
166    fn select_random(choices: &[String], cumdist: &[usize]) -> Option<String> {
167        if cumdist.is_empty() {
168            return None;
169        }
170
171        let mut rng = rand::thread_rng();
172        let r = rng.gen_range(0..cumdist[cumdist.len() - 1]);
173
174        // Binary search for the selection
175        let idx = cumdist.partition_point(|&x| x <= r);
176
177        if idx < choices.len() {
178            Some(choices[idx].clone())
179        } else {
180            Some(choices[choices.len() - 1].clone())
181        }
182    }
183
184    /// Generate items from the chain
185    ///
186    /// Returns an iterator that yields successive items until the END state is reached.
187    pub fn gen(&self, init_state: Option<&[String]>) -> ChainGenerator<'_> {
188        let state = init_state
189            .map(|s| s.to_vec())
190            .unwrap_or_else(|| vec![BEGIN.to_string(); self.state_size]);
191
192        ChainGenerator {
193            chain: self,
194            state,
195            done: false,
196        }
197    }
198
199    /// Walk the chain and return a complete sequence
200    ///
201    /// Returns a vector representing a single run of the Markov model.
202    pub fn walk(&self, init_state: Option<&[String]>) -> Vec<String> {
203        self.gen(init_state).collect()
204    }
205
206    /// Get the state size
207    pub fn state_size(&self) -> usize {
208        self.state_size
209    }
210
211    /// Get the model (for inspection or combination)
212    pub fn model(&self) -> &FxHashMap<Vec<String>, FxHashMap<String, usize>> {
213        &self.model
214    }
215
216    /// Check if the chain is compiled
217    pub fn is_compiled(&self) -> bool {
218        self.compiled
219    }
220
221    /// Serialize the chain to JSON
222    pub fn to_json(&self) -> Result<String> {
223        let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = self
224            .model
225            .iter()
226            .map(|(k, v)| (k.clone(), v.clone()))
227            .collect();
228        Ok(serde_json::to_string(&items)?)
229    }
230
231    /// Deserialize a chain from JSON
232    pub fn from_json(json_str: &str) -> Result<Self> {
233        let items: Vec<(Vec<String>, FxHashMap<String, usize>)> = serde_json::from_str(json_str)?;
234
235        if items.is_empty() {
236            return Err(MarkovError::ModelFormatError("Empty model".to_string()));
237        }
238
239        let state_size = items[0].0.len();
240        let model: FxHashMap<Vec<String>, FxHashMap<String, usize>> = items.into_iter().collect();
241
242        let mut chain = Chain {
243            state_size,
244            model,
245            compiled: false,
246            compiled_model: FxHashMap::default(),
247            begin_choices: None,
248            begin_cumdist: None,
249        };
250        chain.precompute_begin_state();
251        Ok(chain)
252    }
253
254    /// Create a chain from a pre-built model (used for combining models)
255    pub fn from_combined_model(
256        model: FxHashMap<Vec<String>, FxHashMap<String, usize>>,
257        state_size: usize,
258    ) -> Self {
259        let mut chain = Chain {
260            state_size,
261            model,
262            compiled: false,
263            compiled_model: FxHashMap::default(),
264            begin_choices: None,
265            begin_cumdist: None,
266        };
267        chain.precompute_begin_state();
268        chain
269    }
270}
271
272/// Iterator for generating sequences from a Markov chain
273pub struct ChainGenerator<'a> {
274    chain: &'a Chain,
275    state: Vec<String>,
276    done: bool,
277}
278
279impl<'a> Iterator for ChainGenerator<'a> {
280    type Item = String;
281
282    fn next(&mut self) -> Option<Self::Item> {
283        if self.done {
284            return None;
285        }
286
287        if let Some(next_word) = self.chain.move_state(&self.state) {
288            if next_word == END {
289                self.done = true;
290                return None;
291            }
292
293            // Update state
294            self.state.remove(0);
295            self.state.push(next_word.clone());
296            Some(next_word)
297        } else {
298            self.done = true;
299            None
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_chain_creation() {
310        let corpus = vec![
311            vec!["hello".to_string(), "world".to_string()],
312            vec!["hello".to_string(), "rust".to_string()],
313        ];
314        let chain = Chain::new(&corpus, 1);
315        assert_eq!(chain.state_size(), 1);
316    }
317
318    #[test]
319    fn test_chain_walk() {
320        let corpus = vec![vec![
321            "the".to_string(),
322            "cat".to_string(),
323            "sat".to_string(),
324        ]];
325        let chain = Chain::new(&corpus, 1);
326        let result = chain.walk(None);
327        assert!(!result.is_empty());
328    }
329
330    #[test]
331    fn test_chain_json_serialization() {
332        let corpus = vec![vec!["hello".to_string(), "world".to_string()]];
333        let chain = Chain::new(&corpus, 1);
334        let json = chain.to_json().unwrap();
335        let restored = Chain::from_json(&json).unwrap();
336        assert_eq!(chain.state_size(), restored.state_size());
337    }
338}