1use rand::distributions::WeightedIndex;
3use rand::prelude::Distribution;
4use rand::rngs::StdRng;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use thiserror::Error;
8
9type TransitionTable = HashMap<Vec<String>, Vec<(String, u32)>>;
11
12#[derive(Debug, Error)]
13pub enum MarkovError {
14 #[error("no data for generation (model is empty or tag has no data)")]
15 NoData,
16 #[error("no sentence start found")]
17 NoSentenceStart,
18 #[error("IO error: {0}")]
19 Io(#[from] std::io::Error),
20 #[error("RON deserialization error: {0}")]
21 Ron(#[from] ron::error::SpannedError),
22}
23
24const SENTENCE_START: &str = "<S>";
26const SENTENCE_END: &str = "</S>";
28
29const SENTENCE_ENDERS: &[char] = &['.', '!', '?'];
31const PUNCTUATION: &[char] = &['.', '!', '?', ',', ';', ':', '"', '\''];
32
33#[derive(Debug, Clone, Serialize, Deserialize, Default)]
35pub struct MarkovModel {
36 pub n: usize,
38 pub transitions: TransitionTable,
40 pub tagged_transitions: HashMap<String, TransitionTable>,
42}
43
44impl MarkovModel {
45 pub fn generate(
51 &self,
52 rng: &mut StdRng,
53 tag: Option<&str>,
54 min_words: usize,
55 max_words: usize,
56 ) -> Result<String, MarkovError> {
57 let transitions = if let Some(tag) = tag {
58 self.tagged_transitions
59 .get(tag)
60 .ok_or(MarkovError::NoData)?
61 } else {
62 &self.transitions
63 };
64
65 if transitions.is_empty() {
66 return Err(MarkovError::NoData);
67 }
68
69 let mut result_tokens: Vec<String> = Vec::new();
70 let mut state: Vec<String> = vec![SENTENCE_START.to_string(); self.n - 1];
71 let mut word_count = 0;
72 let mut last_sentence_end = 0;
73
74 for _ in 0..(max_words * 3) {
75 let next = match pick_next(transitions, &state, rng) {
77 Some(tok) => tok,
78 None => break,
79 };
80
81 if next == SENTENCE_END {
82 last_sentence_end = result_tokens.len();
84
85 if word_count >= min_words {
86 break;
87 }
88
89 state = vec![SENTENCE_START.to_string(); self.n - 1];
91 continue;
92 }
93
94 if !PUNCTUATION.contains(&next.chars().next().unwrap_or(' ')) {
96 word_count += 1;
97 }
98
99 result_tokens.push(next.clone());
100
101 state.push(next);
103 if state.len() > self.n - 1 {
104 state.remove(0);
105 }
106
107 if word_count >= max_words {
108 if last_sentence_end > 0 {
110 result_tokens.truncate(last_sentence_end);
111 }
112 break;
113 }
114 }
115
116 if result_tokens.is_empty() {
117 return Err(MarkovError::NoSentenceStart);
118 }
119
120 Ok(reassemble_tokens(&result_tokens))
121 }
122}
123
124fn pick_next(transitions: &TransitionTable, state: &[String], rng: &mut StdRng) -> Option<String> {
126 let options = transitions.get(state)?;
127 if options.is_empty() {
128 return None;
129 }
130
131 let weights: Vec<u32> = options.iter().map(|(_, count)| *count).collect();
132 let dist = WeightedIndex::new(&weights).ok()?;
133 Some(options[dist.sample(rng)].0.clone())
134}
135
136fn reassemble_tokens(tokens: &[String]) -> String {
138 let mut result = String::new();
139 for (i, tok) in tokens.iter().enumerate() {
140 let is_punct = tok.len() == 1 && PUNCTUATION.contains(&tok.chars().next().unwrap());
141 if i > 0 && !is_punct {
142 result.push(' ');
143 }
144 result.push_str(tok);
145 }
146 result
147}
148
149pub struct MarkovTrainer;
151
152impl MarkovTrainer {
153 pub fn train(text: &str, n: usize) -> MarkovModel {
158 assert!((2..=4).contains(&n), "n-gram depth must be 2-4");
159
160 let mut transitions: TransitionTable = HashMap::new();
161 let mut tagged_transitions: HashMap<String, TransitionTable> = HashMap::new();
162
163 let mut current_tag: Option<String> = None;
164
165 for line in text.lines() {
166 let trimmed = line.trim();
167
168 if trimmed.starts_with('[') && trimmed.ends_with(']') && trimmed.len() > 2 {
170 let tag = &trimmed[1..trimmed.len() - 1];
171 current_tag = Some(tag.to_string());
172 continue;
173 }
174
175 if trimmed.is_empty() {
176 continue;
177 }
178
179 let tokens = tokenize(trimmed);
180 let sentences = split_into_sentences(&tokens);
181
182 for sentence in &sentences {
183 let mut padded = vec![SENTENCE_START.to_string(); n - 1];
185 padded.extend(sentence.iter().cloned());
186 padded.push(SENTENCE_END.to_string());
187
188 for window in padded.windows(n) {
189 let prefix: Vec<String> = window[..n - 1].to_vec();
190 let next = window[n - 1].clone();
191
192 add_transition(&mut transitions, prefix.clone(), next.clone());
194
195 if let Some(ref tag) = current_tag {
197 let tag_table = tagged_transitions.entry(tag.clone()).or_default();
198 add_transition(tag_table, prefix, next);
199 }
200 }
201 }
202 }
203
204 MarkovModel {
205 n,
206 transitions,
207 tagged_transitions,
208 }
209 }
210}
211
212fn add_transition(table: &mut TransitionTable, prefix: Vec<String>, next: String) {
214 let entries = table.entry(prefix).or_default();
215 if let Some(entry) = entries.iter_mut().find(|(tok, _)| tok == &next) {
216 entry.1 += 1;
217 } else {
218 entries.push((next, 1));
219 }
220}
221
222fn tokenize(text: &str) -> Vec<String> {
224 let mut tokens = Vec::new();
225 for word in text.split_whitespace() {
226 let mut remaining = word;
227 while !remaining.is_empty() {
228 let first = remaining.chars().next().unwrap();
230 if PUNCTUATION.contains(&first) {
231 tokens.push(first.to_string());
232 remaining = &remaining[first.len_utf8()..];
233 continue;
234 }
235
236 if let Some(pos) = remaining.find(|c: char| PUNCTUATION.contains(&c)) {
238 tokens.push(remaining[..pos].to_string());
239 remaining = &remaining[pos..];
240 } else {
241 tokens.push(remaining.to_string());
242 break;
243 }
244 }
245 }
246 tokens
247}
248
249fn split_into_sentences(tokens: &[String]) -> Vec<Vec<String>> {
251 let mut sentences = Vec::new();
252 let mut current = Vec::new();
253
254 for tok in tokens {
255 current.push(tok.clone());
256 if tok.len() == 1
257 && SENTENCE_ENDERS.contains(&tok.chars().next().unwrap())
258 && !current.is_empty()
259 {
260 sentences.push(current.clone());
261 current.clear();
262 }
263 }
264
265 if !current.is_empty() {
267 sentences.push(current);
268 }
269
270 sentences
271}
272
273pub struct MarkovBlender;
275
276impl MarkovBlender {
277 pub fn generate(
279 models: &[(&MarkovModel, f32)],
280 rng: &mut StdRng,
281 tag: Option<&str>,
282 min_words: usize,
283 max_words: usize,
284 ) -> Result<String, MarkovError> {
285 if models.is_empty() {
286 return Err(MarkovError::NoData);
287 }
288
289 let n = models[0].0.n;
291
292 let mut result_tokens: Vec<String> = Vec::new();
293 let mut state: Vec<String> = vec![SENTENCE_START.to_string(); n - 1];
294 let mut word_count = 0;
295 let mut last_sentence_end = 0;
296
297 for _ in 0..(max_words * 3) {
298 let next = match pick_next_blended(models, &state, tag, rng) {
300 Some(tok) => tok,
301 None => break,
302 };
303
304 if next == SENTENCE_END {
305 last_sentence_end = result_tokens.len();
306 if word_count >= min_words {
307 break;
308 }
309 state = vec![SENTENCE_START.to_string(); n - 1];
310 continue;
311 }
312
313 if !PUNCTUATION.contains(&next.chars().next().unwrap_or(' ')) {
314 word_count += 1;
315 }
316
317 result_tokens.push(next.clone());
318 state.push(next);
319 if state.len() > n - 1 {
320 state.remove(0);
321 }
322
323 if word_count >= max_words {
324 if last_sentence_end > 0 {
325 result_tokens.truncate(last_sentence_end);
326 }
327 break;
328 }
329 }
330
331 if result_tokens.is_empty() {
332 return Err(MarkovError::NoSentenceStart);
333 }
334
335 Ok(reassemble_tokens(&result_tokens))
336 }
337}
338
339fn pick_next_blended(
341 models: &[(&MarkovModel, f32)],
342 state: &[String],
343 tag: Option<&str>,
344 rng: &mut StdRng,
345) -> Option<String> {
346 let mut combined: HashMap<String, f64> = HashMap::new();
347
348 for (model, blend_weight) in models {
349 let transitions = if let Some(tag) = tag {
350 model
351 .tagged_transitions
352 .get(tag)
353 .unwrap_or(&model.transitions)
354 } else {
355 &model.transitions
356 };
357
358 if let Some(options) = transitions.get(state) {
359 let total: u32 = options.iter().map(|(_, c)| c).sum();
360 if total == 0 {
361 continue;
362 }
363 for (tok, count) in options {
364 let prob = (*count as f64) / (total as f64) * (*blend_weight as f64);
365 *combined.entry(tok.clone()).or_default() += prob;
366 }
367 }
368 }
369
370 if combined.is_empty() {
371 return None;
372 }
373
374 let tokens: Vec<String> = combined.keys().cloned().collect();
375 let weights: Vec<f64> = tokens.iter().map(|t| combined[t]).collect();
376 let dist = WeightedIndex::new(&weights).ok()?;
377 Some(tokens[dist.sample(rng)].clone())
378}
379
380pub fn save_model(model: &MarkovModel, path: &std::path::Path) -> Result<(), MarkovError> {
382 let serialized = ron::ser::to_string_pretty(model, ron::ser::PrettyConfig::default())
383 .map_err(|e| std::io::Error::other(e.to_string()))?;
384 std::fs::write(path, serialized)?;
385 Ok(())
386}
387
388pub fn load_model(path: &std::path::Path) -> Result<MarkovModel, MarkovError> {
390 let contents = std::fs::read_to_string(path)?;
391 let model: MarkovModel = ron::from_str(&contents)?;
392 Ok(model)
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use rand::SeedableRng;
399
400 fn train_test_corpus() -> MarkovModel {
401 let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
402 MarkovTrainer::train(&corpus, 2)
403 }
404
405 #[test]
406 fn tokenize_basic() {
407 let tokens = tokenize("Hello, world.");
408 assert_eq!(tokens, vec!["Hello", ",", "world", "."]);
409 }
410
411 #[test]
412 fn tokenize_complex() {
413 let tokens = tokenize("She said, \"What?\" He replied.");
414 assert!(tokens.contains(&"She".to_string()));
415 assert!(tokens.contains(&",".to_string()));
416 assert!(tokens.contains(&"?".to_string()));
417 assert!(tokens.contains(&".".to_string()));
418 }
419
420 #[test]
421 fn train_creates_transitions() {
422 let model = train_test_corpus();
423 assert_eq!(model.n, 2);
424 assert!(!model.transitions.is_empty());
425 }
426
427 #[test]
428 fn train_creates_tagged_transitions() {
429 let model = train_test_corpus();
430 assert!(model.tagged_transitions.contains_key("neutral"));
431 assert!(model.tagged_transitions.contains_key("tense"));
432 assert!(model.tagged_transitions.contains_key("warm"));
433 }
434
435 #[test]
436 fn generate_deterministic() {
437 let model = train_test_corpus();
438 let mut rng1 = StdRng::seed_from_u64(42);
439 let mut rng2 = StdRng::seed_from_u64(42);
440
441 let result1 = model.generate(&mut rng1, None, 3, 20).unwrap();
442 let result2 = model.generate(&mut rng2, None, 3, 20).unwrap();
443 assert_eq!(result1, result2);
444 }
445
446 #[test]
447 fn generate_produces_output() {
448 let model = train_test_corpus();
449 let mut rng = StdRng::seed_from_u64(42);
450
451 let result = model.generate(&mut rng, None, 3, 20).unwrap();
452 assert!(!result.is_empty());
453 let word_count = result.split_whitespace().count();
454 assert!(
455 word_count >= 3,
456 "Expected at least 3 words, got: {}",
457 word_count
458 );
459 }
460
461 #[test]
462 fn generate_respects_sentence_boundaries() {
463 let model = train_test_corpus();
464 let mut rng = StdRng::seed_from_u64(42);
465
466 let result = model.generate(&mut rng, None, 3, 20).unwrap();
467 let trimmed = result.trim();
469 let last_char = trimmed.chars().last().unwrap();
470 assert!(
471 SENTENCE_ENDERS.contains(&last_char) || last_char.is_alphanumeric(),
472 "Expected sentence boundary or word end, got: '{}'",
473 last_char
474 );
475 }
476
477 #[test]
478 fn generate_with_tag() {
479 let model = train_test_corpus();
480 let mut rng = StdRng::seed_from_u64(42);
481
482 let result = model.generate(&mut rng, Some("tense"), 3, 20).unwrap();
483 assert!(!result.is_empty());
484 }
485
486 #[test]
487 fn tag_filtering_changes_output() {
488 let model = train_test_corpus();
489
490 let mut found_different = false;
492 for seed in 0..50 {
493 let mut rng1 = StdRng::seed_from_u64(seed);
494 let mut rng2 = StdRng::seed_from_u64(seed);
495
496 let neutral = model.generate(&mut rng1, Some("neutral"), 3, 15);
497 let tense = model.generate(&mut rng2, Some("tense"), 3, 15);
498
499 if let (Ok(n), Ok(t)) = (neutral, tense) {
500 if n != t {
501 found_different = true;
502 break;
503 }
504 }
505 }
506 assert!(
507 found_different,
508 "Tagged generation should produce different output"
509 );
510 }
511
512 #[test]
513 fn generate_invalid_tag_returns_error() {
514 let model = train_test_corpus();
515 let mut rng = StdRng::seed_from_u64(42);
516
517 let result = model.generate(&mut rng, Some("nonexistent_tag"), 3, 20);
518 assert!(result.is_err());
519 }
520
521 #[test]
522 fn ron_round_trip() {
523 let model = train_test_corpus();
524
525 let serialized = ron::to_string(&model).unwrap();
526 let deserialized: MarkovModel = ron::from_str(&serialized).unwrap();
527
528 assert_eq!(deserialized.n, model.n);
529 assert_eq!(deserialized.transitions.len(), model.transitions.len());
530 }
531
532 #[test]
533 fn save_and_load_model() {
534 let model = train_test_corpus();
535 let path = std::path::PathBuf::from("target/test_markov_model.ron");
536
537 save_model(&model, &path).unwrap();
538 let loaded = load_model(&path).unwrap();
539
540 assert_eq!(loaded.n, model.n);
541 assert_eq!(loaded.transitions.len(), model.transitions.len());
542
543 let _ = std::fs::remove_file(&path);
545 }
546
547 #[test]
548 fn blending_produces_output() {
549 let model = train_test_corpus();
550 let mut rng = StdRng::seed_from_u64(42);
551
552 let result = MarkovBlender::generate(&[(&model, 1.0)], &mut rng, None, 3, 20).unwrap();
553 assert!(!result.is_empty());
554 }
555
556 #[test]
557 fn trigram_model() {
558 let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
559 let model = MarkovTrainer::train(&corpus, 3);
560 assert_eq!(model.n, 3);
561
562 let mut rng = StdRng::seed_from_u64(42);
563 let result = model.generate(&mut rng, None, 3, 20).unwrap();
564 assert!(!result.is_empty());
565 }
566
567 #[test]
568 fn reassemble_attaches_punctuation() {
569 let tokens = vec![
570 "Hello".to_string(),
571 ",".to_string(),
572 "world".to_string(),
573 ".".to_string(),
574 ];
575 let result = reassemble_tokens(&tokens);
576 assert_eq!(result, "Hello, world.");
577 }
578}