Skip to main content

datasynth_runtime/
scenario_engine.rs

1//! Scenario engine orchestrator for paired baseline/counterfactual generation.
2
3use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9    Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16/// Errors from the scenario engine.
17#[derive(Debug, Error)]
18pub enum ScenarioError {
19    #[error("intervention error: {0}")]
20    Intervention(#[from] InterventionError),
21    #[error("propagation error: {0}")]
22    Propagation(#[from] PropagationError),
23    #[error("mutation error: {0}")]
24    Mutation(#[from] MutationError),
25    #[error("DAG error: {0}")]
26    Dag(#[from] CausalDAGError),
27    #[error("generation error: {0}")]
28    Generation(String),
29    #[error("IO error: {0}")]
30    Io(#[from] std::io::Error),
31    #[error("serialization error: {0}")]
32    Serialization(String),
33}
34
35/// Result of generating a single scenario.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38    pub scenario_name: String,
39    pub baseline_path: PathBuf,
40    pub counterfactual_path: PathBuf,
41    pub interventions_applied: usize,
42    pub months_affected: usize,
43}
44
45/// Orchestrates paired scenario generation.
46pub struct ScenarioEngine {
47    base_config: GeneratorConfig,
48    causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52    /// Create a new ScenarioEngine, loading the causal DAG from config.
53    pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54        let causal_dag = Self::load_causal_dag(&config)?;
55        Ok(Self {
56            base_config: config,
57            causal_dag,
58        })
59    }
60
61    /// Load the causal DAG from config presets or custom definition.
62    ///
63    /// Presets: `default` (financial-process default), `manufacturing`
64    /// (supply-chain propagation), `retail` (O2C / seasonality),
65    /// `financial_services` (correspondent banking + AML), `minimal` (3-node
66    /// smoke-test DAG), and `custom` (user-provided `nodes` + `edges` in the
67    /// config).
68    fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
69        let causal_config = &config.scenarios.causal_model;
70        let mut dag: CausalDAG = match causal_config.preset.as_str() {
71            "custom" => {
72                if causal_config.nodes.is_empty() || causal_config.edges.is_empty() {
73                    return Err(ScenarioError::Serialization(
74                        "causal_model.preset = \"custom\" requires both `nodes` and `edges` \
75                         to be populated in the config"
76                            .to_string(),
77                    ));
78                }
79                let nodes: Vec<datasynth_core::causal_dag::CausalNode> = causal_config
80                    .nodes
81                    .iter()
82                    .enumerate()
83                    .map(|(i, v)| {
84                        serde_json::from_value(v.clone()).map_err(|e| {
85                            ScenarioError::Serialization(format!("causal_model.nodes[{i}]: {e}"))
86                        })
87                    })
88                    .collect::<Result<_, _>>()?;
89                let edges: Vec<datasynth_core::causal_dag::CausalEdge> = causal_config
90                    .edges
91                    .iter()
92                    .enumerate()
93                    .map(|(i, v)| {
94                        serde_json::from_value(v.clone()).map_err(|e| {
95                            ScenarioError::Serialization(format!("causal_model.edges[{i}]: {e}"))
96                        })
97                    })
98                    .collect::<Result<_, _>>()?;
99                CausalDAG {
100                    nodes,
101                    edges,
102                    topological_order: Vec::new(),
103                }
104            }
105            "default" | "" => {
106                let yaml = include_str!("causal_dag_default.yaml");
107                serde_yaml::from_str(yaml).map_err(|e| {
108                    ScenarioError::Serialization(format!("failed to parse default causal DAG: {e}"))
109                })?
110            }
111            "manufacturing" => {
112                let yaml = include_str!("causal_dag_manufacturing.yaml");
113                serde_yaml::from_str(yaml).map_err(|e| {
114                    ScenarioError::Serialization(format!(
115                        "failed to parse manufacturing causal DAG: {e}"
116                    ))
117                })?
118            }
119            "retail" => {
120                let yaml = include_str!("causal_dag_retail.yaml");
121                serde_yaml::from_str(yaml).map_err(|e| {
122                    ScenarioError::Serialization(format!("failed to parse retail causal DAG: {e}"))
123                })?
124            }
125            "financial_services" => {
126                let yaml = include_str!("causal_dag_financial_services.yaml");
127                serde_yaml::from_str(yaml).map_err(|e| {
128                    ScenarioError::Serialization(format!(
129                        "failed to parse financial_services causal DAG: {e}"
130                    ))
131                })?
132            }
133            "minimal" => {
134                use datasynth_core::causal_dag::{
135                    CausalEdge, CausalNode, NodeCategory, TransferFunction,
136                };
137                // Minimal DAG: macro → operational → outcome (3 nodes, 2 edges)
138                CausalDAG {
139                    nodes: vec![
140                        CausalNode {
141                            id: "gdp_growth".to_string(),
142                            label: "GDP Growth".to_string(),
143                            category: NodeCategory::Macro,
144                            baseline_value: 0.025,
145                            bounds: Some((-0.10, 0.15)),
146                            interventionable: true,
147                            config_bindings: vec![],
148                        },
149                        CausalNode {
150                            id: "transaction_volume".to_string(),
151                            label: "Transaction Volume".to_string(),
152                            category: NodeCategory::Operational,
153                            baseline_value: 1.0,
154                            bounds: Some((0.2, 3.0)),
155                            interventionable: true,
156                            config_bindings: vec!["transactions.volume_multiplier".to_string()],
157                        },
158                        CausalNode {
159                            id: "error_rate".to_string(),
160                            label: "Error Rate".to_string(),
161                            category: NodeCategory::Outcome,
162                            baseline_value: 0.02,
163                            bounds: Some((0.0, 0.30)),
164                            interventionable: false,
165                            config_bindings: vec!["anomaly_injection.base_rate".to_string()],
166                        },
167                    ],
168                    edges: vec![
169                        CausalEdge {
170                            from: "gdp_growth".to_string(),
171                            to: "transaction_volume".to_string(),
172                            transfer: TransferFunction::Linear {
173                                coefficient: 0.8,
174                                intercept: 1.0,
175                            },
176                            lag_months: 1,
177                            strength: 1.0,
178                            mechanism: Some("GDP growth drives transaction volume".to_string()),
179                        },
180                        CausalEdge {
181                            from: "transaction_volume".to_string(),
182                            to: "error_rate".to_string(),
183                            transfer: TransferFunction::Linear {
184                                coefficient: 0.01,
185                                intercept: 0.0,
186                            },
187                            lag_months: 0,
188                            strength: 1.0,
189                            mechanism: Some("Higher volume increases error rate".to_string()),
190                        },
191                    ],
192                    topological_order: vec![],
193                }
194            }
195            other => {
196                return Err(ScenarioError::Serialization(format!(
197                    "unknown causal DAG preset: '{other}'"
198                )));
199            }
200        };
201
202        dag.validate()?;
203        Ok(dag)
204    }
205
206    /// Get a reference to the loaded causal DAG.
207    pub fn causal_dag(&self) -> &CausalDAG {
208        &self.causal_dag
209    }
210
211    /// Get a reference to the base config.
212    pub fn base_config(&self) -> &GeneratorConfig {
213        &self.base_config
214    }
215
216    /// Generate all scenarios defined in config.
217    pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
218        let scenarios = &self.base_config.scenarios.scenarios;
219        let mut results = Vec::with_capacity(scenarios.len());
220
221        // Create baseline directory
222        let baseline_path = output_root.join("baseline");
223        std::fs::create_dir_all(&baseline_path)?;
224
225        // Generate each scenario
226        for scenario in scenarios {
227            let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
228            results.push(result);
229        }
230
231        Ok(results)
232    }
233
234    /// Generate a single scenario: validate, propagate, mutate, produce output.
235    pub fn generate_scenario(
236        &self,
237        scenario: &ScenarioSchemaConfig,
238        baseline_path: &Path,
239        output_root: &Path,
240    ) -> Result<ScenarioResult, ScenarioError> {
241        // 1. Convert schema config to core interventions
242        let interventions = Self::convert_interventions(&scenario.interventions)?;
243
244        // 2. Validate interventions
245        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
246
247        // 3. Propagate through causal DAG
248        let engine = CausalPropagationEngine::new(&self.causal_dag);
249        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
250
251        // 4. Build constraints
252        let constraints = ScenarioConstraints {
253            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
254            preserve_document_chains: scenario.constraints.preserve_document_chains,
255            preserve_period_close: scenario.constraints.preserve_period_close,
256            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
257            custom: vec![],
258        };
259
260        // 5. Apply to config (creates mutated copy)
261        let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
262
263        // 6. Create scenario output directory
264        let scenario_path = output_root
265            .join("scenarios")
266            .join(&scenario.name)
267            .join("data");
268        std::fs::create_dir_all(&scenario_path)?;
269
270        // 7. Write scenario manifest
271        let manifest = ScenarioManifest {
272            scenario_name: scenario.name.clone(),
273            description: scenario.description.clone(),
274            interventions_count: interventions.len(),
275            months_affected: propagated.changes_by_month.len(),
276            config_paths_changed: propagated
277                .changes_by_month
278                .values()
279                .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
280                .collect::<std::collections::HashSet<_>>()
281                .into_iter()
282                .collect(),
283        };
284
285        let manifest_path = output_root
286            .join("scenarios")
287            .join(&scenario.name)
288            .join("scenario_manifest.yaml");
289        let manifest_yaml = serde_yaml::to_string(&manifest)
290            .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
291        std::fs::write(&manifest_path, manifest_yaml)?;
292
293        Ok(ScenarioResult {
294            scenario_name: scenario.name.clone(),
295            baseline_path: baseline_path.to_path_buf(),
296            counterfactual_path: scenario_path,
297            interventions_applied: interventions.len(),
298            months_affected: propagated.changes_by_month.len(),
299        })
300    }
301
302    /// Convert schema-level intervention configs to core Intervention structs.
303    fn convert_interventions(
304        schema_interventions: &[datasynth_config::InterventionSchemaConfig],
305    ) -> Result<Vec<Intervention>, ScenarioError> {
306        let mut interventions = Vec::new();
307
308        for schema in schema_interventions {
309            let intervention_type: InterventionType =
310                serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
311                    ScenarioError::Serialization(format!("failed to parse intervention type: {e}"))
312                })?;
313
314            let onset = match schema.timing.onset.to_lowercase().as_str() {
315                "sudden" => OnsetType::Sudden,
316                "gradual" => OnsetType::Gradual,
317                "oscillating" => OnsetType::Oscillating,
318                _ => OnsetType::Sudden,
319            };
320
321            interventions.push(Intervention {
322                id: Uuid::new_v4(),
323                intervention_type,
324                timing: InterventionTiming {
325                    start_month: schema.timing.start_month,
326                    duration_months: schema.timing.duration_months,
327                    onset,
328                    ramp_months: schema.timing.ramp_months,
329                },
330                label: schema.label.clone(),
331                priority: schema.priority,
332            });
333        }
334
335        Ok(interventions)
336    }
337
338    /// List all scenarios in the config.
339    pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
340        self.base_config
341            .scenarios
342            .scenarios
343            .iter()
344            .map(|s| ScenarioSummary {
345                name: s.name.clone(),
346                description: s.description.clone(),
347                tags: s.tags.clone(),
348                intervention_count: s.interventions.len(),
349                probability_weight: s.probability_weight,
350            })
351            .collect()
352    }
353
354    /// Validate all scenarios without generating.
355    pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
356        self.base_config
357            .scenarios
358            .scenarios
359            .iter()
360            .map(|s| {
361                let result = self.validate_scenario(s);
362                ScenarioValidationResult {
363                    name: s.name.clone(),
364                    valid: result.is_ok(),
365                    error: result.err().map(|e| e.to_string()),
366                }
367            })
368            .collect()
369    }
370
371    /// Validate a single scenario.
372    fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
373        let interventions = Self::convert_interventions(&scenario.interventions)?;
374        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
375        let engine = CausalPropagationEngine::new(&self.causal_dag);
376        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
377
378        let constraints = ScenarioConstraints {
379            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
380            preserve_document_chains: scenario.constraints.preserve_document_chains,
381            preserve_period_close: scenario.constraints.preserve_period_close,
382            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
383            custom: vec![],
384        };
385
386        let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
387        Ok(())
388    }
389}
390
391/// Summary info for listing scenarios.
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ScenarioSummary {
394    pub name: String,
395    pub description: String,
396    pub tags: Vec<String>,
397    pub intervention_count: usize,
398    pub probability_weight: Option<f64>,
399}
400
401/// Result of validating a scenario.
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct ScenarioValidationResult {
404    pub name: String,
405    pub valid: bool,
406    pub error: Option<String>,
407}
408
409/// Manifest written alongside scenario output.
410#[derive(Debug, Clone, Serialize, Deserialize)]
411struct ScenarioManifest {
412    scenario_name: String,
413    description: String,
414    interventions_count: usize,
415    months_affected: usize,
416    config_paths_changed: Vec<String>,
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use datasynth_config::{
423        InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
424        ScenarioOutputSchemaConfig, ScenariosConfig,
425    };
426    use datasynth_test_utils::fixtures::minimal_config;
427    use tempfile::TempDir;
428
429    fn config_with_scenario() -> GeneratorConfig {
430        let mut config = minimal_config();
431        config.scenarios = ScenariosConfig {
432            enabled: true,
433            scenarios: vec![ScenarioSchemaConfig {
434                name: "test_recession".to_string(),
435                description: "Test recession scenario".to_string(),
436                tags: vec!["test".to_string()],
437                base: None,
438                probability_weight: Some(0.3),
439                interventions: vec![InterventionSchemaConfig {
440                    intervention_type: serde_json::json!({
441                        "type": "parameter_shift",
442                        "target": "global.period_months",
443                        "to": 3,
444                        "interpolation": "linear"
445                    }),
446                    timing: InterventionTimingSchemaConfig {
447                        start_month: 1,
448                        duration_months: None,
449                        onset: "sudden".to_string(),
450                        ramp_months: None,
451                    },
452                    label: Some("Test shift".to_string()),
453                    priority: 0,
454                }],
455                constraints: ScenarioConstraintsSchemaConfig::default(),
456                output: ScenarioOutputSchemaConfig::default(),
457                metadata: Default::default(),
458            }],
459            causal_model: Default::default(),
460            defaults: Default::default(),
461            generate_counterfactuals: false,
462        };
463        config
464    }
465
466    #[test]
467    fn test_scenario_engine_new_default_dag() {
468        let config = config_with_scenario();
469        let engine = ScenarioEngine::new(config).expect("should create engine");
470        assert!(!engine.causal_dag().nodes.is_empty());
471        assert!(!engine.causal_dag().edges.is_empty());
472    }
473
474    #[test]
475    fn test_scenario_engine_list_scenarios() {
476        let config = config_with_scenario();
477        let engine = ScenarioEngine::new(config).expect("should create engine");
478        let scenarios = engine.list_scenarios();
479        assert_eq!(scenarios.len(), 1);
480        assert_eq!(scenarios[0].name, "test_recession");
481        assert_eq!(scenarios[0].intervention_count, 1);
482    }
483
484    #[test]
485    fn test_scenario_engine_validate_all() {
486        let config = config_with_scenario();
487        let engine = ScenarioEngine::new(config).expect("should create engine");
488        let results = engine.validate_all();
489        assert_eq!(results.len(), 1);
490        assert!(results[0].valid, "validation error: {:?}", results[0].error);
491    }
492
493    #[test]
494    fn test_scenario_engine_converts_schema_to_interventions() {
495        let config = config_with_scenario();
496        let interventions =
497            ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
498                .expect("should convert");
499        assert_eq!(interventions.len(), 1);
500        assert!(matches!(
501            interventions[0].intervention_type,
502            InterventionType::ParameterShift(_)
503        ));
504    }
505
506    #[test]
507    fn test_minimal_dag_preset_valid() {
508        let mut config = minimal_config();
509        config.scenarios = ScenariosConfig {
510            enabled: true,
511            scenarios: vec![ScenarioSchemaConfig {
512                name: "minimal_test".to_string(),
513                description: "Test with minimal DAG".to_string(),
514                tags: vec![],
515                base: None,
516                probability_weight: None,
517                interventions: vec![InterventionSchemaConfig {
518                    intervention_type: serde_json::json!({
519                        "type": "parameter_shift",
520                        "target": "transactions.volume_multiplier",
521                        "to": 2.0,
522                        "interpolation": "linear"
523                    }),
524                    timing: InterventionTimingSchemaConfig {
525                        start_month: 1,
526                        duration_months: None,
527                        onset: "sudden".to_string(),
528                        ramp_months: None,
529                    },
530                    label: Some("Volume increase".to_string()),
531                    priority: 0,
532                }],
533                constraints: ScenarioConstraintsSchemaConfig::default(),
534                output: ScenarioOutputSchemaConfig::default(),
535                metadata: Default::default(),
536            }],
537            causal_model: datasynth_config::CausalModelSchemaConfig {
538                preset: "minimal".to_string(),
539                ..Default::default()
540            },
541            defaults: Default::default(),
542            generate_counterfactuals: false,
543        };
544
545        let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
546        assert_eq!(engine.causal_dag().nodes.len(), 3);
547        assert_eq!(engine.causal_dag().edges.len(), 2);
548
549        // Validate all scenarios pass
550        let results = engine.validate_all();
551        assert_eq!(results.len(), 1);
552        assert!(results[0].valid, "validation error: {:?}", results[0].error);
553    }
554
555    #[test]
556    fn test_scenario_engine_generates_output() {
557        let config = config_with_scenario();
558        let engine = ScenarioEngine::new(config).expect("should create engine");
559        let tmpdir = TempDir::new().expect("should create tmpdir");
560        let results = engine.generate_all(tmpdir.path()).expect("should generate");
561        assert_eq!(results.len(), 1);
562        assert_eq!(results[0].scenario_name, "test_recession");
563        // Manifest should exist
564        let manifest_path = tmpdir
565            .path()
566            .join("scenarios")
567            .join("test_recession")
568            .join("scenario_manifest.yaml");
569        assert!(manifest_path.exists());
570    }
571}