Skip to main content

narrative_engine/core/
pipeline.rs

1/// The main narrative pipeline: Event → Text orchestration.
2///
3/// Wires together grammar expansion, voice selection, Markov fill,
4/// variety pass, and context checking.
5use rand::rngs::StdRng;
6use rand::SeedableRng;
7use std::collections::HashMap;
8use std::path::Path;
9use thiserror::Error;
10
11use crate::core::context::NarrativeContext;
12use crate::core::grammar::{GrammarError, GrammarSet, SelectionContext};
13use crate::core::markov::{MarkovError, MarkovModel};
14use crate::core::variety::VarietyPass;
15use crate::core::voice::{VoiceError, VoiceRegistry};
16use crate::schema::entity::{Entity, EntityId, VoiceId};
17use crate::schema::event::Event;
18use crate::schema::narrative_fn::NarrativeFunction;
19
20#[derive(Debug, Error)]
21pub enum PipelineError {
22    #[error("grammar error: {0}")]
23    Grammar(#[from] GrammarError),
24    #[error("voice error: {0}")]
25    Voice(#[from] VoiceError),
26    #[error("markov error: {0}")]
27    Markov(#[from] MarkovError),
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30    #[error("RON error: {0}")]
31    Ron(#[from] ron::error::SpannedError),
32    #[error("entity not found: {0:?}")]
33    EntityNotFound(EntityId),
34    #[error("no grammar rule found for narrative function: {0}")]
35    NoRuleForFunction(String),
36    #[error("generation failed after {0} retries")]
37    GenerationFailed(u32),
38}
39
40/// World state passed by the game to the narration pipeline.
41pub struct WorldState<'a> {
42    pub entities: &'a HashMap<EntityId, Entity>,
43}
44
45/// Event-type to narrative-function mapping entry.
46#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
47pub struct EventMapping {
48    pub event_type: String,
49    pub narrative_fn: NarrativeFunction,
50}
51
52/// The top-level narrative engine. Built via `NarrativeEngine::builder()`.
53pub struct NarrativeEngine {
54    grammars: GrammarSet,
55    voices: VoiceRegistry,
56    markov_models: HashMap<String, MarkovModel>,
57    mappings: HashMap<String, NarrativeFunction>,
58    context: NarrativeContext,
59    seed: u64,
60    generation_count: u64,
61}
62
63/// Builder for constructing a `NarrativeEngine`.
64pub struct NarrativeEngineBuilder {
65    genre_templates: Vec<String>,
66    grammars_dir: Option<String>,
67    voices_dir: Option<String>,
68    markov_models_dir: Option<String>,
69    mappings_path: Option<String>,
70    seed: u64,
71    /// Directly provided grammars (for testing without files).
72    grammars: Option<GrammarSet>,
73    /// Directly provided voices (for testing without files).
74    voices: Option<VoiceRegistry>,
75    /// Directly provided markov models (for testing without files).
76    markov_models: Option<HashMap<String, MarkovModel>>,
77    /// Directly provided mappings (for testing without files).
78    mappings: Option<HashMap<String, NarrativeFunction>>,
79}
80
81impl NarrativeEngine {
82    pub fn builder() -> NarrativeEngineBuilder {
83        NarrativeEngineBuilder {
84            genre_templates: Vec::new(),
85            grammars_dir: None,
86            voices_dir: None,
87            markov_models_dir: None,
88            mappings_path: None,
89            seed: 0,
90            grammars: None,
91            voices: None,
92            markov_models: None,
93            mappings: None,
94        }
95    }
96
97    /// Generate narration for an event using the first participant's voice.
98    pub fn narrate(
99        &mut self,
100        event: &Event,
101        world: &WorldState<'_>,
102    ) -> Result<String, PipelineError> {
103        // Select voice from first participant
104        let voice_id = self.resolve_voice_id(event, world);
105        self.narrate_with_voice(event, voice_id, world)
106    }
107
108    /// Generate narration for an event using a specific voice.
109    pub fn narrate_as(
110        &mut self,
111        event: &Event,
112        voice_id: VoiceId,
113        world: &WorldState<'_>,
114    ) -> Result<String, PipelineError> {
115        self.narrate_with_voice(event, Some(voice_id), world)
116    }
117
118    /// Generate multiple variants for an event.
119    pub fn narrate_variants(
120        &mut self,
121        event: &Event,
122        count: usize,
123        world: &WorldState<'_>,
124    ) -> Result<Vec<String>, PipelineError> {
125        let mut results = Vec::with_capacity(count);
126        for i in 0..count {
127            // Use different seed offsets for each variant
128            let saved_count = self.generation_count;
129            self.generation_count = saved_count + (i as u64 * 1000);
130            let result = self.narrate(event, world)?;
131            self.generation_count = saved_count + 1;
132            results.push(result);
133        }
134        Ok(results)
135    }
136
137    fn resolve_voice_id(&self, event: &Event, world: &WorldState<'_>) -> Option<VoiceId> {
138        // Use first participant's voice_id
139        for participant in &event.participants {
140            if let Some(entity) = world.entities.get(&participant.entity_id) {
141                if entity.voice_id.is_some() {
142                    return entity.voice_id;
143                }
144            }
145        }
146        None
147    }
148
149    fn narrate_with_voice(
150        &mut self,
151        event: &Event,
152        voice_id: Option<VoiceId>,
153        world: &WorldState<'_>,
154    ) -> Result<String, PipelineError> {
155        let max_retries = 3u32;
156
157        for retry in 0..max_retries {
158            let mut rng = StdRng::seed_from_u64(
159                self.seed
160                    .wrapping_add(self.generation_count)
161                    .wrapping_add(retry as u64 * 7919), // prime offset per retry
162            );
163
164            // 1. Resolve narrative function
165            let narrative_fn = self.resolve_narrative_fn(event);
166
167            // 2. Build SelectionContext
168            let mut ctx = self.build_context(event, world, &narrative_fn);
169
170            // 3-4. Resolve voice
171            let resolved_voice = voice_id.and_then(|id| self.voices.resolve(id));
172            if let Some(ref voice) = resolved_voice {
173                ctx.voice_weights = Some(&voice.grammar_weights);
174            }
175
176            // Add markov model references to context
177            for (corpus_id, model) in &self.markov_models {
178                ctx.markov_models.insert(corpus_id.clone(), model);
179            }
180
181            // 5. Determine entry rule name
182            let rule_name = format!("{}_opening", narrative_fn.name());
183
184            // 6. Expand grammar
185            let expanded = match self.grammars.expand(&rule_name, &mut ctx, &mut rng) {
186                Ok(text) => text,
187                Err(GrammarError::RuleNotFound(_)) => {
188                    // Try without _opening suffix
189                    match self
190                        .grammars
191                        .expand(narrative_fn.name(), &mut ctx, &mut rng)
192                    {
193                        Ok(text) => text,
194                        Err(e) => return Err(PipelineError::Grammar(e)),
195                    }
196                }
197                Err(e) => return Err(PipelineError::Grammar(e)),
198            };
199
200            // 7. Run variety pass
201            let output = if let Some(ref voice) = resolved_voice {
202                VarietyPass::apply(&expanded, voice, &self.context, &mut rng)
203            } else {
204                expanded
205            };
206
207            // 8. Check for repetition
208            let issues = self.context.check_repetition(&output);
209            if issues.is_empty() || retry == max_retries - 1 {
210                // 9. Record and return
211                self.context.record(&output);
212                self.generation_count += 1;
213                return Ok(output);
214            }
215            // Retry with different seed offset
216        }
217
218        Err(PipelineError::GenerationFailed(max_retries))
219    }
220
221    fn resolve_narrative_fn(&self, event: &Event) -> NarrativeFunction {
222        // Event can specify narrative_fn directly
223        // Or look up from mappings table
224        if let Some(mapped) = self.mappings.get(&event.event_type) {
225            mapped.clone()
226        } else {
227            event.narrative_fn.clone()
228        }
229    }
230
231    fn build_context<'a>(
232        &'a self,
233        event: &Event,
234        world: &'a WorldState<'_>,
235        narrative_fn: &NarrativeFunction,
236    ) -> SelectionContext<'a> {
237        let mut ctx = SelectionContext::new();
238
239        // Add mood and stakes as tags
240        ctx.tags.insert(event.mood.tag().to_string());
241        ctx.tags.insert(event.stakes.tag().to_string());
242
243        // Add narrative function as tag
244        ctx.tags.insert(format!("fn:{}", narrative_fn.name()));
245
246        // Add intensity-based tags
247        let intensity = narrative_fn.intensity();
248        if intensity >= 0.7 {
249            ctx.tags.insert("intensity:high".to_string());
250        } else if intensity <= 0.3 {
251            ctx.tags.insert("intensity:low".to_string());
252        }
253
254        // Add participant entity tags and bindings
255        for (i, participant) in event.participants.iter().enumerate() {
256            if let Some(entity) = world.entities.get(&participant.entity_id) {
257                for tag in &entity.tags {
258                    ctx.tags.insert(tag.clone());
259                }
260
261                // Bind by role
262                ctx.entity_bindings.insert(participant.role.clone(), entity);
263
264                // First participant is also "subject" if no explicit subject role
265                if i == 0 && !ctx.entity_bindings.contains_key("subject") {
266                    ctx.entity_bindings.insert("subject".to_string(), entity);
267                }
268            }
269        }
270
271        // Add location entity tags
272        if let Some(ref location) = event.location {
273            if let Some(entity) = world.entities.get(&location.entity_id) {
274                for tag in &entity.tags {
275                    ctx.tags.insert(tag.clone());
276                }
277                ctx.entity_bindings.insert(location.role.clone(), entity);
278            }
279        }
280
281        ctx
282    }
283}
284
285impl NarrativeEngineBuilder {
286    pub fn genre_templates(mut self, templates: &[&str]) -> Self {
287        self.genre_templates = templates.iter().map(|s| s.to_string()).collect();
288        self
289    }
290
291    pub fn grammars_dir(mut self, path: &str) -> Self {
292        self.grammars_dir = Some(path.to_string());
293        self
294    }
295
296    pub fn voices_dir(mut self, path: &str) -> Self {
297        self.voices_dir = Some(path.to_string());
298        self
299    }
300
301    pub fn markov_models_dir(mut self, path: &str) -> Self {
302        self.markov_models_dir = Some(path.to_string());
303        self
304    }
305
306    pub fn mappings(mut self, path: &str) -> Self {
307        self.mappings_path = Some(path.to_string());
308        self
309    }
310
311    pub fn seed(mut self, seed: u64) -> Self {
312        self.seed = seed;
313        self
314    }
315
316    /// Provide grammars directly (for testing without files).
317    pub fn with_grammars(mut self, grammars: GrammarSet) -> Self {
318        self.grammars = Some(grammars);
319        self
320    }
321
322    /// Provide voices directly (for testing without files).
323    pub fn with_voices(mut self, voices: VoiceRegistry) -> Self {
324        self.voices = Some(voices);
325        self
326    }
327
328    /// Provide markov models directly (for testing without files).
329    pub fn with_markov_models(mut self, models: HashMap<String, MarkovModel>) -> Self {
330        self.markov_models = Some(models);
331        self
332    }
333
334    /// Provide mappings directly (for testing without files).
335    pub fn with_mappings(mut self, mappings: HashMap<String, NarrativeFunction>) -> Self {
336        self.mappings = Some(mappings);
337        self
338    }
339
340    pub fn build(self) -> Result<NarrativeEngine, PipelineError> {
341        let mut grammars = self.grammars.unwrap_or_default();
342        let mut voices = self.voices.unwrap_or_default();
343        let mut markov_models = self.markov_models.unwrap_or_default();
344        let mappings = self.mappings.unwrap_or_default();
345
346        // Load genre templates
347        for template_name in &self.genre_templates {
348            let grammar_path = format!("genre_data/{}/grammar.ron", template_name);
349            if Path::new(&grammar_path).exists() {
350                let template_grammars = GrammarSet::load_from_ron(Path::new(&grammar_path))?;
351                grammars.merge(template_grammars);
352            }
353
354            let voices_path = format!("genre_data/{}/voices.ron", template_name);
355            if Path::new(&voices_path).exists() {
356                voices.load_from_ron(Path::new(&voices_path))?;
357            }
358        }
359
360        // Load game-specific grammars (override genre templates)
361        if let Some(ref dir) = self.grammars_dir {
362            if Path::new(dir).exists() {
363                load_ron_files_from_dir(dir, |path| {
364                    let gs = GrammarSet::load_from_ron(path)?;
365                    grammars.merge(gs);
366                    Ok(())
367                })?;
368            }
369        }
370
371        // Load game-specific voices
372        if let Some(ref dir) = self.voices_dir {
373            if Path::new(dir).exists() {
374                load_ron_files_from_dir(dir, |path| {
375                    voices.load_from_ron(path)?;
376                    Ok(())
377                })?;
378            }
379        }
380
381        // Load Markov models
382        if let Some(ref dir) = self.markov_models_dir {
383            if Path::new(dir).exists() {
384                load_ron_files_from_dir(dir, |path| {
385                    let model = crate::core::markov::load_model(path)?;
386                    let name = path
387                        .file_stem()
388                        .and_then(|s| s.to_str())
389                        .unwrap_or("unknown")
390                        .to_string();
391                    markov_models.insert(name, model);
392                    Ok(())
393                })?;
394            }
395        }
396
397        // Load mappings
398        let mappings = if let Some(ref path) = self.mappings_path {
399            if Path::new(path).exists() {
400                let contents = std::fs::read_to_string(path)?;
401                let entries: Vec<EventMapping> = ron::from_str(&contents)?;
402                let mut map = mappings;
403                for entry in entries {
404                    map.insert(entry.event_type, entry.narrative_fn);
405                }
406                map
407            } else {
408                mappings
409            }
410        } else {
411            mappings
412        };
413
414        Ok(NarrativeEngine {
415            grammars,
416            voices,
417            markov_models,
418            mappings,
419            context: NarrativeContext::default(),
420            seed: self.seed,
421            generation_count: 0,
422        })
423    }
424}
425
426/// Load all .ron files from a directory, calling `loader` for each.
427fn load_ron_files_from_dir<F>(dir: &str, mut loader: F) -> Result<(), PipelineError>
428where
429    F: FnMut(&Path) -> Result<(), PipelineError>,
430{
431    let entries = std::fs::read_dir(dir)?;
432    for entry in entries {
433        let entry = entry?;
434        let path = entry.path();
435        if path.extension().and_then(|s| s.to_str()) == Some("ron") {
436            loader(&path)?;
437        }
438    }
439    Ok(())
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::core::markov::MarkovTrainer;
446    use crate::core::voice::Voice;
447    use crate::schema::entity::Value;
448    use crate::schema::event::{EntityRef, Mood, Stakes};
449
450    fn build_test_engine() -> NarrativeEngine {
451        // Create minimal grammar
452        let grammar_ron = r#"{
453            "confrontation_opening": Rule(
454                requires: ["mood:tense"],
455                excludes: [],
456                alternatives: [
457                    (weight: 3, text: "{subject} stepped forward. {tense_detail}"),
458                    (weight: 2, text: "The tension was palpable. {subject} spoke first."),
459                ],
460            ),
461            "tense_detail": Rule(
462                requires: [],
463                excludes: [],
464                alternatives: [
465                    (weight: 2, text: "The air felt heavy with unspoken words."),
466                    (weight: 2, text: "No one dared to breathe."),
467                    (weight: 1, text: "A silence settled over the room."),
468                ],
469            ),
470            "revelation_opening": Rule(
471                requires: [],
472                excludes: [],
473                alternatives: [
474                    (weight: 2, text: "{subject} revealed the truth at last."),
475                    (weight: 1, text: "The secret was finally out."),
476                ],
477            ),
478        }"#;
479        let grammars = GrammarSet::parse_ron(grammar_ron).unwrap();
480
481        // Create a voice
482        let mut voices = VoiceRegistry::new();
483        voices.register(Voice {
484            id: VoiceId(1),
485            name: "narrator".to_string(),
486            parent: None,
487            grammar_weights: HashMap::new(),
488            vocabulary: crate::core::voice::VocabularyPool::default(),
489            markov_bindings: Vec::new(),
490            structure_prefs: crate::core::voice::StructurePrefs::default(),
491            quirks: Vec::new(),
492        });
493
494        // Train a small Markov model
495        let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
496        let markov_model = MarkovTrainer::train(&corpus, 2);
497
498        let mut markov_models = HashMap::new();
499        markov_models.insert("test_corpus".to_string(), markov_model);
500
501        NarrativeEngine::builder()
502            .seed(42)
503            .with_grammars(grammars)
504            .with_voices(voices)
505            .with_markov_models(markov_models)
506            .build()
507            .unwrap()
508    }
509
510    fn make_test_world() -> (HashMap<EntityId, Entity>, Event) {
511        let mut entities = HashMap::new();
512
513        let margaret = Entity {
514            id: EntityId(1),
515            name: "Margaret".to_string(),
516            pronouns: crate::schema::entity::Pronouns::SheHer,
517            tags: ["host".to_string(), "formal".to_string()]
518                .into_iter()
519                .collect(),
520            relationships: Vec::new(),
521            voice_id: Some(VoiceId(1)),
522            properties: HashMap::from([(
523                "title".to_string(),
524                Value::String("Duchess".to_string()),
525            )]),
526        };
527
528        let james = Entity {
529            id: EntityId(2),
530            name: "James".to_string(),
531            pronouns: crate::schema::entity::Pronouns::HeHim,
532            tags: ["guest".to_string()].into_iter().collect(),
533            relationships: Vec::new(),
534            voice_id: None,
535            properties: HashMap::new(),
536        };
537
538        entities.insert(EntityId(1), margaret);
539        entities.insert(EntityId(2), james);
540
541        let event = Event {
542            event_type: "accusation".to_string(),
543            participants: vec![
544                EntityRef {
545                    entity_id: EntityId(1),
546                    role: "subject".to_string(),
547                },
548                EntityRef {
549                    entity_id: EntityId(2),
550                    role: "object".to_string(),
551                },
552            ],
553            location: None,
554            mood: Mood::Tense,
555            stakes: Stakes::High,
556            outcome: None,
557            narrative_fn: NarrativeFunction::Confrontation,
558            metadata: HashMap::new(),
559        };
560
561        (entities, event)
562    }
563
564    #[test]
565    fn narrate_produces_output() {
566        let mut engine = build_test_engine();
567        let (entities, event) = make_test_world();
568        let world = WorldState {
569            entities: &entities,
570        };
571
572        let result = engine.narrate(&event, &world).unwrap();
573        assert!(!result.is_empty(), "Expected non-empty narration");
574        assert!(
575            result.len() > 10,
576            "Expected substantial text, got: {}",
577            result
578        );
579    }
580
581    #[test]
582    fn narrate_deterministic_same_seed() {
583        let (entities, event) = make_test_world();
584
585        let mut engine1 = build_test_engine();
586        let world1 = WorldState {
587            entities: &entities,
588        };
589        let result1 = engine1.narrate(&event, &world1).unwrap();
590
591        let mut engine2 = build_test_engine();
592        let world2 = WorldState {
593            entities: &entities,
594        };
595        let result2 = engine2.narrate(&event, &world2).unwrap();
596
597        assert_eq!(result1, result2);
598    }
599
600    #[test]
601    fn narrate_different_with_different_seed() {
602        let (entities, event) = make_test_world();
603
604        let mut found_different = false;
605        let mut engine1 = NarrativeEngine::builder()
606            .seed(1)
607            .with_grammars(build_test_engine().grammars.clone())
608            .build()
609            .unwrap();
610        let world = WorldState {
611            entities: &entities,
612        };
613        let result1 = engine1.narrate(&event, &world).unwrap();
614
615        for seed in 2..50 {
616            let grammars_ron = r#"{
617                "confrontation_opening": Rule(
618                    requires: ["mood:tense"],
619                    excludes: [],
620                    alternatives: [
621                        (weight: 3, text: "{subject} stepped forward. {tense_detail}"),
622                        (weight: 2, text: "The tension was palpable. {subject} spoke first."),
623                    ],
624                ),
625                "tense_detail": Rule(
626                    requires: [],
627                    excludes: [],
628                    alternatives: [
629                        (weight: 2, text: "The air felt heavy with unspoken words."),
630                        (weight: 2, text: "No one dared to breathe."),
631                        (weight: 1, text: "A silence settled over the room."),
632                    ],
633                ),
634            }"#;
635            let mut engine2 = NarrativeEngine::builder()
636                .seed(seed)
637                .with_grammars(GrammarSet::parse_ron(grammars_ron).unwrap())
638                .build()
639                .unwrap();
640            let result2 = engine2.narrate(&event, &world).unwrap();
641            if result1 != result2 {
642                found_different = true;
643                break;
644            }
645        }
646        assert!(
647            found_different,
648            "Expected different output with different seeds"
649        );
650    }
651
652    #[test]
653    fn narrate_as_with_specific_voice() {
654        let mut engine = build_test_engine();
655        let (entities, event) = make_test_world();
656        let world = WorldState {
657            entities: &entities,
658        };
659
660        let result = engine.narrate_as(&event, VoiceId(1), &world).unwrap();
661        assert!(!result.is_empty());
662    }
663
664    #[test]
665    fn narrate_variants_produces_multiple() {
666        let mut engine = build_test_engine();
667        let (entities, event) = make_test_world();
668        let world = WorldState {
669            entities: &entities,
670        };
671
672        let variants = engine.narrate_variants(&event, 3, &world).unwrap();
673        assert_eq!(variants.len(), 3);
674        for v in &variants {
675            assert!(!v.is_empty());
676        }
677    }
678
679    #[test]
680    fn narrate_contains_entity_name() {
681        let mut engine = build_test_engine();
682        let (entities, event) = make_test_world();
683        let world = WorldState {
684            entities: &entities,
685        };
686
687        // Run several seeds — at least one should contain Margaret
688        let mut found_name = false;
689        for _ in 0..10 {
690            let result = engine.narrate(&event, &world).unwrap();
691            if result.contains("Margaret") {
692                found_name = true;
693                break;
694            }
695        }
696        assert!(found_name, "Expected entity name in at least one narration");
697    }
698
699    #[test]
700    fn builder_with_seed() {
701        let engine = NarrativeEngine::builder().seed(12345).build().unwrap();
702        assert_eq!(engine.seed, 12345);
703    }
704}