1use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9 Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16#[derive(Debug, Error)]
18pub enum ScenarioError {
19 #[error("intervention error: {0}")]
20 Intervention(#[from] InterventionError),
21 #[error("propagation error: {0}")]
22 Propagation(#[from] PropagationError),
23 #[error("mutation error: {0}")]
24 Mutation(#[from] MutationError),
25 #[error("DAG error: {0}")]
26 Dag(#[from] CausalDAGError),
27 #[error("generation error: {0}")]
28 Generation(String),
29 #[error("IO error: {0}")]
30 Io(#[from] std::io::Error),
31 #[error("serialization error: {0}")]
32 Serialization(String),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38 pub scenario_name: String,
39 pub baseline_path: PathBuf,
40 pub counterfactual_path: PathBuf,
41 pub interventions_applied: usize,
42 pub months_affected: usize,
43}
44
45pub struct ScenarioEngine {
47 base_config: GeneratorConfig,
48 causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52 pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54 let causal_dag = Self::load_causal_dag(&config)?;
55 Ok(Self {
56 base_config: config,
57 causal_dag,
58 })
59 }
60
61 fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63 let causal_config = &config.scenarios.causal_model;
64 let mut dag: CausalDAG = match causal_config.preset.as_str() {
65 "default" | "" => {
66 let yaml = include_str!("causal_dag_default.yaml");
67 serde_yaml::from_str(yaml).map_err(|e| {
68 ScenarioError::Serialization(format!(
69 "failed to parse default causal DAG: {}",
70 e
71 ))
72 })?
73 }
74 "minimal" => {
75 use datasynth_core::causal_dag::{
76 CausalEdge, CausalNode, NodeCategory, TransferFunction,
77 };
78 CausalDAG {
80 nodes: vec![
81 CausalNode {
82 id: "gdp_growth".to_string(),
83 label: "GDP Growth".to_string(),
84 category: NodeCategory::Macro,
85 baseline_value: 0.025,
86 bounds: Some((-0.10, 0.15)),
87 interventionable: true,
88 config_bindings: vec![],
89 },
90 CausalNode {
91 id: "transaction_volume".to_string(),
92 label: "Transaction Volume".to_string(),
93 category: NodeCategory::Operational,
94 baseline_value: 1.0,
95 bounds: Some((0.2, 3.0)),
96 interventionable: true,
97 config_bindings: vec!["transactions.volume_multiplier".to_string()],
98 },
99 CausalNode {
100 id: "error_rate".to_string(),
101 label: "Error Rate".to_string(),
102 category: NodeCategory::Outcome,
103 baseline_value: 0.02,
104 bounds: Some((0.0, 0.30)),
105 interventionable: false,
106 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
107 },
108 ],
109 edges: vec![
110 CausalEdge {
111 from: "gdp_growth".to_string(),
112 to: "transaction_volume".to_string(),
113 transfer: TransferFunction::Linear {
114 coefficient: 0.8,
115 intercept: 1.0,
116 },
117 lag_months: 1,
118 strength: 1.0,
119 mechanism: Some("GDP growth drives transaction volume".to_string()),
120 },
121 CausalEdge {
122 from: "transaction_volume".to_string(),
123 to: "error_rate".to_string(),
124 transfer: TransferFunction::Linear {
125 coefficient: 0.01,
126 intercept: 0.0,
127 },
128 lag_months: 0,
129 strength: 1.0,
130 mechanism: Some("Higher volume increases error rate".to_string()),
131 },
132 ],
133 topological_order: vec![],
134 }
135 }
136 other => {
137 return Err(ScenarioError::Serialization(format!(
138 "unknown causal DAG preset: '{}'",
139 other
140 )));
141 }
142 };
143
144 dag.validate()?;
145 Ok(dag)
146 }
147
148 pub fn causal_dag(&self) -> &CausalDAG {
150 &self.causal_dag
151 }
152
153 pub fn base_config(&self) -> &GeneratorConfig {
155 &self.base_config
156 }
157
158 pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
160 let scenarios = &self.base_config.scenarios.scenarios;
161 let mut results = Vec::with_capacity(scenarios.len());
162
163 let baseline_path = output_root.join("baseline");
165 std::fs::create_dir_all(&baseline_path)?;
166
167 for scenario in scenarios {
169 let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
170 results.push(result);
171 }
172
173 Ok(results)
174 }
175
176 pub fn generate_scenario(
178 &self,
179 scenario: &ScenarioSchemaConfig,
180 baseline_path: &Path,
181 output_root: &Path,
182 ) -> Result<ScenarioResult, ScenarioError> {
183 let interventions = Self::convert_interventions(&scenario.interventions)?;
185
186 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
188
189 let engine = CausalPropagationEngine::new(&self.causal_dag);
191 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
192
193 let constraints = ScenarioConstraints {
195 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
196 preserve_document_chains: scenario.constraints.preserve_document_chains,
197 preserve_period_close: scenario.constraints.preserve_period_close,
198 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
199 custom: vec![],
200 };
201
202 let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
204
205 let scenario_path = output_root
207 .join("scenarios")
208 .join(&scenario.name)
209 .join("data");
210 std::fs::create_dir_all(&scenario_path)?;
211
212 let manifest = ScenarioManifest {
214 scenario_name: scenario.name.clone(),
215 description: scenario.description.clone(),
216 interventions_count: interventions.len(),
217 months_affected: propagated.changes_by_month.len(),
218 config_paths_changed: propagated
219 .changes_by_month
220 .values()
221 .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
222 .collect::<std::collections::HashSet<_>>()
223 .into_iter()
224 .collect(),
225 };
226
227 let manifest_path = output_root
228 .join("scenarios")
229 .join(&scenario.name)
230 .join("scenario_manifest.yaml");
231 let manifest_yaml = serde_yaml::to_string(&manifest)
232 .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
233 std::fs::write(&manifest_path, manifest_yaml)?;
234
235 Ok(ScenarioResult {
236 scenario_name: scenario.name.clone(),
237 baseline_path: baseline_path.to_path_buf(),
238 counterfactual_path: scenario_path,
239 interventions_applied: interventions.len(),
240 months_affected: propagated.changes_by_month.len(),
241 })
242 }
243
244 fn convert_interventions(
246 schema_interventions: &[datasynth_config::InterventionSchemaConfig],
247 ) -> Result<Vec<Intervention>, ScenarioError> {
248 let mut interventions = Vec::new();
249
250 for schema in schema_interventions {
251 let intervention_type: InterventionType =
252 serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
253 ScenarioError::Serialization(format!(
254 "failed to parse intervention type: {}",
255 e
256 ))
257 })?;
258
259 let onset = match schema.timing.onset.to_lowercase().as_str() {
260 "sudden" => OnsetType::Sudden,
261 "gradual" => OnsetType::Gradual,
262 "oscillating" => OnsetType::Oscillating,
263 _ => OnsetType::Sudden,
264 };
265
266 interventions.push(Intervention {
267 id: Uuid::new_v4(),
268 intervention_type,
269 timing: InterventionTiming {
270 start_month: schema.timing.start_month,
271 duration_months: schema.timing.duration_months,
272 onset,
273 ramp_months: schema.timing.ramp_months,
274 },
275 label: schema.label.clone(),
276 priority: schema.priority,
277 });
278 }
279
280 Ok(interventions)
281 }
282
283 pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
285 self.base_config
286 .scenarios
287 .scenarios
288 .iter()
289 .map(|s| ScenarioSummary {
290 name: s.name.clone(),
291 description: s.description.clone(),
292 tags: s.tags.clone(),
293 intervention_count: s.interventions.len(),
294 probability_weight: s.probability_weight,
295 })
296 .collect()
297 }
298
299 pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
301 self.base_config
302 .scenarios
303 .scenarios
304 .iter()
305 .map(|s| {
306 let result = self.validate_scenario(s);
307 ScenarioValidationResult {
308 name: s.name.clone(),
309 valid: result.is_ok(),
310 error: result.err().map(|e| e.to_string()),
311 }
312 })
313 .collect()
314 }
315
316 fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
318 let interventions = Self::convert_interventions(&scenario.interventions)?;
319 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
320 let engine = CausalPropagationEngine::new(&self.causal_dag);
321 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
322
323 let constraints = ScenarioConstraints {
324 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
325 preserve_document_chains: scenario.constraints.preserve_document_chains,
326 preserve_period_close: scenario.constraints.preserve_period_close,
327 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
328 custom: vec![],
329 };
330
331 let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
332 Ok(())
333 }
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ScenarioSummary {
339 pub name: String,
340 pub description: String,
341 pub tags: Vec<String>,
342 pub intervention_count: usize,
343 pub probability_weight: Option<f64>,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ScenarioValidationResult {
349 pub name: String,
350 pub valid: bool,
351 pub error: Option<String>,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
356struct ScenarioManifest {
357 scenario_name: String,
358 description: String,
359 interventions_count: usize,
360 months_affected: usize,
361 config_paths_changed: Vec<String>,
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use datasynth_config::{
368 InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
369 ScenarioOutputSchemaConfig, ScenariosConfig,
370 };
371 use datasynth_test_utils::fixtures::minimal_config;
372 use tempfile::TempDir;
373
374 fn config_with_scenario() -> GeneratorConfig {
375 let mut config = minimal_config();
376 config.scenarios = ScenariosConfig {
377 enabled: true,
378 scenarios: vec![ScenarioSchemaConfig {
379 name: "test_recession".to_string(),
380 description: "Test recession scenario".to_string(),
381 tags: vec!["test".to_string()],
382 base: None,
383 probability_weight: Some(0.3),
384 interventions: vec![InterventionSchemaConfig {
385 intervention_type: serde_json::json!({
386 "type": "parameter_shift",
387 "target": "global.period_months",
388 "to": 3,
389 "interpolation": "linear"
390 }),
391 timing: InterventionTimingSchemaConfig {
392 start_month: 1,
393 duration_months: None,
394 onset: "sudden".to_string(),
395 ramp_months: None,
396 },
397 label: Some("Test shift".to_string()),
398 priority: 0,
399 }],
400 constraints: ScenarioConstraintsSchemaConfig::default(),
401 output: ScenarioOutputSchemaConfig::default(),
402 metadata: Default::default(),
403 }],
404 causal_model: Default::default(),
405 defaults: Default::default(),
406 };
407 config
408 }
409
410 #[test]
411 fn test_scenario_engine_new_default_dag() {
412 let config = config_with_scenario();
413 let engine = ScenarioEngine::new(config).expect("should create engine");
414 assert!(!engine.causal_dag().nodes.is_empty());
415 assert!(!engine.causal_dag().edges.is_empty());
416 }
417
418 #[test]
419 fn test_scenario_engine_list_scenarios() {
420 let config = config_with_scenario();
421 let engine = ScenarioEngine::new(config).expect("should create engine");
422 let scenarios = engine.list_scenarios();
423 assert_eq!(scenarios.len(), 1);
424 assert_eq!(scenarios[0].name, "test_recession");
425 assert_eq!(scenarios[0].intervention_count, 1);
426 }
427
428 #[test]
429 fn test_scenario_engine_validate_all() {
430 let config = config_with_scenario();
431 let engine = ScenarioEngine::new(config).expect("should create engine");
432 let results = engine.validate_all();
433 assert_eq!(results.len(), 1);
434 assert!(results[0].valid, "validation error: {:?}", results[0].error);
435 }
436
437 #[test]
438 fn test_scenario_engine_converts_schema_to_interventions() {
439 let config = config_with_scenario();
440 let interventions =
441 ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
442 .expect("should convert");
443 assert_eq!(interventions.len(), 1);
444 assert!(matches!(
445 interventions[0].intervention_type,
446 InterventionType::ParameterShift(_)
447 ));
448 }
449
450 #[test]
451 fn test_minimal_dag_preset_valid() {
452 let mut config = minimal_config();
453 config.scenarios = ScenariosConfig {
454 enabled: true,
455 scenarios: vec![ScenarioSchemaConfig {
456 name: "minimal_test".to_string(),
457 description: "Test with minimal DAG".to_string(),
458 tags: vec![],
459 base: None,
460 probability_weight: None,
461 interventions: vec![InterventionSchemaConfig {
462 intervention_type: serde_json::json!({
463 "type": "parameter_shift",
464 "target": "transactions.volume_multiplier",
465 "to": 2.0,
466 "interpolation": "linear"
467 }),
468 timing: InterventionTimingSchemaConfig {
469 start_month: 1,
470 duration_months: None,
471 onset: "sudden".to_string(),
472 ramp_months: None,
473 },
474 label: Some("Volume increase".to_string()),
475 priority: 0,
476 }],
477 constraints: ScenarioConstraintsSchemaConfig::default(),
478 output: ScenarioOutputSchemaConfig::default(),
479 metadata: Default::default(),
480 }],
481 causal_model: datasynth_config::CausalModelSchemaConfig {
482 preset: "minimal".to_string(),
483 ..Default::default()
484 },
485 defaults: Default::default(),
486 };
487
488 let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
489 assert_eq!(engine.causal_dag().nodes.len(), 3);
490 assert_eq!(engine.causal_dag().edges.len(), 2);
491
492 let results = engine.validate_all();
494 assert_eq!(results.len(), 1);
495 assert!(results[0].valid, "validation error: {:?}", results[0].error);
496 }
497
498 #[test]
499 fn test_scenario_engine_generates_output() {
500 let config = config_with_scenario();
501 let engine = ScenarioEngine::new(config).expect("should create engine");
502 let tmpdir = TempDir::new().expect("should create tmpdir");
503 let results = engine.generate_all(tmpdir.path()).expect("should generate");
504 assert_eq!(results.len(), 1);
505 assert_eq!(results[0].scenario_name, "test_recession");
506 let manifest_path = tmpdir
508 .path()
509 .join("scenarios")
510 .join("test_recession")
511 .join("scenario_manifest.yaml");
512 assert!(manifest_path.exists());
513 }
514}