1use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9 Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16#[derive(Debug, Error)]
18pub enum ScenarioError {
19 #[error("intervention error: {0}")]
20 Intervention(#[from] InterventionError),
21 #[error("propagation error: {0}")]
22 Propagation(#[from] PropagationError),
23 #[error("mutation error: {0}")]
24 Mutation(#[from] MutationError),
25 #[error("DAG error: {0}")]
26 Dag(#[from] CausalDAGError),
27 #[error("generation error: {0}")]
28 Generation(String),
29 #[error("IO error: {0}")]
30 Io(#[from] std::io::Error),
31 #[error("serialization error: {0}")]
32 Serialization(String),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38 pub scenario_name: String,
39 pub baseline_path: PathBuf,
40 pub counterfactual_path: PathBuf,
41 pub interventions_applied: usize,
42 pub months_affected: usize,
43}
44
45pub struct ScenarioEngine {
47 base_config: GeneratorConfig,
48 causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52 pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54 let causal_dag = Self::load_causal_dag(&config)?;
55 Ok(Self {
56 base_config: config,
57 causal_dag,
58 })
59 }
60
61 fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63 let causal_config = &config.scenarios.causal_model;
64 let mut dag: CausalDAG = match causal_config.preset.as_str() {
65 "default" | "" => {
66 let yaml = include_str!("causal_dag_default.yaml");
67 serde_yaml::from_str(yaml).map_err(|e| {
68 ScenarioError::Serialization(format!("failed to parse default causal DAG: {e}"))
69 })?
70 }
71 "minimal" => {
72 use datasynth_core::causal_dag::{
73 CausalEdge, CausalNode, NodeCategory, TransferFunction,
74 };
75 CausalDAG {
77 nodes: vec![
78 CausalNode {
79 id: "gdp_growth".to_string(),
80 label: "GDP Growth".to_string(),
81 category: NodeCategory::Macro,
82 baseline_value: 0.025,
83 bounds: Some((-0.10, 0.15)),
84 interventionable: true,
85 config_bindings: vec![],
86 },
87 CausalNode {
88 id: "transaction_volume".to_string(),
89 label: "Transaction Volume".to_string(),
90 category: NodeCategory::Operational,
91 baseline_value: 1.0,
92 bounds: Some((0.2, 3.0)),
93 interventionable: true,
94 config_bindings: vec!["transactions.volume_multiplier".to_string()],
95 },
96 CausalNode {
97 id: "error_rate".to_string(),
98 label: "Error Rate".to_string(),
99 category: NodeCategory::Outcome,
100 baseline_value: 0.02,
101 bounds: Some((0.0, 0.30)),
102 interventionable: false,
103 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
104 },
105 ],
106 edges: vec![
107 CausalEdge {
108 from: "gdp_growth".to_string(),
109 to: "transaction_volume".to_string(),
110 transfer: TransferFunction::Linear {
111 coefficient: 0.8,
112 intercept: 1.0,
113 },
114 lag_months: 1,
115 strength: 1.0,
116 mechanism: Some("GDP growth drives transaction volume".to_string()),
117 },
118 CausalEdge {
119 from: "transaction_volume".to_string(),
120 to: "error_rate".to_string(),
121 transfer: TransferFunction::Linear {
122 coefficient: 0.01,
123 intercept: 0.0,
124 },
125 lag_months: 0,
126 strength: 1.0,
127 mechanism: Some("Higher volume increases error rate".to_string()),
128 },
129 ],
130 topological_order: vec![],
131 }
132 }
133 other => {
134 return Err(ScenarioError::Serialization(format!(
135 "unknown causal DAG preset: '{other}'"
136 )));
137 }
138 };
139
140 dag.validate()?;
141 Ok(dag)
142 }
143
144 pub fn causal_dag(&self) -> &CausalDAG {
146 &self.causal_dag
147 }
148
149 pub fn base_config(&self) -> &GeneratorConfig {
151 &self.base_config
152 }
153
154 pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
156 let scenarios = &self.base_config.scenarios.scenarios;
157 let mut results = Vec::with_capacity(scenarios.len());
158
159 let baseline_path = output_root.join("baseline");
161 std::fs::create_dir_all(&baseline_path)?;
162
163 for scenario in scenarios {
165 let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
166 results.push(result);
167 }
168
169 Ok(results)
170 }
171
172 pub fn generate_scenario(
174 &self,
175 scenario: &ScenarioSchemaConfig,
176 baseline_path: &Path,
177 output_root: &Path,
178 ) -> Result<ScenarioResult, ScenarioError> {
179 let interventions = Self::convert_interventions(&scenario.interventions)?;
181
182 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
184
185 let engine = CausalPropagationEngine::new(&self.causal_dag);
187 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
188
189 let constraints = ScenarioConstraints {
191 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
192 preserve_document_chains: scenario.constraints.preserve_document_chains,
193 preserve_period_close: scenario.constraints.preserve_period_close,
194 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
195 custom: vec![],
196 };
197
198 let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
200
201 let scenario_path = output_root
203 .join("scenarios")
204 .join(&scenario.name)
205 .join("data");
206 std::fs::create_dir_all(&scenario_path)?;
207
208 let manifest = ScenarioManifest {
210 scenario_name: scenario.name.clone(),
211 description: scenario.description.clone(),
212 interventions_count: interventions.len(),
213 months_affected: propagated.changes_by_month.len(),
214 config_paths_changed: propagated
215 .changes_by_month
216 .values()
217 .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
218 .collect::<std::collections::HashSet<_>>()
219 .into_iter()
220 .collect(),
221 };
222
223 let manifest_path = output_root
224 .join("scenarios")
225 .join(&scenario.name)
226 .join("scenario_manifest.yaml");
227 let manifest_yaml = serde_yaml::to_string(&manifest)
228 .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
229 std::fs::write(&manifest_path, manifest_yaml)?;
230
231 Ok(ScenarioResult {
232 scenario_name: scenario.name.clone(),
233 baseline_path: baseline_path.to_path_buf(),
234 counterfactual_path: scenario_path,
235 interventions_applied: interventions.len(),
236 months_affected: propagated.changes_by_month.len(),
237 })
238 }
239
240 fn convert_interventions(
242 schema_interventions: &[datasynth_config::InterventionSchemaConfig],
243 ) -> Result<Vec<Intervention>, ScenarioError> {
244 let mut interventions = Vec::new();
245
246 for schema in schema_interventions {
247 let intervention_type: InterventionType =
248 serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
249 ScenarioError::Serialization(format!("failed to parse intervention type: {e}"))
250 })?;
251
252 let onset = match schema.timing.onset.to_lowercase().as_str() {
253 "sudden" => OnsetType::Sudden,
254 "gradual" => OnsetType::Gradual,
255 "oscillating" => OnsetType::Oscillating,
256 _ => OnsetType::Sudden,
257 };
258
259 interventions.push(Intervention {
260 id: Uuid::new_v4(),
261 intervention_type,
262 timing: InterventionTiming {
263 start_month: schema.timing.start_month,
264 duration_months: schema.timing.duration_months,
265 onset,
266 ramp_months: schema.timing.ramp_months,
267 },
268 label: schema.label.clone(),
269 priority: schema.priority,
270 });
271 }
272
273 Ok(interventions)
274 }
275
276 pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
278 self.base_config
279 .scenarios
280 .scenarios
281 .iter()
282 .map(|s| ScenarioSummary {
283 name: s.name.clone(),
284 description: s.description.clone(),
285 tags: s.tags.clone(),
286 intervention_count: s.interventions.len(),
287 probability_weight: s.probability_weight,
288 })
289 .collect()
290 }
291
292 pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
294 self.base_config
295 .scenarios
296 .scenarios
297 .iter()
298 .map(|s| {
299 let result = self.validate_scenario(s);
300 ScenarioValidationResult {
301 name: s.name.clone(),
302 valid: result.is_ok(),
303 error: result.err().map(|e| e.to_string()),
304 }
305 })
306 .collect()
307 }
308
309 fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
311 let interventions = Self::convert_interventions(&scenario.interventions)?;
312 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
313 let engine = CausalPropagationEngine::new(&self.causal_dag);
314 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
315
316 let constraints = ScenarioConstraints {
317 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
318 preserve_document_chains: scenario.constraints.preserve_document_chains,
319 preserve_period_close: scenario.constraints.preserve_period_close,
320 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
321 custom: vec![],
322 };
323
324 let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
325 Ok(())
326 }
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct ScenarioSummary {
332 pub name: String,
333 pub description: String,
334 pub tags: Vec<String>,
335 pub intervention_count: usize,
336 pub probability_weight: Option<f64>,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ScenarioValidationResult {
342 pub name: String,
343 pub valid: bool,
344 pub error: Option<String>,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349struct ScenarioManifest {
350 scenario_name: String,
351 description: String,
352 interventions_count: usize,
353 months_affected: usize,
354 config_paths_changed: Vec<String>,
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use datasynth_config::{
361 InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
362 ScenarioOutputSchemaConfig, ScenariosConfig,
363 };
364 use datasynth_test_utils::fixtures::minimal_config;
365 use tempfile::TempDir;
366
367 fn config_with_scenario() -> GeneratorConfig {
368 let mut config = minimal_config();
369 config.scenarios = ScenariosConfig {
370 enabled: true,
371 scenarios: vec![ScenarioSchemaConfig {
372 name: "test_recession".to_string(),
373 description: "Test recession scenario".to_string(),
374 tags: vec!["test".to_string()],
375 base: None,
376 probability_weight: Some(0.3),
377 interventions: vec![InterventionSchemaConfig {
378 intervention_type: serde_json::json!({
379 "type": "parameter_shift",
380 "target": "global.period_months",
381 "to": 3,
382 "interpolation": "linear"
383 }),
384 timing: InterventionTimingSchemaConfig {
385 start_month: 1,
386 duration_months: None,
387 onset: "sudden".to_string(),
388 ramp_months: None,
389 },
390 label: Some("Test shift".to_string()),
391 priority: 0,
392 }],
393 constraints: ScenarioConstraintsSchemaConfig::default(),
394 output: ScenarioOutputSchemaConfig::default(),
395 metadata: Default::default(),
396 }],
397 causal_model: Default::default(),
398 defaults: Default::default(),
399 };
400 config
401 }
402
403 #[test]
404 fn test_scenario_engine_new_default_dag() {
405 let config = config_with_scenario();
406 let engine = ScenarioEngine::new(config).expect("should create engine");
407 assert!(!engine.causal_dag().nodes.is_empty());
408 assert!(!engine.causal_dag().edges.is_empty());
409 }
410
411 #[test]
412 fn test_scenario_engine_list_scenarios() {
413 let config = config_with_scenario();
414 let engine = ScenarioEngine::new(config).expect("should create engine");
415 let scenarios = engine.list_scenarios();
416 assert_eq!(scenarios.len(), 1);
417 assert_eq!(scenarios[0].name, "test_recession");
418 assert_eq!(scenarios[0].intervention_count, 1);
419 }
420
421 #[test]
422 fn test_scenario_engine_validate_all() {
423 let config = config_with_scenario();
424 let engine = ScenarioEngine::new(config).expect("should create engine");
425 let results = engine.validate_all();
426 assert_eq!(results.len(), 1);
427 assert!(results[0].valid, "validation error: {:?}", results[0].error);
428 }
429
430 #[test]
431 fn test_scenario_engine_converts_schema_to_interventions() {
432 let config = config_with_scenario();
433 let interventions =
434 ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
435 .expect("should convert");
436 assert_eq!(interventions.len(), 1);
437 assert!(matches!(
438 interventions[0].intervention_type,
439 InterventionType::ParameterShift(_)
440 ));
441 }
442
443 #[test]
444 fn test_minimal_dag_preset_valid() {
445 let mut config = minimal_config();
446 config.scenarios = ScenariosConfig {
447 enabled: true,
448 scenarios: vec![ScenarioSchemaConfig {
449 name: "minimal_test".to_string(),
450 description: "Test with minimal DAG".to_string(),
451 tags: vec![],
452 base: None,
453 probability_weight: None,
454 interventions: vec![InterventionSchemaConfig {
455 intervention_type: serde_json::json!({
456 "type": "parameter_shift",
457 "target": "transactions.volume_multiplier",
458 "to": 2.0,
459 "interpolation": "linear"
460 }),
461 timing: InterventionTimingSchemaConfig {
462 start_month: 1,
463 duration_months: None,
464 onset: "sudden".to_string(),
465 ramp_months: None,
466 },
467 label: Some("Volume increase".to_string()),
468 priority: 0,
469 }],
470 constraints: ScenarioConstraintsSchemaConfig::default(),
471 output: ScenarioOutputSchemaConfig::default(),
472 metadata: Default::default(),
473 }],
474 causal_model: datasynth_config::CausalModelSchemaConfig {
475 preset: "minimal".to_string(),
476 ..Default::default()
477 },
478 defaults: Default::default(),
479 };
480
481 let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
482 assert_eq!(engine.causal_dag().nodes.len(), 3);
483 assert_eq!(engine.causal_dag().edges.len(), 2);
484
485 let results = engine.validate_all();
487 assert_eq!(results.len(), 1);
488 assert!(results[0].valid, "validation error: {:?}", results[0].error);
489 }
490
491 #[test]
492 fn test_scenario_engine_generates_output() {
493 let config = config_with_scenario();
494 let engine = ScenarioEngine::new(config).expect("should create engine");
495 let tmpdir = TempDir::new().expect("should create tmpdir");
496 let results = engine.generate_all(tmpdir.path()).expect("should generate");
497 assert_eq!(results.len(), 1);
498 assert_eq!(results[0].scenario_name, "test_recession");
499 let manifest_path = tmpdir
501 .path()
502 .join("scenarios")
503 .join("test_recession")
504 .join("scenario_manifest.yaml");
505 assert!(manifest_path.exists());
506 }
507}