1use rand::rngs::StdRng;
6use rand::SeedableRng;
7use std::collections::HashMap;
8use std::path::Path;
9use thiserror::Error;
10
11use crate::core::context::NarrativeContext;
12use crate::core::grammar::{GrammarError, GrammarSet, SelectionContext};
13use crate::core::markov::{MarkovError, MarkovModel};
14use crate::core::variety::VarietyPass;
15use crate::core::voice::{VoiceError, VoiceRegistry};
16use crate::schema::entity::{Entity, EntityId, VoiceId};
17use crate::schema::event::Event;
18use crate::schema::narrative_fn::NarrativeFunction;
19
20#[derive(Debug, Error)]
21pub enum PipelineError {
22 #[error("grammar error: {0}")]
23 Grammar(#[from] GrammarError),
24 #[error("voice error: {0}")]
25 Voice(#[from] VoiceError),
26 #[error("markov error: {0}")]
27 Markov(#[from] MarkovError),
28 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30 #[error("RON error: {0}")]
31 Ron(#[from] ron::error::SpannedError),
32 #[error("entity not found: {0:?}")]
33 EntityNotFound(EntityId),
34 #[error("no grammar rule found for narrative function: {0}")]
35 NoRuleForFunction(String),
36 #[error("generation failed after {0} retries")]
37 GenerationFailed(u32),
38}
39
40pub struct WorldState<'a> {
42 pub entities: &'a HashMap<EntityId, Entity>,
43}
44
45#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
47pub struct EventMapping {
48 pub event_type: String,
49 pub narrative_fn: NarrativeFunction,
50}
51
52pub struct NarrativeEngine {
54 grammars: GrammarSet,
55 voices: VoiceRegistry,
56 markov_models: HashMap<String, MarkovModel>,
57 mappings: HashMap<String, NarrativeFunction>,
58 context: NarrativeContext,
59 seed: u64,
60 generation_count: u64,
61}
62
63pub struct NarrativeEngineBuilder {
65 genre_templates: Vec<String>,
66 grammars_dir: Option<String>,
67 voices_dir: Option<String>,
68 markov_models_dir: Option<String>,
69 mappings_path: Option<String>,
70 seed: u64,
71 grammars: Option<GrammarSet>,
73 voices: Option<VoiceRegistry>,
75 markov_models: Option<HashMap<String, MarkovModel>>,
77 mappings: Option<HashMap<String, NarrativeFunction>>,
79}
80
81impl NarrativeEngine {
82 pub fn builder() -> NarrativeEngineBuilder {
83 NarrativeEngineBuilder {
84 genre_templates: Vec::new(),
85 grammars_dir: None,
86 voices_dir: None,
87 markov_models_dir: None,
88 mappings_path: None,
89 seed: 0,
90 grammars: None,
91 voices: None,
92 markov_models: None,
93 mappings: None,
94 }
95 }
96
97 pub fn narrate(
99 &mut self,
100 event: &Event,
101 world: &WorldState<'_>,
102 ) -> Result<String, PipelineError> {
103 let voice_id = self.resolve_voice_id(event, world);
105 self.narrate_with_voice(event, voice_id, world)
106 }
107
108 pub fn narrate_as(
110 &mut self,
111 event: &Event,
112 voice_id: VoiceId,
113 world: &WorldState<'_>,
114 ) -> Result<String, PipelineError> {
115 self.narrate_with_voice(event, Some(voice_id), world)
116 }
117
118 pub fn narrate_variants(
120 &mut self,
121 event: &Event,
122 count: usize,
123 world: &WorldState<'_>,
124 ) -> Result<Vec<String>, PipelineError> {
125 let mut results = Vec::with_capacity(count);
126 for i in 0..count {
127 let saved_count = self.generation_count;
129 self.generation_count = saved_count + (i as u64 * 1000);
130 let result = self.narrate(event, world)?;
131 self.generation_count = saved_count + 1;
132 results.push(result);
133 }
134 Ok(results)
135 }
136
137 fn resolve_voice_id(&self, event: &Event, world: &WorldState<'_>) -> Option<VoiceId> {
138 for participant in &event.participants {
140 if let Some(entity) = world.entities.get(&participant.entity_id) {
141 if entity.voice_id.is_some() {
142 return entity.voice_id;
143 }
144 }
145 }
146 None
147 }
148
149 fn narrate_with_voice(
150 &mut self,
151 event: &Event,
152 voice_id: Option<VoiceId>,
153 world: &WorldState<'_>,
154 ) -> Result<String, PipelineError> {
155 let max_retries = 3u32;
156
157 for retry in 0..max_retries {
158 let mut rng = StdRng::seed_from_u64(
159 self.seed
160 .wrapping_add(self.generation_count)
161 .wrapping_add(retry as u64 * 7919), );
163
164 let narrative_fn = self.resolve_narrative_fn(event);
166
167 let mut ctx = self.build_context(event, world, &narrative_fn);
169
170 let resolved_voice = voice_id.and_then(|id| self.voices.resolve(id));
172 if let Some(ref voice) = resolved_voice {
173 ctx.voice_weights = Some(&voice.grammar_weights);
174 }
175
176 for (corpus_id, model) in &self.markov_models {
178 ctx.markov_models.insert(corpus_id.clone(), model);
179 }
180
181 let rule_name = format!("{}_opening", narrative_fn.name());
183
184 let expanded = match self.grammars.expand(&rule_name, &mut ctx, &mut rng) {
186 Ok(text) => text,
187 Err(GrammarError::RuleNotFound(_)) => {
188 match self
190 .grammars
191 .expand(narrative_fn.name(), &mut ctx, &mut rng)
192 {
193 Ok(text) => text,
194 Err(e) => return Err(PipelineError::Grammar(e)),
195 }
196 }
197 Err(e) => return Err(PipelineError::Grammar(e)),
198 };
199
200 let output = if let Some(ref voice) = resolved_voice {
202 VarietyPass::apply(&expanded, voice, &self.context, &mut rng)
203 } else {
204 expanded
205 };
206
207 let issues = self.context.check_repetition(&output);
209 if issues.is_empty() || retry == max_retries - 1 {
210 self.context.record(&output);
212 self.generation_count += 1;
213 return Ok(output);
214 }
215 }
217
218 Err(PipelineError::GenerationFailed(max_retries))
219 }
220
221 fn resolve_narrative_fn(&self, event: &Event) -> NarrativeFunction {
222 if let Some(mapped) = self.mappings.get(&event.event_type) {
225 mapped.clone()
226 } else {
227 event.narrative_fn.clone()
228 }
229 }
230
231 fn build_context<'a>(
232 &'a self,
233 event: &Event,
234 world: &'a WorldState<'_>,
235 narrative_fn: &NarrativeFunction,
236 ) -> SelectionContext<'a> {
237 let mut ctx = SelectionContext::new();
238
239 ctx.tags.insert(event.mood.tag().to_string());
241 ctx.tags.insert(event.stakes.tag().to_string());
242
243 ctx.tags.insert(format!("fn:{}", narrative_fn.name()));
245
246 let intensity = narrative_fn.intensity();
248 if intensity >= 0.7 {
249 ctx.tags.insert("intensity:high".to_string());
250 } else if intensity <= 0.3 {
251 ctx.tags.insert("intensity:low".to_string());
252 }
253
254 for (i, participant) in event.participants.iter().enumerate() {
256 if let Some(entity) = world.entities.get(&participant.entity_id) {
257 for tag in &entity.tags {
258 ctx.tags.insert(tag.clone());
259 }
260
261 ctx.entity_bindings.insert(participant.role.clone(), entity);
263
264 if i == 0 && !ctx.entity_bindings.contains_key("subject") {
266 ctx.entity_bindings.insert("subject".to_string(), entity);
267 }
268 }
269 }
270
271 if let Some(ref location) = event.location {
273 if let Some(entity) = world.entities.get(&location.entity_id) {
274 for tag in &entity.tags {
275 ctx.tags.insert(tag.clone());
276 }
277 ctx.entity_bindings.insert(location.role.clone(), entity);
278 }
279 }
280
281 ctx
282 }
283}
284
285impl NarrativeEngineBuilder {
286 pub fn genre_templates(mut self, templates: &[&str]) -> Self {
287 self.genre_templates = templates.iter().map(|s| s.to_string()).collect();
288 self
289 }
290
291 pub fn grammars_dir(mut self, path: &str) -> Self {
292 self.grammars_dir = Some(path.to_string());
293 self
294 }
295
296 pub fn voices_dir(mut self, path: &str) -> Self {
297 self.voices_dir = Some(path.to_string());
298 self
299 }
300
301 pub fn markov_models_dir(mut self, path: &str) -> Self {
302 self.markov_models_dir = Some(path.to_string());
303 self
304 }
305
306 pub fn mappings(mut self, path: &str) -> Self {
307 self.mappings_path = Some(path.to_string());
308 self
309 }
310
311 pub fn seed(mut self, seed: u64) -> Self {
312 self.seed = seed;
313 self
314 }
315
316 pub fn with_grammars(mut self, grammars: GrammarSet) -> Self {
318 self.grammars = Some(grammars);
319 self
320 }
321
322 pub fn with_voices(mut self, voices: VoiceRegistry) -> Self {
324 self.voices = Some(voices);
325 self
326 }
327
328 pub fn with_markov_models(mut self, models: HashMap<String, MarkovModel>) -> Self {
330 self.markov_models = Some(models);
331 self
332 }
333
334 pub fn with_mappings(mut self, mappings: HashMap<String, NarrativeFunction>) -> Self {
336 self.mappings = Some(mappings);
337 self
338 }
339
340 pub fn build(self) -> Result<NarrativeEngine, PipelineError> {
341 let mut grammars = self.grammars.unwrap_or_default();
342 let mut voices = self.voices.unwrap_or_default();
343 let mut markov_models = self.markov_models.unwrap_or_default();
344 let mappings = self.mappings.unwrap_or_default();
345
346 for template_name in &self.genre_templates {
348 let grammar_path = format!("genre_data/{}/grammar.ron", template_name);
349 if Path::new(&grammar_path).exists() {
350 let template_grammars = GrammarSet::load_from_ron(Path::new(&grammar_path))?;
351 grammars.merge(template_grammars);
352 }
353
354 let voices_path = format!("genre_data/{}/voices.ron", template_name);
355 if Path::new(&voices_path).exists() {
356 voices.load_from_ron(Path::new(&voices_path))?;
357 }
358 }
359
360 if let Some(ref dir) = self.grammars_dir {
362 if Path::new(dir).exists() {
363 load_ron_files_from_dir(dir, |path| {
364 let gs = GrammarSet::load_from_ron(path)?;
365 grammars.merge(gs);
366 Ok(())
367 })?;
368 }
369 }
370
371 if let Some(ref dir) = self.voices_dir {
373 if Path::new(dir).exists() {
374 load_ron_files_from_dir(dir, |path| {
375 voices.load_from_ron(path)?;
376 Ok(())
377 })?;
378 }
379 }
380
381 if let Some(ref dir) = self.markov_models_dir {
383 if Path::new(dir).exists() {
384 load_ron_files_from_dir(dir, |path| {
385 let model = crate::core::markov::load_model(path)?;
386 let name = path
387 .file_stem()
388 .and_then(|s| s.to_str())
389 .unwrap_or("unknown")
390 .to_string();
391 markov_models.insert(name, model);
392 Ok(())
393 })?;
394 }
395 }
396
397 let mappings = if let Some(ref path) = self.mappings_path {
399 if Path::new(path).exists() {
400 let contents = std::fs::read_to_string(path)?;
401 let entries: Vec<EventMapping> = ron::from_str(&contents)?;
402 let mut map = mappings;
403 for entry in entries {
404 map.insert(entry.event_type, entry.narrative_fn);
405 }
406 map
407 } else {
408 mappings
409 }
410 } else {
411 mappings
412 };
413
414 Ok(NarrativeEngine {
415 grammars,
416 voices,
417 markov_models,
418 mappings,
419 context: NarrativeContext::default(),
420 seed: self.seed,
421 generation_count: 0,
422 })
423 }
424}
425
426fn load_ron_files_from_dir<F>(dir: &str, mut loader: F) -> Result<(), PipelineError>
428where
429 F: FnMut(&Path) -> Result<(), PipelineError>,
430{
431 let entries = std::fs::read_dir(dir)?;
432 for entry in entries {
433 let entry = entry?;
434 let path = entry.path();
435 if path.extension().and_then(|s| s.to_str()) == Some("ron") {
436 loader(&path)?;
437 }
438 }
439 Ok(())
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::core::markov::MarkovTrainer;
446 use crate::core::voice::Voice;
447 use crate::schema::entity::Value;
448 use crate::schema::event::{EntityRef, Mood, Stakes};
449
450 fn build_test_engine() -> NarrativeEngine {
451 let grammar_ron = r#"{
453 "confrontation_opening": Rule(
454 requires: ["mood:tense"],
455 excludes: [],
456 alternatives: [
457 (weight: 3, text: "{subject} stepped forward. {tense_detail}"),
458 (weight: 2, text: "The tension was palpable. {subject} spoke first."),
459 ],
460 ),
461 "tense_detail": Rule(
462 requires: [],
463 excludes: [],
464 alternatives: [
465 (weight: 2, text: "The air felt heavy with unspoken words."),
466 (weight: 2, text: "No one dared to breathe."),
467 (weight: 1, text: "A silence settled over the room."),
468 ],
469 ),
470 "revelation_opening": Rule(
471 requires: [],
472 excludes: [],
473 alternatives: [
474 (weight: 2, text: "{subject} revealed the truth at last."),
475 (weight: 1, text: "The secret was finally out."),
476 ],
477 ),
478 }"#;
479 let grammars = GrammarSet::parse_ron(grammar_ron).unwrap();
480
481 let mut voices = VoiceRegistry::new();
483 voices.register(Voice {
484 id: VoiceId(1),
485 name: "narrator".to_string(),
486 parent: None,
487 grammar_weights: HashMap::new(),
488 vocabulary: crate::core::voice::VocabularyPool::default(),
489 markov_bindings: Vec::new(),
490 structure_prefs: crate::core::voice::StructurePrefs::default(),
491 quirks: Vec::new(),
492 });
493
494 let corpus = std::fs::read_to_string("tests/fixtures/test_corpus.txt").unwrap();
496 let markov_model = MarkovTrainer::train(&corpus, 2);
497
498 let mut markov_models = HashMap::new();
499 markov_models.insert("test_corpus".to_string(), markov_model);
500
501 NarrativeEngine::builder()
502 .seed(42)
503 .with_grammars(grammars)
504 .with_voices(voices)
505 .with_markov_models(markov_models)
506 .build()
507 .unwrap()
508 }
509
510 fn make_test_world() -> (HashMap<EntityId, Entity>, Event) {
511 let mut entities = HashMap::new();
512
513 let margaret = Entity {
514 id: EntityId(1),
515 name: "Margaret".to_string(),
516 pronouns: crate::schema::entity::Pronouns::SheHer,
517 tags: ["host".to_string(), "formal".to_string()]
518 .into_iter()
519 .collect(),
520 relationships: Vec::new(),
521 voice_id: Some(VoiceId(1)),
522 properties: HashMap::from([(
523 "title".to_string(),
524 Value::String("Duchess".to_string()),
525 )]),
526 };
527
528 let james = Entity {
529 id: EntityId(2),
530 name: "James".to_string(),
531 pronouns: crate::schema::entity::Pronouns::HeHim,
532 tags: ["guest".to_string()].into_iter().collect(),
533 relationships: Vec::new(),
534 voice_id: None,
535 properties: HashMap::new(),
536 };
537
538 entities.insert(EntityId(1), margaret);
539 entities.insert(EntityId(2), james);
540
541 let event = Event {
542 event_type: "accusation".to_string(),
543 participants: vec![
544 EntityRef {
545 entity_id: EntityId(1),
546 role: "subject".to_string(),
547 },
548 EntityRef {
549 entity_id: EntityId(2),
550 role: "object".to_string(),
551 },
552 ],
553 location: None,
554 mood: Mood::Tense,
555 stakes: Stakes::High,
556 outcome: None,
557 narrative_fn: NarrativeFunction::Confrontation,
558 metadata: HashMap::new(),
559 };
560
561 (entities, event)
562 }
563
564 #[test]
565 fn narrate_produces_output() {
566 let mut engine = build_test_engine();
567 let (entities, event) = make_test_world();
568 let world = WorldState {
569 entities: &entities,
570 };
571
572 let result = engine.narrate(&event, &world).unwrap();
573 assert!(!result.is_empty(), "Expected non-empty narration");
574 assert!(
575 result.len() > 10,
576 "Expected substantial text, got: {}",
577 result
578 );
579 }
580
581 #[test]
582 fn narrate_deterministic_same_seed() {
583 let (entities, event) = make_test_world();
584
585 let mut engine1 = build_test_engine();
586 let world1 = WorldState {
587 entities: &entities,
588 };
589 let result1 = engine1.narrate(&event, &world1).unwrap();
590
591 let mut engine2 = build_test_engine();
592 let world2 = WorldState {
593 entities: &entities,
594 };
595 let result2 = engine2.narrate(&event, &world2).unwrap();
596
597 assert_eq!(result1, result2);
598 }
599
600 #[test]
601 fn narrate_different_with_different_seed() {
602 let (entities, event) = make_test_world();
603
604 let mut found_different = false;
605 let mut engine1 = NarrativeEngine::builder()
606 .seed(1)
607 .with_grammars(build_test_engine().grammars.clone())
608 .build()
609 .unwrap();
610 let world = WorldState {
611 entities: &entities,
612 };
613 let result1 = engine1.narrate(&event, &world).unwrap();
614
615 for seed in 2..50 {
616 let grammars_ron = r#"{
617 "confrontation_opening": Rule(
618 requires: ["mood:tense"],
619 excludes: [],
620 alternatives: [
621 (weight: 3, text: "{subject} stepped forward. {tense_detail}"),
622 (weight: 2, text: "The tension was palpable. {subject} spoke first."),
623 ],
624 ),
625 "tense_detail": Rule(
626 requires: [],
627 excludes: [],
628 alternatives: [
629 (weight: 2, text: "The air felt heavy with unspoken words."),
630 (weight: 2, text: "No one dared to breathe."),
631 (weight: 1, text: "A silence settled over the room."),
632 ],
633 ),
634 }"#;
635 let mut engine2 = NarrativeEngine::builder()
636 .seed(seed)
637 .with_grammars(GrammarSet::parse_ron(grammars_ron).unwrap())
638 .build()
639 .unwrap();
640 let result2 = engine2.narrate(&event, &world).unwrap();
641 if result1 != result2 {
642 found_different = true;
643 break;
644 }
645 }
646 assert!(
647 found_different,
648 "Expected different output with different seeds"
649 );
650 }
651
652 #[test]
653 fn narrate_as_with_specific_voice() {
654 let mut engine = build_test_engine();
655 let (entities, event) = make_test_world();
656 let world = WorldState {
657 entities: &entities,
658 };
659
660 let result = engine.narrate_as(&event, VoiceId(1), &world).unwrap();
661 assert!(!result.is_empty());
662 }
663
664 #[test]
665 fn narrate_variants_produces_multiple() {
666 let mut engine = build_test_engine();
667 let (entities, event) = make_test_world();
668 let world = WorldState {
669 entities: &entities,
670 };
671
672 let variants = engine.narrate_variants(&event, 3, &world).unwrap();
673 assert_eq!(variants.len(), 3);
674 for v in &variants {
675 assert!(!v.is_empty());
676 }
677 }
678
679 #[test]
680 fn narrate_contains_entity_name() {
681 let mut engine = build_test_engine();
682 let (entities, event) = make_test_world();
683 let world = WorldState {
684 entities: &entities,
685 };
686
687 let mut found_name = false;
689 for _ in 0..10 {
690 let result = engine.narrate(&event, &world).unwrap();
691 if result.contains("Margaret") {
692 found_name = true;
693 break;
694 }
695 }
696 assert!(found_name, "Expected entity name in at least one narration");
697 }
698
699 #[test]
700 fn builder_with_seed() {
701 let engine = NarrativeEngine::builder().seed(12345).build().unwrap();
702 assert_eq!(engine.seed, 12345);
703 }
704}