1use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9 Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16#[derive(Debug, Error)]
18pub enum ScenarioError {
19 #[error("intervention error: {0}")]
20 Intervention(#[from] InterventionError),
21 #[error("propagation error: {0}")]
22 Propagation(#[from] PropagationError),
23 #[error("mutation error: {0}")]
24 Mutation(#[from] MutationError),
25 #[error("DAG error: {0}")]
26 Dag(#[from] CausalDAGError),
27 #[error("generation error: {0}")]
28 Generation(String),
29 #[error("IO error: {0}")]
30 Io(#[from] std::io::Error),
31 #[error("serialization error: {0}")]
32 Serialization(String),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38 pub scenario_name: String,
39 pub baseline_path: PathBuf,
40 pub counterfactual_path: PathBuf,
41 pub interventions_applied: usize,
42 pub months_affected: usize,
43}
44
45pub struct ScenarioEngine {
47 base_config: GeneratorConfig,
48 causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52 pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54 let causal_dag = Self::load_causal_dag(&config)?;
55 Ok(Self {
56 base_config: config,
57 causal_dag,
58 })
59 }
60
61 fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63 let causal_config = &config.scenarios.causal_model;
64 let mut dag: CausalDAG = match causal_config.preset.as_str() {
65 "default" | "" => {
66 let yaml =
67 include_str!("causal_dag_default.yaml");
68 serde_yaml::from_str(yaml).map_err(|e| {
69 ScenarioError::Serialization(format!(
70 "failed to parse default causal DAG: {}",
71 e
72 ))
73 })?
74 }
75 "minimal" => {
76 use datasynth_core::causal_dag::{
77 CausalEdge, CausalNode, NodeCategory, TransferFunction,
78 };
79 CausalDAG {
81 nodes: vec![
82 CausalNode {
83 id: "gdp_growth".to_string(),
84 label: "GDP Growth".to_string(),
85 category: NodeCategory::Macro,
86 baseline_value: 0.025,
87 bounds: Some((-0.10, 0.15)),
88 interventionable: true,
89 config_bindings: vec![],
90 },
91 CausalNode {
92 id: "transaction_volume".to_string(),
93 label: "Transaction Volume".to_string(),
94 category: NodeCategory::Operational,
95 baseline_value: 1.0,
96 bounds: Some((0.2, 3.0)),
97 interventionable: true,
98 config_bindings: vec!["transactions.volume_multiplier".to_string()],
99 },
100 CausalNode {
101 id: "error_rate".to_string(),
102 label: "Error Rate".to_string(),
103 category: NodeCategory::Outcome,
104 baseline_value: 0.02,
105 bounds: Some((0.0, 0.30)),
106 interventionable: false,
107 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
108 },
109 ],
110 edges: vec![
111 CausalEdge {
112 from: "gdp_growth".to_string(),
113 to: "transaction_volume".to_string(),
114 transfer: TransferFunction::Linear {
115 coefficient: 0.8,
116 intercept: 1.0,
117 },
118 lag_months: 1,
119 strength: 1.0,
120 mechanism: Some("GDP growth drives transaction volume".to_string()),
121 },
122 CausalEdge {
123 from: "transaction_volume".to_string(),
124 to: "error_rate".to_string(),
125 transfer: TransferFunction::Linear {
126 coefficient: 0.01,
127 intercept: 0.0,
128 },
129 lag_months: 0,
130 strength: 1.0,
131 mechanism: Some("Higher volume increases error rate".to_string()),
132 },
133 ],
134 topological_order: vec![],
135 }
136 }
137 other => {
138 return Err(ScenarioError::Serialization(format!(
139 "unknown causal DAG preset: '{}'",
140 other
141 )));
142 }
143 };
144
145 dag.validate()?;
146 Ok(dag)
147 }
148
149 pub fn causal_dag(&self) -> &CausalDAG {
151 &self.causal_dag
152 }
153
154 pub fn base_config(&self) -> &GeneratorConfig {
156 &self.base_config
157 }
158
159 pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
161 let scenarios = &self.base_config.scenarios.scenarios;
162 let mut results = Vec::with_capacity(scenarios.len());
163
164 let baseline_path = output_root.join("baseline");
166 std::fs::create_dir_all(&baseline_path)?;
167
168 for scenario in scenarios {
170 let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
171 results.push(result);
172 }
173
174 Ok(results)
175 }
176
177 pub fn generate_scenario(
179 &self,
180 scenario: &ScenarioSchemaConfig,
181 baseline_path: &Path,
182 output_root: &Path,
183 ) -> Result<ScenarioResult, ScenarioError> {
184 let interventions = Self::convert_interventions(&scenario.interventions)?;
186
187 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
189
190 let engine = CausalPropagationEngine::new(&self.causal_dag);
192 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
193
194 let constraints = ScenarioConstraints {
196 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
197 preserve_document_chains: scenario.constraints.preserve_document_chains,
198 preserve_period_close: scenario.constraints.preserve_period_close,
199 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
200 custom: vec![],
201 };
202
203 let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
205
206 let scenario_path = output_root
208 .join("scenarios")
209 .join(&scenario.name)
210 .join("data");
211 std::fs::create_dir_all(&scenario_path)?;
212
213 let manifest = ScenarioManifest {
215 scenario_name: scenario.name.clone(),
216 description: scenario.description.clone(),
217 interventions_count: interventions.len(),
218 months_affected: propagated.changes_by_month.len(),
219 config_paths_changed: propagated
220 .changes_by_month
221 .values()
222 .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
223 .collect::<std::collections::HashSet<_>>()
224 .into_iter()
225 .collect(),
226 };
227
228 let manifest_path = output_root
229 .join("scenarios")
230 .join(&scenario.name)
231 .join("scenario_manifest.yaml");
232 let manifest_yaml = serde_yaml::to_string(&manifest)
233 .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
234 std::fs::write(&manifest_path, manifest_yaml)?;
235
236 Ok(ScenarioResult {
237 scenario_name: scenario.name.clone(),
238 baseline_path: baseline_path.to_path_buf(),
239 counterfactual_path: scenario_path,
240 interventions_applied: interventions.len(),
241 months_affected: propagated.changes_by_month.len(),
242 })
243 }
244
245 fn convert_interventions(
247 schema_interventions: &[datasynth_config::InterventionSchemaConfig],
248 ) -> Result<Vec<Intervention>, ScenarioError> {
249 let mut interventions = Vec::new();
250
251 for schema in schema_interventions {
252 let intervention_type: InterventionType =
253 serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
254 ScenarioError::Serialization(format!(
255 "failed to parse intervention type: {}",
256 e
257 ))
258 })?;
259
260 let onset = match schema.timing.onset.to_lowercase().as_str() {
261 "sudden" => OnsetType::Sudden,
262 "gradual" => OnsetType::Gradual,
263 "oscillating" => OnsetType::Oscillating,
264 _ => OnsetType::Sudden,
265 };
266
267 interventions.push(Intervention {
268 id: Uuid::new_v4(),
269 intervention_type,
270 timing: InterventionTiming {
271 start_month: schema.timing.start_month,
272 duration_months: schema.timing.duration_months,
273 onset,
274 ramp_months: schema.timing.ramp_months,
275 },
276 label: schema.label.clone(),
277 priority: schema.priority,
278 });
279 }
280
281 Ok(interventions)
282 }
283
284 pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
286 self.base_config
287 .scenarios
288 .scenarios
289 .iter()
290 .map(|s| ScenarioSummary {
291 name: s.name.clone(),
292 description: s.description.clone(),
293 tags: s.tags.clone(),
294 intervention_count: s.interventions.len(),
295 probability_weight: s.probability_weight,
296 })
297 .collect()
298 }
299
300 pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
302 self.base_config
303 .scenarios
304 .scenarios
305 .iter()
306 .map(|s| {
307 let result = self.validate_scenario(s);
308 ScenarioValidationResult {
309 name: s.name.clone(),
310 valid: result.is_ok(),
311 error: result.err().map(|e| e.to_string()),
312 }
313 })
314 .collect()
315 }
316
317 fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
319 let interventions = Self::convert_interventions(&scenario.interventions)?;
320 let validated = InterventionManager::validate(&interventions, &self.base_config)?;
321 let engine = CausalPropagationEngine::new(&self.causal_dag);
322 let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
323
324 let constraints = ScenarioConstraints {
325 preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
326 preserve_document_chains: scenario.constraints.preserve_document_chains,
327 preserve_period_close: scenario.constraints.preserve_period_close,
328 preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
329 custom: vec![],
330 };
331
332 let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
333 Ok(())
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ScenarioSummary {
340 pub name: String,
341 pub description: String,
342 pub tags: Vec<String>,
343 pub intervention_count: usize,
344 pub probability_weight: Option<f64>,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct ScenarioValidationResult {
350 pub name: String,
351 pub valid: bool,
352 pub error: Option<String>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
357struct ScenarioManifest {
358 scenario_name: String,
359 description: String,
360 interventions_count: usize,
361 months_affected: usize,
362 config_paths_changed: Vec<String>,
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use datasynth_config::{
369 InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
370 ScenarioOutputSchemaConfig, ScenariosConfig,
371 };
372 use datasynth_test_utils::fixtures::minimal_config;
373 use tempfile::TempDir;
374
375 fn config_with_scenario() -> GeneratorConfig {
376 let mut config = minimal_config();
377 config.scenarios = ScenariosConfig {
378 enabled: true,
379 scenarios: vec![ScenarioSchemaConfig {
380 name: "test_recession".to_string(),
381 description: "Test recession scenario".to_string(),
382 tags: vec!["test".to_string()],
383 base: None,
384 probability_weight: Some(0.3),
385 interventions: vec![InterventionSchemaConfig {
386 intervention_type: serde_json::json!({
387 "type": "parameter_shift",
388 "target": "global.period_months",
389 "to": 3,
390 "interpolation": "linear"
391 }),
392 timing: InterventionTimingSchemaConfig {
393 start_month: 1,
394 duration_months: None,
395 onset: "sudden".to_string(),
396 ramp_months: None,
397 },
398 label: Some("Test shift".to_string()),
399 priority: 0,
400 }],
401 constraints: ScenarioConstraintsSchemaConfig::default(),
402 output: ScenarioOutputSchemaConfig::default(),
403 metadata: Default::default(),
404 }],
405 causal_model: Default::default(),
406 defaults: Default::default(),
407 };
408 config
409 }
410
411 #[test]
412 fn test_scenario_engine_new_default_dag() {
413 let config = config_with_scenario();
414 let engine = ScenarioEngine::new(config).expect("should create engine");
415 assert!(!engine.causal_dag().nodes.is_empty());
416 assert!(!engine.causal_dag().edges.is_empty());
417 }
418
419 #[test]
420 fn test_scenario_engine_list_scenarios() {
421 let config = config_with_scenario();
422 let engine = ScenarioEngine::new(config).expect("should create engine");
423 let scenarios = engine.list_scenarios();
424 assert_eq!(scenarios.len(), 1);
425 assert_eq!(scenarios[0].name, "test_recession");
426 assert_eq!(scenarios[0].intervention_count, 1);
427 }
428
429 #[test]
430 fn test_scenario_engine_validate_all() {
431 let config = config_with_scenario();
432 let engine = ScenarioEngine::new(config).expect("should create engine");
433 let results = engine.validate_all();
434 assert_eq!(results.len(), 1);
435 assert!(results[0].valid, "validation error: {:?}", results[0].error);
436 }
437
438 #[test]
439 fn test_scenario_engine_converts_schema_to_interventions() {
440 let config = config_with_scenario();
441 let interventions =
442 ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
443 .expect("should convert");
444 assert_eq!(interventions.len(), 1);
445 assert!(matches!(
446 interventions[0].intervention_type,
447 InterventionType::ParameterShift(_)
448 ));
449 }
450
451 #[test]
452 fn test_minimal_dag_preset_valid() {
453 let mut config = minimal_config();
454 config.scenarios = ScenariosConfig {
455 enabled: true,
456 scenarios: vec![ScenarioSchemaConfig {
457 name: "minimal_test".to_string(),
458 description: "Test with minimal DAG".to_string(),
459 tags: vec![],
460 base: None,
461 probability_weight: None,
462 interventions: vec![InterventionSchemaConfig {
463 intervention_type: serde_json::json!({
464 "type": "parameter_shift",
465 "target": "transactions.volume_multiplier",
466 "to": 2.0,
467 "interpolation": "linear"
468 }),
469 timing: InterventionTimingSchemaConfig {
470 start_month: 1,
471 duration_months: None,
472 onset: "sudden".to_string(),
473 ramp_months: None,
474 },
475 label: Some("Volume increase".to_string()),
476 priority: 0,
477 }],
478 constraints: ScenarioConstraintsSchemaConfig::default(),
479 output: ScenarioOutputSchemaConfig::default(),
480 metadata: Default::default(),
481 }],
482 causal_model: datasynth_config::CausalModelSchemaConfig {
483 preset: "minimal".to_string(),
484 ..Default::default()
485 },
486 defaults: Default::default(),
487 };
488
489 let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
490 assert_eq!(engine.causal_dag().nodes.len(), 3);
491 assert_eq!(engine.causal_dag().edges.len(), 2);
492
493 let results = engine.validate_all();
495 assert_eq!(results.len(), 1);
496 assert!(results[0].valid, "validation error: {:?}", results[0].error);
497 }
498
499 #[test]
500 fn test_scenario_engine_generates_output() {
501 let config = config_with_scenario();
502 let engine = ScenarioEngine::new(config).expect("should create engine");
503 let tmpdir = TempDir::new().expect("should create tmpdir");
504 let results = engine.generate_all(tmpdir.path()).expect("should generate");
505 assert_eq!(results.len(), 1);
506 assert_eq!(results[0].scenario_name, "test_recession");
507 let manifest_path = tmpdir
509 .path()
510 .join("scenarios")
511 .join("test_recession")
512 .join("scenario_manifest.yaml");
513 assert!(manifest_path.exists());
514 }
515}