1use rustc_hash::FxHashSet;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::schema::entity::VoiceId;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Voice {
12 pub id: VoiceId,
13 pub name: String,
14 pub parent: Option<VoiceId>,
15 #[serde(default)]
16 pub grammar_weights: HashMap<String, f32>,
17 #[serde(default)]
18 pub vocabulary: VocabularyPool,
19 #[serde(default)]
20 pub markov_bindings: Vec<MarkovBinding>,
21 #[serde(default)]
22 pub structure_prefs: StructurePrefs,
23 #[serde(default)]
24 pub quirks: Vec<Quirk>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct VocabularyPool {
30 #[serde(default)]
31 pub preferred: FxHashSet<String>,
32 #[serde(default)]
33 pub avoided: FxHashSet<String>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MarkovBinding {
39 pub corpus_id: String,
40 pub weight: f32,
41 #[serde(default)]
42 pub tags: Vec<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct StructurePrefs {
48 pub avg_sentence_length: (u32, u32),
50 pub clause_complexity: f32,
52 pub question_frequency: f32,
54}
55
56impl Default for StructurePrefs {
57 fn default() -> Self {
58 Self {
59 avg_sentence_length: (8, 18),
60 clause_complexity: 0.5,
61 question_frequency: 0.1,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Quirk {
69 pub pattern: String,
70 pub frequency: f32,
72}
73
74#[derive(Debug, Clone)]
76pub struct ResolvedVoice {
77 pub id: VoiceId,
78 pub name: String,
79 pub grammar_weights: HashMap<String, f32>,
80 pub vocabulary: VocabularyPool,
81 pub markov_bindings: Vec<MarkovBinding>,
82 pub structure_prefs: StructurePrefs,
83 pub quirks: Vec<Quirk>,
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct VoiceRegistry {
89 voices: HashMap<VoiceId, Voice>,
90}
91
92impl VoiceRegistry {
93 pub fn new() -> Self {
94 Self {
95 voices: HashMap::new(),
96 }
97 }
98
99 pub fn register(&mut self, voice: Voice) {
100 self.voices.insert(voice.id, voice);
101 }
102
103 pub fn get(&self, id: VoiceId) -> Option<&Voice> {
104 self.voices.get(&id)
105 }
106
107 pub fn resolve(&self, id: VoiceId) -> Option<ResolvedVoice> {
113 let voice = self.voices.get(&id)?;
114
115 let mut chain = vec![voice];
117 let mut current = voice;
118 while let Some(parent_id) = current.parent {
119 if let Some(parent) = self.voices.get(&parent_id) {
120 chain.push(parent);
121 current = parent;
122 } else {
123 break;
124 }
125 }
126
127 let mut grammar_weights = HashMap::new();
129 let mut preferred = FxHashSet::default();
130 let mut avoided = FxHashSet::default();
131 let mut markov_bindings = Vec::new();
132 let mut structure_prefs = StructurePrefs::default();
133 let mut quirks = Vec::new();
134
135 for ancestor in chain.iter().rev() {
136 for (k, v) in &ancestor.grammar_weights {
138 grammar_weights.insert(k.clone(), *v);
139 }
140
141 preferred.extend(ancestor.vocabulary.preferred.iter().cloned());
143 avoided.extend(ancestor.vocabulary.avoided.iter().cloned());
144
145 markov_bindings.extend(ancestor.markov_bindings.iter().cloned());
147
148 structure_prefs = ancestor.structure_prefs.clone();
150
151 quirks.extend(ancestor.quirks.iter().cloned());
153 }
154
155 Some(ResolvedVoice {
156 id: voice.id,
157 name: voice.name.clone(),
158 grammar_weights,
159 vocabulary: VocabularyPool { preferred, avoided },
160 markov_bindings,
161 structure_prefs,
162 quirks,
163 })
164 }
165
166 pub fn load_from_ron(&mut self, path: &std::path::Path) -> Result<(), VoiceError> {
168 let contents = std::fs::read_to_string(path)?;
169 let voices: Vec<Voice> = ron::from_str(&contents)?;
170 for voice in voices {
171 self.register(voice);
172 }
173 Ok(())
174 }
175}
176
177#[derive(Debug, thiserror::Error)]
178pub enum VoiceError {
179 #[error("IO error: {0}")]
180 Io(#[from] std::io::Error),
181 #[error("RON deserialization error: {0}")]
182 Ron(#[from] ron::error::SpannedError),
183 #[error("voice not found: {0:?}")]
184 NotFound(VoiceId),
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 fn make_parent_voice() -> Voice {
192 Voice {
193 id: VoiceId(1),
194 name: "military".to_string(),
195 parent: None,
196 grammar_weights: HashMap::from([
197 ("greeting".to_string(), 0.5),
198 ("action_detail".to_string(), 2.0),
199 ]),
200 vocabulary: VocabularyPool {
201 preferred: ["sir".to_string(), "affirmative".to_string()]
202 .into_iter()
203 .collect(),
204 avoided: ["hello".to_string()].into_iter().collect(),
205 },
206 markov_bindings: vec![MarkovBinding {
207 corpus_id: "military_prose".to_string(),
208 weight: 1.0,
209 tags: vec!["formal".to_string()],
210 }],
211 structure_prefs: StructurePrefs {
212 avg_sentence_length: (5, 12),
213 clause_complexity: 0.3,
214 question_frequency: 0.05,
215 },
216 quirks: vec![Quirk {
217 pattern: "if you will".to_string(),
218 frequency: 0.1,
219 }],
220 }
221 }
222
223 fn make_child_voice() -> Voice {
224 Voice {
225 id: VoiceId(2),
226 name: "ship_captain".to_string(),
227 parent: Some(VoiceId(1)),
228 grammar_weights: HashMap::from([
229 ("greeting".to_string(), 0.8), ("nautical_detail".to_string(), 3.0),
231 ]),
232 vocabulary: VocabularyPool {
233 preferred: ["aye".to_string(), "starboard".to_string()]
234 .into_iter()
235 .collect(),
236 avoided: FxHashSet::default(),
237 },
238 markov_bindings: vec![MarkovBinding {
239 corpus_id: "nautical_prose".to_string(),
240 weight: 1.5,
241 tags: vec!["sea".to_string()],
242 }],
243 structure_prefs: StructurePrefs {
244 avg_sentence_length: (6, 15),
245 clause_complexity: 0.4,
246 question_frequency: 0.08,
247 },
248 quirks: vec![Quirk {
249 pattern: "by the bow".to_string(),
250 frequency: 0.15,
251 }],
252 }
253 }
254
255 #[test]
256 fn voice_registry_register_and_get() {
257 let mut registry = VoiceRegistry::new();
258 let voice = make_parent_voice();
259 registry.register(voice);
260 assert!(registry.get(VoiceId(1)).is_some());
261 assert!(registry.get(VoiceId(99)).is_none());
262 }
263
264 #[test]
265 fn resolve_single_voice() {
266 let mut registry = VoiceRegistry::new();
267 registry.register(make_parent_voice());
268
269 let resolved = registry.resolve(VoiceId(1)).unwrap();
270 assert_eq!(resolved.name, "military");
271 assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.5));
272 assert!(resolved.vocabulary.preferred.contains("sir"));
273 assert!(resolved.vocabulary.avoided.contains("hello"));
274 assert_eq!(resolved.markov_bindings.len(), 1);
275 assert_eq!(resolved.quirks.len(), 1);
276 }
277
278 #[test]
279 fn resolve_inheritance_chain() {
280 let mut registry = VoiceRegistry::new();
281 registry.register(make_parent_voice());
282 registry.register(make_child_voice());
283
284 let resolved = registry.resolve(VoiceId(2)).unwrap();
285 assert_eq!(resolved.name, "ship_captain");
286
287 assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.8));
289 assert_eq!(resolved.grammar_weights.get("action_detail"), Some(&2.0));
291 assert_eq!(resolved.grammar_weights.get("nautical_detail"), Some(&3.0));
293
294 assert!(resolved.vocabulary.preferred.contains("sir")); assert!(resolved.vocabulary.preferred.contains("aye")); assert!(resolved.vocabulary.preferred.contains("starboard")); assert!(resolved.vocabulary.avoided.contains("hello")); assert_eq!(resolved.markov_bindings.len(), 2);
302
303 assert_eq!(resolved.structure_prefs.avg_sentence_length, (6, 15));
305
306 assert_eq!(resolved.quirks.len(), 2);
308 }
309
310 #[test]
311 fn resolve_missing_voice() {
312 let registry = VoiceRegistry::new();
313 assert!(registry.resolve(VoiceId(99)).is_none());
314 }
315
316 #[test]
317 fn resolve_missing_parent_graceful() {
318 let mut registry = VoiceRegistry::new();
319 registry.register(make_child_voice());
321
322 let resolved = registry.resolve(VoiceId(2)).unwrap();
323 assert_eq!(resolved.name, "ship_captain");
325 assert_eq!(resolved.grammar_weights.get("greeting"), Some(&0.8));
326 }
327
328 #[test]
329 fn ron_round_trip() {
330 let voice = make_parent_voice();
331 let serialized = ron::to_string(&voice).unwrap();
332 let deserialized: Voice = ron::from_str(&serialized).unwrap();
333 assert_eq!(deserialized.name, "military");
334 assert_eq!(deserialized.id, VoiceId(1));
335 assert_eq!(deserialized.grammar_weights.get("greeting"), Some(&0.5));
336 }
337
338 #[test]
339 fn voice_grammar_weight_integration() {
340 use crate::core::grammar::{GrammarSet, SelectionContext};
341 use rand::rngs::StdRng;
342 use rand::SeedableRng;
343
344 let grammar_ron = r#"{
346 "test_rule": Rule(
347 requires: [],
348 excludes: [],
349 alternatives: [
350 (weight: 1, text: "option_a"),
351 (weight: 1, text: "option_b"),
352 ],
353 ),
354 }"#;
355 let gs = GrammarSet::parse_ron(grammar_ron).unwrap();
356
357 let mut count_a_no_voice = 0;
359 for seed in 0..1000 {
360 let mut ctx = SelectionContext::new();
361 let mut rng = StdRng::seed_from_u64(seed);
362 let result = gs.expand("test_rule", &mut ctx, &mut rng).unwrap();
363 if result == "option_a" {
364 count_a_no_voice += 1;
365 }
366 }
367
368 assert!(
372 count_a_no_voice > 400 && count_a_no_voice < 600,
373 "Expected roughly 50/50 distribution, got option_a: {}/1000",
374 count_a_no_voice
375 );
376 }
377
378 #[test]
379 fn load_test_voices_from_ron() {
380 let path = std::path::PathBuf::from("tests/fixtures/test_voices.ron");
381 let mut registry = VoiceRegistry::new();
382 registry.load_from_ron(&path).unwrap();
383
384 assert!(registry.get(VoiceId(1)).is_some());
385 assert!(registry.get(VoiceId(2)).is_some());
386
387 let resolved = registry.resolve(VoiceId(2)).unwrap();
388 assert_eq!(resolved.name, "gossip");
389 assert!(resolved.vocabulary.preferred.contains("indeed")); assert!(resolved.vocabulary.preferred.contains("apparently")); }
393}