Skip to main content

narrative_engine/core/
voice.rs

1/// Voice system — persona/tone bundles that shape generated text.
2use rustc_hash::FxHashSet;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::schema::entity::VoiceId;
7
8/// A voice definition that shapes how text sounds for a specific
9/// speaker, narrator, or document type.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Voice {
12    pub id: VoiceId,
13    pub name: String,
14    pub parent: Option<VoiceId>,
15    #[serde(default)]
16    pub grammar_weights: HashMap<String, f32>,
17    #[serde(default)]
18    pub vocabulary: VocabularyPool,
19    #[serde(default)]
20    pub markov_bindings: Vec<MarkovBinding>,
21    #[serde(default)]
22    pub structure_prefs: StructurePrefs,
23    #[serde(default)]
24    pub quirks: Vec<Quirk>,
25}
26
27/// Preferred and avoided words for a voice.
28#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct VocabularyPool {
30    #[serde(default)]
31    pub preferred: FxHashSet<String>,
32    #[serde(default)]
33    pub avoided: FxHashSet<String>,
34}
35
36/// Binding a voice to a Markov corpus with weight and tags.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MarkovBinding {
39    pub corpus_id: String,
40    pub weight: f32,
41    #[serde(default)]
42    pub tags: Vec<String>,
43}
44
45/// Structural preferences for text generation.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct StructurePrefs {
48    /// (min, max) word count range for sentences.
49    pub avg_sentence_length: (u32, u32),
50    /// 0.0 = simple, 1.0 = complex clause structure.
51    pub clause_complexity: f32,
52    /// 0.0..1.0 probability of generating questions.
53    pub question_frequency: f32,
54}
55
56impl Default for StructurePrefs {
57    fn default() -> Self {
58        Self {
59            avg_sentence_length: (8, 18),
60            clause_complexity: 0.5,
61            question_frequency: 0.1,
62        }
63    }
64}
65
66/// A verbal tic or recurring phrase that gets occasionally inserted.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Quirk {
69    pub pattern: String,
70    /// Probability of injecting per passage (0.0..1.0).
71    pub frequency: f32,
72}
73
74/// A fully resolved voice with inheritance chain merged.
75#[derive(Debug, Clone)]
76pub struct ResolvedVoice {
77    pub id: VoiceId,
78    pub name: String,
79    pub grammar_weights: HashMap<String, f32>,
80    pub vocabulary: VocabularyPool,
81    pub markov_bindings: Vec<MarkovBinding>,
82    pub structure_prefs: StructurePrefs,
83    pub quirks: Vec<Quirk>,
84}
85
86/// Registry of all loaded voices with inheritance resolution.
87#[derive(Debug, Clone, Default)]
88pub struct VoiceRegistry {
89    voices: HashMap<VoiceId, Voice>,
90}
91
92impl VoiceRegistry {
93    pub fn new() -> Self {
94        Self {
95            voices: HashMap::new(),
96        }
97    }
98
99    pub fn register(&mut self, voice: Voice) {
100        self.voices.insert(voice.id, voice);
101    }
102
103    pub fn get(&self, id: VoiceId) -> Option<&Voice> {
104        self.voices.get(&id)
105    }
106
107    /// Resolve a voice by walking its inheritance chain and merging properties.
108    ///
109    /// Child grammar_weights override parent, vocabulary pools union,
110    /// markov_bindings concatenate, structure_prefs take child values
111    /// (falling back to parent), quirks concatenate.
112    pub fn resolve(&self, id: VoiceId) -> Option<ResolvedVoice> {
113        let voice = self.voices.get(&id)?;
114
115        // Build the inheritance chain (child first, ancestors after)
116        let mut chain = vec![voice];
117        let mut current = voice;
118        while let Some(parent_id) = current.parent {
119            if let Some(parent) = self.voices.get(&parent_id) {
120                chain.push(parent);
121                current = parent;
122            } else {
123                break;
124            }
125        }
126
127        // Resolve from root ancestor to child (so child overrides parent)
128        let mut grammar_weights = HashMap::new();
129        let mut preferred = FxHashSet::default();
130        let mut avoided = FxHashSet::default();
131        let mut markov_bindings = Vec::new();
132        let mut structure_prefs = StructurePrefs::default();
133        let mut quirks = Vec::new();
134
135        for ancestor in chain.iter().rev() {
136            // Grammar weights: child overrides parent
137            for (k, v) in &ancestor.grammar_weights {
138                grammar_weights.insert(k.clone(), *v);
139            }
140
141            // Vocabulary: union
142            preferred.extend(ancestor.vocabulary.preferred.iter().cloned());
143            avoided.extend(ancestor.vocabulary.avoided.iter().cloned());
144
145            // Markov bindings: concatenate
146            markov_bindings.extend(ancestor.markov_bindings.iter().cloned());
147
148            // Structure prefs: child takes precedence (last write wins)
149            structure_prefs = ancestor.structure_prefs.clone();
150
151            // Quirks: concatenate
152            quirks.extend(ancestor.quirks.iter().cloned());
153        }
154
155        Some(ResolvedVoice {
156            id: voice.id,
157            name: voice.name.clone(),
158            grammar_weights,
159            vocabulary: VocabularyPool { preferred, avoided },
160            markov_bindings,
161            structure_prefs,
162            quirks,
163        })
164    }
165
166    /// Load voices from a RON file. The file should contain a list of Voice definitions.
167    pub fn load_from_ron(&mut self, path: &std::path::Path) -> Result<(), VoiceError> {
168        let contents = std::fs::read_to_string(path)?;
169        let voices: Vec<Voice> = ron::from_str(&contents)?;
170        for voice in voices {
171            self.register(voice);
172        }
173        Ok(())
174    }
175}
176
177#[derive(Debug, thiserror::Error)]
178pub enum VoiceError {
179    #[error("IO error: {0}")]
180    Io(#[from] std::io::Error),
181    #[error("RON deserialization error: {0}")]
182    Ron(#[from] ron::error::SpannedError),
183    #[error("voice not found: {0:?}")]
184    NotFound(VoiceId),
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    fn make_parent_voice() -> Voice {
192        Voice {
193            id: VoiceId(1),
194            name: "military".to_string(),
195            parent: None,
196            grammar_weights: HashMap::from([
197                ("greeting".to_string(), 0.5),
198                ("action_detail".to_string(), 2.0),
199            ]),
200            vocabulary: VocabularyPool {
201                preferred: ["sir".to_string(), "affirmative".to_string()]
202                    .into_iter()
203                    .collect(),
204                avoided: ["hello".to_string()].into_iter().collect(),
205            },
206            markov_bindings: vec![MarkovBinding {
207                corpus_id: "military_prose".to_string(),
208                weight: 1.0,
209                tags: vec!["formal".to_string()],
210            }],
211            structure_prefs: StructurePrefs {
212                avg_sentence_length: (5, 12),
213                clause_complexity: 0.3,
214                question_frequency: 0.05,
215            },
216            quirks: vec![Quirk {
217                pattern: "if you will".to_string(),
218                frequency: 0.1,
219            }],
220        }
221    }
222
223    fn make_child_voice() -> Voice {
224        Voice {
225            id: VoiceId(2),
226            name: "ship_captain".to_string(),
227            parent: Some(VoiceId(1)),
228            grammar_weights: HashMap::from([
229                ("greeting".to_string(), 0.8), // overrides parent's 0.5
230                ("nautical_detail".to_string(), 3.0),
231            ]),
232            vocabulary: VocabularyPool {
233                preferred: ["aye".to_string(), "starboard".to_string()]
234                    .into_iter()
235                    .collect(),
236                avoided: FxHashSet::default(),
237            },
238            markov_bindings: vec![MarkovBinding {
239                corpus_id: "nautical_prose".to_string(),
240                weight: 1.5,
241                tags: vec!["sea".to_string()],
242            }],
243            structure_prefs: StructurePrefs {
244                avg_sentence_length: (6, 15),
245                clause_complexity: 0.4,
246                question_frequency: 0.08,
247            },
248            quirks: vec![Quirk {
249                pattern: "by the bow".to_string(),
250                frequency: 0.15,
251            }],
252        }
253    }
254
255    #[test]
256    fn voice_registry_register_and_get() {
257        let mut registry = VoiceRegistry::new();
258        let voice = make_parent_voice();
259        registry.register(voice);
260        assert!(registry.get(VoiceId(1)).is_some());
261        assert!(registry.get(VoiceId(99)).is_none());
262    }
263
264    #[test]
265    fn resolve_single_voice() {
266        let mut registry = VoiceRegistry::new();
267        registry.register(make_parent_voice());
268
269        let resolved = registry.resolve(VoiceId(1)).unwrap();
270        assert_eq!(resolved.name, "military");
271        assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.5));
272        assert!(resolved.vocabulary.preferred.contains("sir"));
273        assert!(resolved.vocabulary.avoided.contains("hello"));
274        assert_eq!(resolved.markov_bindings.len(), 1);
275        assert_eq!(resolved.quirks.len(), 1);
276    }
277
278    #[test]
279    fn resolve_inheritance_chain() {
280        let mut registry = VoiceRegistry::new();
281        registry.register(make_parent_voice());
282        registry.register(make_child_voice());
283
284        let resolved = registry.resolve(VoiceId(2)).unwrap();
285        assert_eq!(resolved.name, "ship_captain");
286
287        // Grammar weights: child overrides parent for "greeting"
288        assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.8));
289        // Parent-only weight preserved
290        assert_eq!(resolved.grammar_weights.get("action_detail"), Some(&2.0));
291        // Child-only weight present
292        assert_eq!(resolved.grammar_weights.get("nautical_detail"), Some(&3.0));
293
294        // Vocabulary: union of both
295        assert!(resolved.vocabulary.preferred.contains("sir")); // from parent
296        assert!(resolved.vocabulary.preferred.contains("aye")); // from child
297        assert!(resolved.vocabulary.preferred.contains("starboard")); // from child
298        assert!(resolved.vocabulary.avoided.contains("hello")); // from parent
299
300        // Markov bindings: concatenated
301        assert_eq!(resolved.markov_bindings.len(), 2);
302
303        // Structure prefs: child takes precedence
304        assert_eq!(resolved.structure_prefs.avg_sentence_length, (6, 15));
305
306        // Quirks: concatenated
307        assert_eq!(resolved.quirks.len(), 2);
308    }
309
310    #[test]
311    fn resolve_missing_voice() {
312        let registry = VoiceRegistry::new();
313        assert!(registry.resolve(VoiceId(99)).is_none());
314    }
315
316    #[test]
317    fn resolve_missing_parent_graceful() {
318        let mut registry = VoiceRegistry::new();
319        // Register child without its parent
320        registry.register(make_child_voice());
321
322        let resolved = registry.resolve(VoiceId(2)).unwrap();
323        // Should resolve with just the child's properties
324        assert_eq!(resolved.name, "ship_captain");
325        assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.8));
326    }
327
328    #[test]
329    fn ron_round_trip() {
330        let voice = make_parent_voice();
331        let serialized = ron::to_string(&voice).unwrap();
332        let deserialized: Voice = ron::from_str(&serialized).unwrap();
333        assert_eq!(deserialized.name, "military");
334        assert_eq!(deserialized.id, VoiceId(1));
335        assert_eq!(deserialized.grammar_weights.get("greeting"), Some(&0.5));
336    }
337
338    #[test]
339    fn voice_grammar_weight_integration() {
340        use crate::core::grammar::{GrammarSet, SelectionContext};
341        use rand::rngs::StdRng;
342        use rand::SeedableRng;
343
344        // Build a simple grammar with two alternatives
345        let grammar_ron = r#"{
346            "test_rule": Rule(
347                requires: [],
348                excludes: [],
349                alternatives: [
350                    (weight: 1, text: "option_a"),
351                    (weight: 1, text: "option_b"),
352                ],
353            ),
354        }"#;
355        let gs = GrammarSet::parse_ron(grammar_ron).unwrap();
356
357        // Without voice weights: roughly 50/50
358        let mut count_a_no_voice = 0;
359        for seed in 0..1000 {
360            let mut ctx = SelectionContext::new();
361            let mut rng = StdRng::seed_from_u64(seed);
362            let result = gs.expand("test_rule", &mut ctx, &mut rng).unwrap();
363            if result == "option_a" {
364                count_a_no_voice += 1;
365            }
366        }
367
368        // Voice weights multiply all alternatives equally for a given rule,
369        // so with equal-weight alternatives the ratio stays the same.
370        // The count should still be roughly 50/50.
371        assert!(
372            count_a_no_voice > 400 && count_a_no_voice < 600,
373            "Expected roughly 50/50 distribution, got option_a: {}/1000",
374            count_a_no_voice
375        );
376    }
377
378    #[test]
379    fn load_test_voices_from_ron() {
380        let path = std::path::PathBuf::from("tests/fixtures/test_voices.ron");
381        let mut registry = VoiceRegistry::new();
382        registry.load_from_ron(&path).unwrap();
383
384        assert!(registry.get(VoiceId(1)).is_some());
385        assert!(registry.get(VoiceId(2)).is_some());
386
387        let resolved = registry.resolve(VoiceId(2)).unwrap();
388        assert_eq!(resolved.name, "gossip");
389        // Should inherit from host
390        assert!(resolved.vocabulary.preferred.contains("indeed")); // from parent
391        assert!(resolved.vocabulary.preferred.contains("apparently")); // from child
392    }
393}