1use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9 Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16#[derive(Debug, Error)]
18pub enum ScenarioError {
19 #[error("intervention error: {0}")]
20 Intervention(#[from] InterventionError),
21 #[error("propagation error: {0}")]
22 Propagation(#[from] PropagationError),
23 #[error("mutation error: {0}")]
24 Mutation(#[from] MutationError),
25 #[error("DAG error: {0}")]
26 Dag(#[from] CausalDAGError),
27 #[error("generation error: {0}")]
28 Generation(String),
29 #[error("IO error: {0}")]
30 Io(#[from] std::io::Error),
31 #[error("serialization error: {0}")]
32 Serialization(String),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38 pub scenario_name: String,
39 pub baseline_path: PathBuf,
40 pub counterfactual_path: PathBuf,
41 pub interventions_applied: usize,
42 pub months_affected: usize,
43}
44
45pub struct ScenarioEngine {
47 base_config: GeneratorConfig,
48 causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52 pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54 let causal_dag = Self::load_causal_dag(&config)?;
55 Ok(Self {
56 base_config: config,
57 causal_dag,
58 })
59 }
60
61 fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
69 let causal_config = &config.scenarios.causal_model;
70 let mut dag: CausalDAG = match causal_config.preset.as_str() {
71 "custom" => {
72 if causal_config.nodes.is_empty() || causal_config.edges.is_empty() {
73 return Err(ScenarioError::Serialization(
74 "causal_model.preset = \"custom\" requires both `nodes` and `edges` \
75 to be populated in the config"
76 .to_string(),
77 ));
78 }
79 let nodes: Vec<datasynth_core::causal_dag::CausalNode> = causal_config
80 .nodes
81 .iter()
82 .enumerate()
83 .map(|(i, v)| {
84 serde_json::from_value(v.clone()).map_err(|e| {
85 ScenarioError::Serialization(format!("causal_model.nodes[{i}]: {e}"))
86 })
87 })
88 .collect::<Result<_, _>>()?;
89 let edges: Vec<datasynth_core::causal_dag::CausalEdge> = causal_config
90 .edges
91 .iter()
92 .enumerate()
93 .map(|(i, v)| {
94 serde_json::from_value(v.clone()).map_err(|e| {
95 ScenarioError::Serialization(format!("causal_model.edges[{i}]: {e}"))
96 })
97 })
98 .collect::<Result<_, _>>()?;
99 CausalDAG {
100 nodes,
101 edges,
102 topological_order: Vec::new(),
103 }
104 }
105 "default" | "" => {
106 let yaml = include_str!("causal_dag_default.yaml");
107 serde_yaml::from_str(yaml).map_err(|e| {
108 ScenarioError::Serialization(format!("failed to parse default causal DAG: {e}"))
109 })?
110 }
111 "manufacturing" => {
112 let yaml = include_str!("causal_dag_manufacturing.yaml");
113 serde_yaml::from_str(yaml).map_err(|e| {
114 ScenarioError::Serialization(format!(
115 "failed to parse manufacturing causal DAG: {e}"
116 ))
117 })?
118 }
119 "retail" => {
120 let yaml = include_str!("causal_dag_retail.yaml");
121 serde_yaml::from_str(yaml).map_err(|e| {
122 ScenarioError::Serialization(format!("failed to parse retail causal DAG: {e}"))
123 })?
124 }
125 "financial_services" => {
126 let yaml = include_str!("causal_dag_financial_services.yaml");
127 serde_yaml::from_str(yaml).map_err(|e| {
128 ScenarioError::Serialization(format!(
129 "failed to parse financial_services causal DAG: {e}"
130 ))
131 })?
132 }
133 "minimal" => {
134 use datasynth_core::causal_dag::{
135 CausalEdge, CausalNode, NodeCategory, TransferFunction,
136 };
137 CausalDAG {
139 nodes: vec![
140 CausalNode {
141 id: "gdp_growth".to_string(),
142 label: "GDP Growth".to_string(),
143 category: NodeCategory::Macro,
144 baseline_value: 0.025,
145 bounds: Some((-0.10, 0.15)),
146 interventionable: true,
147 config_bindings: vec![],
148 },
149 CausalNode {
150 id: "transaction_volume".to_string(),
151 label: "Transaction Volume".to_string(),
152 category: NodeCategory::Operational,
153 baseline_value: 1.0,
154 bounds: Some((0.2, 3.0)),
155 interventionable: true,
156 config_bindings: vec!["transactions.volume_multiplier".to_string()],
157 },
158 CausalNode {
159 id: "error_rate".to_string(),
160 label: "Error Rate".to_string(),
161 category: NodeCategory::Outcome,
162 baseline_value: 0.02,
163 bounds: Some((0.0, 0.30)),
164 interventionable: false,
165 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
166 },
167 ],
168 edges: vec![
169 CausalEdge {
170 from: "gdp_growth".to_string(),
171 to: "transaction_volume".to_string(),
172 transfer: TransferFunction::Linear {
173 coefficient: 0.8,
174 intercept: 1.0,
175 },
176 lag_months: 1,
177 strength: 1.0,
178 mechanism: Some("GDP growth drives transaction volume".to_string()),
179 },
180 CausalEdge {
181 from: "transaction_volume".to_string(),
182 to: "error_rate".to_string(),
183 transfer: TransferFunction::Linear {
184 coefficient: 0.01,
185 intercept: 0.0,
186 },
187 lag_months: 0,
188 strength: 1.0,
189 mechanism: Some("Higher volume increases error rate".to_string()),
190 },
191 ],
192 topological_order: vec![],
193 }
194 }
195 other => {
196 return Err(ScenarioError::Serialization(format!(
197 "unknown causal DAG preset: '{other}'"
198 )));
199 }
200 };
201
202 dag.validate()?;
203 Ok(dag)
204 }
205
206 pub fn causal_dag(&self) -> &CausalDAG {
208 &self.causal_dag
209 }
210
211 pub fn base_config(&self) -> &GeneratorConfig {
213 &self.base_config
214 }
215
216 pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
218 let scenarios = &self.base_config.scenarios.scenarios;
219 let mut results = Vec::with_capacity(scenarios.len());
220
221 let baseline_path = output_root.join("baseline");
223 std::fs::create_dir_all(&baseline_path)?;
224
225 for scenario in scenarios {
227 let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
228 results.push(result);
229 }
230
231 Ok(results)
232 }
233
234 pub fn generate_scenario(
236 &self,
237 scenario: &ScenarioSchemaConfig,
238 baseline_path: &Path,
239 output_root: &Path,
240 ) -> Result<ScenarioResult, ScenarioError> {
241 let interventions = Self::convert_interventions(&scenario.interventions)?;
243
244 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
246
247 let engine = CausalPropagationEngine::new(&self.causal_dag);
249 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
250
251 let constraints = ScenarioConstraints {
253 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
254 preserve_document_chains: scenario.constraints.preserve_document_chains,
255 preserve_period_close: scenario.constraints.preserve_period_close,
256 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
257 custom: vec![],
258 };
259
260 let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
262
263 let scenario_path = output_root
265 .join("scenarios")
266 .join(&scenario.name)
267 .join("data");
268 std::fs::create_dir_all(&scenario_path)?;
269
270 let manifest = ScenarioManifest {
272 scenario_name: scenario.name.clone(),
273 description: scenario.description.clone(),
274 interventions_count: interventions.len(),
275 months_affected: propagated.changes_by_month.len(),
276 config_paths_changed: propagated
277 .changes_by_month
278 .values()
279 .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
280 .collect::<std::collections::HashSet<_>>()
281 .into_iter()
282 .collect(),
283 };
284
285 let manifest_path = output_root
286 .join("scenarios")
287 .join(&scenario.name)
288 .join("scenario_manifest.yaml");
289 let manifest_yaml = serde_yaml::to_string(&manifest)
290 .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
291 std::fs::write(&manifest_path, manifest_yaml)?;
292
293 Ok(ScenarioResult {
294 scenario_name: scenario.name.clone(),
295 baseline_path: baseline_path.to_path_buf(),
296 counterfactual_path: scenario_path,
297 interventions_applied: interventions.len(),
298 months_affected: propagated.changes_by_month.len(),
299 })
300 }
301
302 fn convert_interventions(
304 schema_interventions: &[datasynth_config::InterventionSchemaConfig],
305 ) -> Result<Vec<Intervention>, ScenarioError> {
306 let mut interventions = Vec::new();
307
308 for schema in schema_interventions {
309 let intervention_type: InterventionType =
310 serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
311 ScenarioError::Serialization(format!("failed to parse intervention type: {e}"))
312 })?;
313
314 let onset = match schema.timing.onset.to_lowercase().as_str() {
315 "sudden" => OnsetType::Sudden,
316 "gradual" => OnsetType::Gradual,
317 "oscillating" => OnsetType::Oscillating,
318 _ => OnsetType::Sudden,
319 };
320
321 interventions.push(Intervention {
322 id: Uuid::new_v4(),
323 intervention_type,
324 timing: InterventionTiming {
325 start_month: schema.timing.start_month,
326 duration_months: schema.timing.duration_months,
327 onset,
328 ramp_months: schema.timing.ramp_months,
329 },
330 label: schema.label.clone(),
331 priority: schema.priority,
332 });
333 }
334
335 Ok(interventions)
336 }
337
338 pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
340 self.base_config
341 .scenarios
342 .scenarios
343 .iter()
344 .map(|s| ScenarioSummary {
345 name: s.name.clone(),
346 description: s.description.clone(),
347 tags: s.tags.clone(),
348 intervention_count: s.interventions.len(),
349 probability_weight: s.probability_weight,
350 })
351 .collect()
352 }
353
354 pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
356 self.base_config
357 .scenarios
358 .scenarios
359 .iter()
360 .map(|s| {
361 let result = self.validate_scenario(s);
362 ScenarioValidationResult {
363 name: s.name.clone(),
364 valid: result.is_ok(),
365 error: result.err().map(|e| e.to_string()),
366 }
367 })
368 .collect()
369 }
370
371 fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
373 let interventions = Self::convert_interventions(&scenario.interventions)?;
374 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
375 let engine = CausalPropagationEngine::new(&self.causal_dag);
376 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
377
378 let constraints = ScenarioConstraints {
379 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
380 preserve_document_chains: scenario.constraints.preserve_document_chains,
381 preserve_period_close: scenario.constraints.preserve_period_close,
382 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
383 custom: vec![],
384 };
385
386 let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
387 Ok(())
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ScenarioSummary {
394 pub name: String,
395 pub description: String,
396 pub tags: Vec<String>,
397 pub intervention_count: usize,
398 pub probability_weight: Option<f64>,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct ScenarioValidationResult {
404 pub name: String,
405 pub valid: bool,
406 pub error: Option<String>,
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411struct ScenarioManifest {
412 scenario_name: String,
413 description: String,
414 interventions_count: usize,
415 months_affected: usize,
416 config_paths_changed: Vec<String>,
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use datasynth_config::{
423 InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
424 ScenarioOutputSchemaConfig, ScenariosConfig,
425 };
426 use datasynth_test_utils::fixtures::minimal_config;
427 use tempfile::TempDir;
428
429 fn config_with_scenario() -> GeneratorConfig {
430 let mut config = minimal_config();
431 config.scenarios = ScenariosConfig {
432 enabled: true,
433 scenarios: vec![ScenarioSchemaConfig {
434 name: "test_recession".to_string(),
435 description: "Test recession scenario".to_string(),
436 tags: vec!["test".to_string()],
437 base: None,
438 probability_weight: Some(0.3),
439 interventions: vec![InterventionSchemaConfig {
440 intervention_type: serde_json::json!({
441 "type": "parameter_shift",
442 "target": "global.period_months",
443 "to": 3,
444 "interpolation": "linear"
445 }),
446 timing: InterventionTimingSchemaConfig {
447 start_month: 1,
448 duration_months: None,
449 onset: "sudden".to_string(),
450 ramp_months: None,
451 },
452 label: Some("Test shift".to_string()),
453 priority: 0,
454 }],
455 constraints: ScenarioConstraintsSchemaConfig::default(),
456 output: ScenarioOutputSchemaConfig::default(),
457 metadata: Default::default(),
458 }],
459 causal_model: Default::default(),
460 defaults: Default::default(),
461 generate_counterfactuals: false,
462 };
463 config
464 }
465
466 #[test]
467 fn test_scenario_engine_new_default_dag() {
468 let config = config_with_scenario();
469 let engine = ScenarioEngine::new(config).expect("should create engine");
470 assert!(!engine.causal_dag().nodes.is_empty());
471 assert!(!engine.causal_dag().edges.is_empty());
472 }
473
474 #[test]
475 fn test_scenario_engine_list_scenarios() {
476 let config = config_with_scenario();
477 let engine = ScenarioEngine::new(config).expect("should create engine");
478 let scenarios = engine.list_scenarios();
479 assert_eq!(scenarios.len(), 1);
480 assert_eq!(scenarios[0].name, "test_recession");
481 assert_eq!(scenarios[0].intervention_count, 1);
482 }
483
484 #[test]
485 fn test_scenario_engine_validate_all() {
486 let config = config_with_scenario();
487 let engine = ScenarioEngine::new(config).expect("should create engine");
488 let results = engine.validate_all();
489 assert_eq!(results.len(), 1);
490 assert!(results[0].valid, "validation error: {:?}", results[0].error);
491 }
492
493 #[test]
494 fn test_scenario_engine_converts_schema_to_interventions() {
495 let config = config_with_scenario();
496 let interventions =
497 ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
498 .expect("should convert");
499 assert_eq!(interventions.len(), 1);
500 assert!(matches!(
501 interventions[0].intervention_type,
502 InterventionType::ParameterShift(_)
503 ));
504 }
505
506 #[test]
507 fn test_minimal_dag_preset_valid() {
508 let mut config = minimal_config();
509 config.scenarios = ScenariosConfig {
510 enabled: true,
511 scenarios: vec![ScenarioSchemaConfig {
512 name: "minimal_test".to_string(),
513 description: "Test with minimal DAG".to_string(),
514 tags: vec![],
515 base: None,
516 probability_weight: None,
517 interventions: vec![InterventionSchemaConfig {
518 intervention_type: serde_json::json!({
519 "type": "parameter_shift",
520 "target": "transactions.volume_multiplier",
521 "to": 2.0,
522 "interpolation": "linear"
523 }),
524 timing: InterventionTimingSchemaConfig {
525 start_month: 1,
526 duration_months: None,
527 onset: "sudden".to_string(),
528 ramp_months: None,
529 },
530 label: Some("Volume increase".to_string()),
531 priority: 0,
532 }],
533 constraints: ScenarioConstraintsSchemaConfig::default(),
534 output: ScenarioOutputSchemaConfig::default(),
535 metadata: Default::default(),
536 }],
537 causal_model: datasynth_config::CausalModelSchemaConfig {
538 preset: "minimal".to_string(),
539 ..Default::default()
540 },
541 defaults: Default::default(),
542 generate_counterfactuals: false,
543 };
544
545 let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
546 assert_eq!(engine.causal_dag().nodes.len(), 3);
547 assert_eq!(engine.causal_dag().edges.len(), 2);
548
549 let results = engine.validate_all();
551 assert_eq!(results.len(), 1);
552 assert!(results[0].valid, "validation error: {:?}", results[0].error);
553 }
554
555 #[test]
556 fn test_scenario_engine_generates_output() {
557 let config = config_with_scenario();
558 let engine = ScenarioEngine::new(config).expect("should create engine");
559 let tmpdir = TempDir::new().expect("should create tmpdir");
560 let results = engine.generate_all(tmpdir.path()).expect("should generate");
561 assert_eq!(results.len(), 1);
562 assert_eq!(results[0].scenario_name, "test_recession");
563 let manifest_path = tmpdir
565 .path()
566 .join("scenarios")
567 .join("test_recession")
568 .join("scenario_manifest.yaml");
569 assert!(manifest_path.exists());
570 }
571}