Skip to main content

narrative_engine/core/
markov.rs

1/// Markov chain phrase generator — training, serialization, and generation.
2use rand::distributions::WeightedIndex;
3use rand::prelude::Distribution;
4use rand::rngs::StdRng;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use thiserror::Error;
8
9/// Transition table mapping n-gram prefixes to weighted next-token options.
10type TransitionTable = HashMap<Vec<String>, Vec<(String, u32)>>;
11
12#[derive(Debug, Error)]
13pub enum MarkovError {
14    #[error("no data for generation (model is empty or tag has no data)")]
15    NoData,
16    #[error("no sentence start found")]
17    NoSentenceStart,
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20    #[error("RON deserialization error: {0}")]
21    Ron(#[from] ron::error::SpannedError),
22}
23
24/// Special token marking sentence start.
25const SENTENCE_START: &str = "<S>";
26/// Special token marking sentence end.
27const SENTENCE_END: &str = "</S>";
28
29/// Punctuation characters that are tokenized as separate tokens.
30const SENTENCE_ENDERS: &[char] = &['.', '!', '?'];
31const PUNCTUATION: &[char] = &['.', '!', '?', ',', ';', ':', '"', '\''];
32
33/// A trained Markov model storing n-gram probability tables.
34#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct MarkovModel {
36    /// N-gram depth (e.g., 2 for bigrams, 3 for trigrams).
37    pub n: usize,
38    /// Transition table: n-gram prefix → [(next_token, count)].
39    pub transitions: TransitionTable,
40    /// Tag-specific transition tables.
41    pub tagged_transitions: HashMap<String, TransitionTable>,
42}
43
44impl MarkovModel {
45    /// Generate text from this model.
46    ///
47    /// Starts from a sentence-start state, walks the chain selecting next
48    /// tokens by weighted probability, and stops at a sentence boundary
49    /// within the word count range.
50    pub fn generate(
51        &self,
52        rng: &mut StdRng,
53        tag: Option<&str>,
54        min_words: usize,
55        max_words: usize,
56    ) -> Result<String, MarkovError> {
57        let transitions = if let Some(tag) = tag {
58            self.tagged_transitions
59                .get(tag)
60                .ok_or(MarkovError::NoData)?
61        } else {
62            &self.transitions
63        };
64
65        if transitions.is_empty() {
66            return Err(MarkovError::NoData);
67        }
68
69        let mut result_tokens: Vec<String> = Vec::new();
70        let mut state: Vec<String> = vec![SENTENCE_START.to_string(); self.n - 1];
71        let mut word_count = 0;
72        let mut last_sentence_end = 0;
73
74        for _ in 0..(max_words * 3) {
75            // safety limit on iterations
76            let next = match pick_next(transitions, &state, rng) {
77                Some(tok) => tok,
78                None => break,
79            };
80
81            if next == SENTENCE_END {
82                // Record sentence boundary position
83                last_sentence_end = result_tokens.len();
84
85                if word_count >= min_words {
86                    break;
87                }
88
89                // Start a new sentence
90                state = vec![SENTENCE_START.to_string(); self.n - 1];
91                continue;
92            }
93
94            // Count actual words (not punctuation)
95            if !PUNCTUATION.contains(&next.chars().next().unwrap_or(' ')) {
96                word_count += 1;
97            }
98
99            result_tokens.push(next.clone());
100
101            // Slide state window
102            state.push(next);
103            if state.len() > self.n - 1 {
104                state.remove(0);
105            }
106
107            if word_count >= max_words {
108                // Truncate at last complete sentence
109                if last_sentence_end > 0 {
110                    result_tokens.truncate(last_sentence_end);
111                }
112                break;
113            }
114        }
115
116        if result_tokens.is_empty() {
117            return Err(MarkovError::NoSentenceStart);
118        }
119
120        Ok(reassemble_tokens(&result_tokens))
121    }
122}
123
124/// Pick the next token from transitions given a state prefix.
125fn pick_next(transitions: &TransitionTable, state: &[String], rng: &mut StdRng) -> Option<String> {
126    let options = transitions.get(state)?;
127    if options.is_empty() {
128        return None;
129    }
130
131    let weights: Vec<u32> = options.iter().map(|(_, count)| *count).collect();
132    let dist = WeightedIndex::new(&weights).ok()?;
133    Some(options[dist.sample(rng)].0.clone())
134}
135
136/// Reassemble tokens into natural text (attach punctuation to previous word).
137fn reassemble_tokens(tokens: &[String]) -> String {
138    let mut result = String::new();
139    for (i, tok) in tokens.iter().enumerate() {
140        let is_punct = tok.len() == 1 && PUNCTUATION.contains(&tok.chars().next().unwrap());
141        if i > 0 && !is_punct {
142            result.push(' ');
143        }
144        result.push_str(tok);
145    }
146    result
147}
148
149/// Trains Markov models from raw text.
150pub struct MarkovTrainer;
151
152impl MarkovTrainer {
153    /// Train a Markov model from raw text with the given n-gram depth.
154    ///
155    /// Supports tagged regions: lines prefixed with `[tag]` apply that tag
156    /// to subsequent text until the next tag or end of file.
157    pub fn train(text: &str, n: usize) -> MarkovModel {
158        assert!((2..=4).contains(&n), "n-gram depth must be 2-4");
159
160        let mut transitions: TransitionTable = HashMap::new();
161        let mut tagged_transitions: HashMap<String, TransitionTable> = HashMap::new();
162
163        let mut current_tag: Option<String> = None;
164
165        for line in text.lines() {
166            let trimmed = line.trim();
167
168            // Check for tag markers: [tagname]
169            if trimmed.starts_with('[') && trimmed.ends_with(']') && trimmed.len() > 2 {
170                let tag = &trimmed[1..trimmed.len() - 1];
171                current_tag = Some(tag.to_string());
172                continue;
173            }
174
175            if trimmed.is_empty() {
176                continue;
177            }
178
179            let tokens = tokenize(trimmed);
180            let sentences = split_into_sentences(&tokens);
181
182            for sentence in &sentences {
183                // Build n-gram chain for this sentence
184                let mut padded = vec![SENTENCE_START.to_string(); n - 1];
185                padded.extend(sentence.iter().cloned());
186                padded.push(SENTENCE_END.to_string());
187
188                for window in padded.windows(n) {
189                    let prefix: Vec<String> = window[..n - 1].to_vec();
190                    let next = window[n - 1].clone();
191
192                    // Add to global transitions
193                    add_transition(&mut transitions, prefix.clone(), next.clone());
194
195                    // Add to tagged transitions if we have a tag
196                    if let Some(ref tag) = current_tag {
197                        let tag_table = tagged_transitions.entry(tag.clone()).or_default();
198                        add_transition(tag_table, prefix, next);
199                    }
200                }
201            }
202        }
203
204        MarkovModel {
205            n,
206            transitions,
207            tagged_transitions,
208        }
209    }
210}
211
212/// Add a transition to a transition table, incrementing the count.
213fn add_transition(table: &mut TransitionTable, prefix: Vec<String>, next: String) {
214    let entries = table.entry(prefix).or_default();
215    if let Some(entry) = entries.iter_mut().find(|(tok, _)| tok == &next) {
216        entry.1 += 1;
217    } else {
218        entries.push((next, 1));
219    }
220}
221
222/// Tokenize text: split on whitespace, separate punctuation as individual tokens.
223fn tokenize(text: &str) -> Vec<String> {
224    let mut tokens = Vec::new();
225    for word in text.split_whitespace() {
226        let mut remaining = word;
227        while !remaining.is_empty() {
228            // Check if starts with punctuation
229            let first = remaining.chars().next().unwrap();
230            if PUNCTUATION.contains(&first) {
231                tokens.push(first.to_string());
232                remaining = &remaining[first.len_utf8()..];
233                continue;
234            }
235
236            // Find end of word (before punctuation)
237            if let Some(pos) = remaining.find(|c: char| PUNCTUATION.contains(&c)) {
238                tokens.push(remaining[..pos].to_string());
239                remaining = &remaining[pos..];
240            } else {
241                tokens.push(remaining.to_string());
242                break;
243            }
244        }
245    }
246    tokens
247}
248
249/// Split a token sequence into sentences at sentence-ending punctuation.
250fn split_into_sentences(tokens: &[String]) -> Vec<Vec<String>> {
251    let mut sentences = Vec::new();
252    let mut current = Vec::new();
253
254    for tok in tokens {
255        current.push(tok.clone());
256        if tok.len() == 1
257            && SENTENCE_ENDERS.contains(&tok.chars().next().unwrap())
258            && !current.is_empty()
259        {
260            sentences.push(current.clone());
261            current.clear();
262        }
263    }
264
265    // Don't discard trailing tokens without sentence ender
266    if !current.is_empty() {
267        sentences.push(current);
268    }
269
270    sentences
271}
272
273/// Blends output from multiple Markov models with configurable weights.
274pub struct MarkovBlender;
275
276impl MarkovBlender {
277    /// Generate text by blending multiple models at each step.
278    pub fn generate(
279        models: &[(&MarkovModel, f32)],
280        rng: &mut StdRng,
281        tag: Option<&str>,
282        min_words: usize,
283        max_words: usize,
284    ) -> Result<String, MarkovError> {
285        if models.is_empty() {
286            return Err(MarkovError::NoData);
287        }
288
289        // All models must have the same n
290        let n = models[0].0.n;
291
292        let mut result_tokens: Vec<String> = Vec::new();
293        let mut state: Vec<String> = vec![SENTENCE_START.to_string(); n - 1];
294        let mut word_count = 0;
295        let mut last_sentence_end = 0;
296
297        for _ in 0..(max_words * 3) {
298            // Blend transition probabilities from all models
299            let next = match pick_next_blended(models, &state, tag, rng) {
300                Some(tok) => tok,
301                None => break,
302            };
303
304            if next == SENTENCE_END {
305                last_sentence_end = result_tokens.len();
306                if word_count >= min_words {
307                    break;
308                }
309                state = vec![SENTENCE_START.to_string(); n - 1];
310                continue;
311            }
312
313            if !PUNCTUATION.contains(&next.chars().next().unwrap_or(' ')) {
314                word_count += 1;
315            }
316
317            result_tokens.push(next.clone());
318            state.push(next);
319            if state.len() > n - 1 {
320                state.remove(0);
321            }
322
323            if word_count >= max_words {
324                if last_sentence_end > 0 {
325                    result_tokens.truncate(last_sentence_end);
326                }
327                break;
328            }
329        }
330
331        if result_tokens.is_empty() {
332            return Err(MarkovError::NoSentenceStart);
333        }
334
335        Ok(reassemble_tokens(&result_tokens))
336    }
337}
338
339/// Pick next token by blending transition probabilities from multiple models.
340fn pick_next_blended(
341    models: &[(&MarkovModel, f32)],
342    state: &[String],
343    tag: Option<&str>,
344    rng: &mut StdRng,
345) -> Option<String> {
346    let mut combined: HashMap<String, f64> = HashMap::new();
347
348    for (model, blend_weight) in models {
349        let transitions = if let Some(tag) = tag {
350            model
351                .tagged_transitions
352                .get(tag)
353                .unwrap_or(&model.transitions)
354        } else {
355            &model.transitions
356        };
357
358        if let Some(options) = transitions.get(state) {
359            let total: u32 = options.iter().map(|(_, c)| c).sum();
360            if total == 0 {
361                continue;
362            }
363            for (tok, count) in options {
364                let prob = (*count as f64) / (total as f64) * (*blend_weight as f64);
365                *combined.entry(tok.clone()).or_default() += prob;
366            }
367        }
368    }
369
370    if combined.is_empty() {
371        return None;
372    }
373
374    let tokens: Vec<String> = combined.keys().cloned().collect();
375    let weights: Vec<f64> = tokens.iter().map(|t| combined[t]).collect();
376    let dist = WeightedIndex::new(&weights).ok()?;
377    Some(tokens[dist.sample(rng)].clone())
378}
379
380/// Save a MarkovModel to a RON file.
381pub fn save_model(model: &MarkovModel, path: &std::path::Path) -> Result<(), MarkovError> {
382    let serialized = ron::ser::to_string_pretty(model, ron::ser::PrettyConfig::default())
383        .map_err(|e| std::io::Error::other(e.to_string()))?;
384    std::fs::write(path, serialized)?;
385    Ok(())
386}
387
388/// Load a MarkovModel from a RON file.
389pub fn load_model(path: &std::path::Path) -> Result<MarkovModel, MarkovError> {
390    let contents = std::fs::read_to_string(path)?;
391    let model: MarkovModel = ron::from_str(&contents)?;
392    Ok(model)
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use rand::SeedableRng;
399
400    fn train_test_corpus() -> MarkovModel {
401        let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
402        MarkovTrainer::train(&corpus, 2)
403    }
404
405    #[test]
406    fn tokenize_basic() {
407        let tokens = tokenize("Hello, world.");
408        assert_eq!(tokens, vec!["Hello", ",", "world", "."]);
409    }
410
411    #[test]
412    fn tokenize_complex() {
413        let tokens = tokenize("She said, \"What?\" He replied.");
414        assert!(tokens.contains(&"She".to_string()));
415        assert!(tokens.contains(&",".to_string()));
416        assert!(tokens.contains(&"?".to_string()));
417        assert!(tokens.contains(&".".to_string()));
418    }
419
420    #[test]
421    fn train_creates_transitions() {
422        let model = train_test_corpus();
423        assert_eq!(model.n, 2);
424        assert!(!model.transitions.is_empty());
425    }
426
427    #[test]
428    fn train_creates_tagged_transitions() {
429        let model = train_test_corpus();
430        assert!(model.tagged_transitions.contains_key("neutral"));
431        assert!(model.tagged_transitions.contains_key("tense"));
432        assert!(model.tagged_transitions.contains_key("warm"));
433    }
434
435    #[test]
436    fn generate_deterministic() {
437        let model = train_test_corpus();
438        let mut rng1 = StdRng::seed_from_u64(42);
439        let mut rng2 = StdRng::seed_from_u64(42);
440
441        let result1 = model.generate(&mut rng1, None, 3, 20).unwrap();
442        let result2 = model.generate(&mut rng2, None, 3, 20).unwrap();
443        assert_eq!(result1, result2);
444    }
445
446    #[test]
447    fn generate_produces_output() {
448        let model = train_test_corpus();
449        let mut rng = StdRng::seed_from_u64(42);
450
451        let result = model.generate(&mut rng, None, 3, 20).unwrap();
452        assert!(!result.is_empty());
453        let word_count = result.split_whitespace().count();
454        assert!(
455            word_count >= 3,
456            "Expected at least 3 words, got: {}",
457            word_count
458        );
459    }
460
461    #[test]
462    fn generate_respects_sentence_boundaries() {
463        let model = train_test_corpus();
464        let mut rng = StdRng::seed_from_u64(42);
465
466        let result = model.generate(&mut rng, None, 3, 20).unwrap();
467        // Result should end with sentence-ending punctuation or the last token
468        let trimmed = result.trim();
469        let last_char = trimmed.chars().last().unwrap();
470        assert!(
471            SENTENCE_ENDERS.contains(&last_char) || last_char.is_alphanumeric(),
472            "Expected sentence boundary or word end, got: '{}'",
473            last_char
474        );
475    }
476
477    #[test]
478    fn generate_with_tag() {
479        let model = train_test_corpus();
480        let mut rng = StdRng::seed_from_u64(42);
481
482        let result = model.generate(&mut rng, Some("tense"), 3, 20).unwrap();
483        assert!(!result.is_empty());
484    }
485
486    #[test]
487    fn tag_filtering_changes_output() {
488        let model = train_test_corpus();
489
490        // Generate multiple outputs with different tags and check they differ
491        let mut found_different = false;
492        for seed in 0..50 {
493            let mut rng1 = StdRng::seed_from_u64(seed);
494            let mut rng2 = StdRng::seed_from_u64(seed);
495
496            let neutral = model.generate(&mut rng1, Some("neutral"), 3, 15);
497            let tense = model.generate(&mut rng2, Some("tense"), 3, 15);
498
499            if let (Ok(n), Ok(t)) = (neutral, tense) {
500                if n != t {
501                    found_different = true;
502                    break;
503                }
504            }
505        }
506        assert!(
507            found_different,
508            "Tagged generation should produce different output"
509        );
510    }
511
512    #[test]
513    fn generate_invalid_tag_returns_error() {
514        let model = train_test_corpus();
515        let mut rng = StdRng::seed_from_u64(42);
516
517        let result = model.generate(&mut rng, Some("nonexistent_tag"), 3, 20);
518        assert!(result.is_err());
519    }
520
521    #[test]
522    fn ron_round_trip() {
523        let model = train_test_corpus();
524
525        let serialized = ron::to_string(&model).unwrap();
526        let deserialized: MarkovModel = ron::from_str(&serialized).unwrap();
527
528        assert_eq!(deserialized.n, model.n);
529        assert_eq!(deserialized.transitions.len(), model.transitions.len());
530    }
531
532    #[test]
533    fn save_and_load_model() {
534        let model = train_test_corpus();
535        let path = std::path::PathBuf::from("target/test_markov_model.ron");
536
537        save_model(&model, &path).unwrap();
538        let loaded = load_model(&path).unwrap();
539
540        assert_eq!(loaded.n, model.n);
541        assert_eq!(loaded.transitions.len(), model.transitions.len());
542
543        // Cleanup
544        let _ = std::fs::remove_file(&path);
545    }
546
547    #[test]
548    fn blending_produces_output() {
549        let model = train_test_corpus();
550        let mut rng = StdRng::seed_from_u64(42);
551
552        let result = MarkovBlender::generate(&[(&model, 1.0)], &mut rng, None, 3, 20).unwrap();
553        assert!(!result.is_empty());
554    }
555
556    #[test]
557    fn trigram_model() {
558        let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
559        let model = MarkovTrainer::train(&corpus, 3);
560        assert_eq!(model.n, 3);
561
562        let mut rng = StdRng::seed_from_u64(42);
563        let result = model.generate(&mut rng, None, 3, 20).unwrap();
564        assert!(!result.is_empty());
565    }
566
567    #[test]
568    fn reassemble_attaches_punctuation() {
569        let tokens = vec![
570            "Hello".to_string(),
571            ",".to_string(),
572            "world".to_string(),
573            ".".to_string(),
574        ];
575        let result = reassemble_tokens(&tokens);
576        assert_eq!(result, "Hello, world.");
577    }
578}