Skip to main content

datasynth_runtime/
scenario_engine.rs

1//! Scenario engine orchestrator for paired baseline/counterfactual generation.
2
3use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9    Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16/// Errors from the scenario engine.
17#[derive(Debug, Error)]
18pub enum ScenarioError {
19    #[error("intervention error: {0}")]
20    Intervention(#[from] InterventionError),
21    #[error("propagation error: {0}")]
22    Propagation(#[from] PropagationError),
23    #[error("mutation error: {0}")]
24    Mutation(#[from] MutationError),
25    #[error("DAG error: {0}")]
26    Dag(#[from] CausalDAGError),
27    #[error("generation error: {0}")]
28    Generation(String),
29    #[error("IO error: {0}")]
30    Io(#[from] std::io::Error),
31    #[error("serialization error: {0}")]
32    Serialization(String),
33}
34
35/// Result of generating a single scenario.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38    pub scenario_name: String,
39    pub baseline_path: PathBuf,
40    pub counterfactual_path: PathBuf,
41    pub interventions_applied: usize,
42    pub months_affected: usize,
43}
44
45/// Orchestrates paired scenario generation.
46pub struct ScenarioEngine {
47    base_config: GeneratorConfig,
48    causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52    /// Create a new ScenarioEngine, loading the causal DAG from config.
53    pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54        let causal_dag = Self::load_causal_dag(&config)?;
55        Ok(Self {
56            base_config: config,
57            causal_dag,
58        })
59    }
60
61    /// Load the causal DAG from config presets or custom definition.
62    fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63        let causal_config = &config.scenarios.causal_model;
64        let mut dag: CausalDAG = match causal_config.preset.as_str() {
65            "default" | "" => {
66                let yaml = include_str!("causal_dag_default.yaml");
67                serde_yaml::from_str(yaml).map_err(|e| {
68                    ScenarioError::Serialization(format!("failed to parse default causal DAG: {e}"))
69                })?
70            }
71            "minimal" => {
72                use datasynth_core::causal_dag::{
73                    CausalEdge, CausalNode, NodeCategory, TransferFunction,
74                };
75                // Minimal DAG: macro → operational → outcome (3 nodes, 2 edges)
76                CausalDAG {
77                    nodes: vec![
78                        CausalNode {
79                            id: "gdp_growth".to_string(),
80                            label: "GDP Growth".to_string(),
81                            category: NodeCategory::Macro,
82                            baseline_value: 0.025,
83                            bounds: Some((-0.10, 0.15)),
84                            interventionable: true,
85                            config_bindings: vec![],
86                        },
87                        CausalNode {
88                            id: "transaction_volume".to_string(),
89                            label: "Transaction Volume".to_string(),
90                            category: NodeCategory::Operational,
91                            baseline_value: 1.0,
92                            bounds: Some((0.2, 3.0)),
93                            interventionable: true,
94                            config_bindings: vec!["transactions.volume_multiplier".to_string()],
95                        },
96                        CausalNode {
97                            id: "error_rate".to_string(),
98                            label: "Error Rate".to_string(),
99                            category: NodeCategory::Outcome,
100                            baseline_value: 0.02,
101                            bounds: Some((0.0, 0.30)),
102                            interventionable: false,
103                            config_bindings: vec!["anomaly_injection.base_rate".to_string()],
104                        },
105                    ],
106                    edges: vec![
107                        CausalEdge {
108                            from: "gdp_growth".to_string(),
109                            to: "transaction_volume".to_string(),
110                            transfer: TransferFunction::Linear {
111                                coefficient: 0.8,
112                                intercept: 1.0,
113                            },
114                            lag_months: 1,
115                            strength: 1.0,
116                            mechanism: Some("GDP growth drives transaction volume".to_string()),
117                        },
118                        CausalEdge {
119                            from: "transaction_volume".to_string(),
120                            to: "error_rate".to_string(),
121                            transfer: TransferFunction::Linear {
122                                coefficient: 0.01,
123                                intercept: 0.0,
124                            },
125                            lag_months: 0,
126                            strength: 1.0,
127                            mechanism: Some("Higher volume increases error rate".to_string()),
128                        },
129                    ],
130                    topological_order: vec![],
131                }
132            }
133            other => {
134                return Err(ScenarioError::Serialization(format!(
135                    "unknown causal DAG preset: '{other}'"
136                )));
137            }
138        };
139
140        dag.validate()?;
141        Ok(dag)
142    }
143
144    /// Get a reference to the loaded causal DAG.
145    pub fn causal_dag(&self) -> &CausalDAG {
146        &self.causal_dag
147    }
148
149    /// Get a reference to the base config.
150    pub fn base_config(&self) -> &GeneratorConfig {
151        &self.base_config
152    }
153
154    /// Generate all scenarios defined in config.
155    pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
156        let scenarios = &self.base_config.scenarios.scenarios;
157        let mut results = Vec::with_capacity(scenarios.len());
158
159        // Create baseline directory
160        let baseline_path = output_root.join("baseline");
161        std::fs::create_dir_all(&baseline_path)?;
162
163        // Generate each scenario
164        for scenario in scenarios {
165            let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
166            results.push(result);
167        }
168
169        Ok(results)
170    }
171
172    /// Generate a single scenario: validate, propagate, mutate, produce output.
173    pub fn generate_scenario(
174        &self,
175        scenario: &ScenarioSchemaConfig,
176        baseline_path: &Path,
177        output_root: &Path,
178    ) -> Result<ScenarioResult, ScenarioError> {
179        // 1. Convert schema config to core interventions
180        let interventions = Self::convert_interventions(&scenario.interventions)?;
181
182        // 2. Validate interventions
183        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
184
185        // 3. Propagate through causal DAG
186        let engine = CausalPropagationEngine::new(&self.causal_dag);
187        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
188
189        // 4. Build constraints
190        let constraints = ScenarioConstraints {
191            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
192            preserve_document_chains: scenario.constraints.preserve_document_chains,
193            preserve_period_close: scenario.constraints.preserve_period_close,
194            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
195            custom: vec![],
196        };
197
198        // 5. Apply to config (creates mutated copy)
199        let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
200
201        // 6. Create scenario output directory
202        let scenario_path = output_root
203            .join("scenarios")
204            .join(&scenario.name)
205            .join("data");
206        std::fs::create_dir_all(&scenario_path)?;
207
208        // 7. Write scenario manifest
209        let manifest = ScenarioManifest {
210            scenario_name: scenario.name.clone(),
211            description: scenario.description.clone(),
212            interventions_count: interventions.len(),
213            months_affected: propagated.changes_by_month.len(),
214            config_paths_changed: propagated
215                .changes_by_month
216                .values()
217                .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
218                .collect::<std::collections::HashSet<_>>()
219                .into_iter()
220                .collect(),
221        };
222
223        let manifest_path = output_root
224            .join("scenarios")
225            .join(&scenario.name)
226            .join("scenario_manifest.yaml");
227        let manifest_yaml = serde_yaml::to_string(&manifest)
228            .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
229        std::fs::write(&manifest_path, manifest_yaml)?;
230
231        Ok(ScenarioResult {
232            scenario_name: scenario.name.clone(),
233            baseline_path: baseline_path.to_path_buf(),
234            counterfactual_path: scenario_path,
235            interventions_applied: interventions.len(),
236            months_affected: propagated.changes_by_month.len(),
237        })
238    }
239
240    /// Convert schema-level intervention configs to core Intervention structs.
241    fn convert_interventions(
242        schema_interventions: &[datasynth_config::InterventionSchemaConfig],
243    ) -> Result<Vec<Intervention>, ScenarioError> {
244        let mut interventions = Vec::new();
245
246        for schema in schema_interventions {
247            let intervention_type: InterventionType =
248                serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
249                    ScenarioError::Serialization(format!("failed to parse intervention type: {e}"))
250                })?;
251
252            let onset = match schema.timing.onset.to_lowercase().as_str() {
253                "sudden" => OnsetType::Sudden,
254                "gradual" => OnsetType::Gradual,
255                "oscillating" => OnsetType::Oscillating,
256                _ => OnsetType::Sudden,
257            };
258
259            interventions.push(Intervention {
260                id: Uuid::new_v4(),
261                intervention_type,
262                timing: InterventionTiming {
263                    start_month: schema.timing.start_month,
264                    duration_months: schema.timing.duration_months,
265                    onset,
266                    ramp_months: schema.timing.ramp_months,
267                },
268                label: schema.label.clone(),
269                priority: schema.priority,
270            });
271        }
272
273        Ok(interventions)
274    }
275
276    /// List all scenarios in the config.
277    pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
278        self.base_config
279            .scenarios
280            .scenarios
281            .iter()
282            .map(|s| ScenarioSummary {
283                name: s.name.clone(),
284                description: s.description.clone(),
285                tags: s.tags.clone(),
286                intervention_count: s.interventions.len(),
287                probability_weight: s.probability_weight,
288            })
289            .collect()
290    }
291
292    /// Validate all scenarios without generating.
293    pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
294        self.base_config
295            .scenarios
296            .scenarios
297            .iter()
298            .map(|s| {
299                let result = self.validate_scenario(s);
300                ScenarioValidationResult {
301                    name: s.name.clone(),
302                    valid: result.is_ok(),
303                    error: result.err().map(|e| e.to_string()),
304                }
305            })
306            .collect()
307    }
308
309    /// Validate a single scenario.
310    fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
311        let interventions = Self::convert_interventions(&scenario.interventions)?;
312        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
313        let engine = CausalPropagationEngine::new(&self.causal_dag);
314        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
315
316        let constraints = ScenarioConstraints {
317            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
318            preserve_document_chains: scenario.constraints.preserve_document_chains,
319            preserve_period_close: scenario.constraints.preserve_period_close,
320            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
321            custom: vec![],
322        };
323
324        let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
325        Ok(())
326    }
327}
328
329/// Summary info for listing scenarios.
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct ScenarioSummary {
332    pub name: String,
333    pub description: String,
334    pub tags: Vec<String>,
335    pub intervention_count: usize,
336    pub probability_weight: Option<f64>,
337}
338
339/// Result of validating a scenario.
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ScenarioValidationResult {
342    pub name: String,
343    pub valid: bool,
344    pub error: Option<String>,
345}
346
347/// Manifest written alongside scenario output.
348#[derive(Debug, Clone, Serialize, Deserialize)]
349struct ScenarioManifest {
350    scenario_name: String,
351    description: String,
352    interventions_count: usize,
353    months_affected: usize,
354    config_paths_changed: Vec<String>,
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use datasynth_config::{
361        InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
362        ScenarioOutputSchemaConfig, ScenariosConfig,
363    };
364    use datasynth_test_utils::fixtures::minimal_config;
365    use tempfile::TempDir;
366
367    fn config_with_scenario() -> GeneratorConfig {
368        let mut config = minimal_config();
369        config.scenarios = ScenariosConfig {
370            enabled: true,
371            scenarios: vec![ScenarioSchemaConfig {
372                name: "test_recession".to_string(),
373                description: "Test recession scenario".to_string(),
374                tags: vec!["test".to_string()],
375                base: None,
376                probability_weight: Some(0.3),
377                interventions: vec![InterventionSchemaConfig {
378                    intervention_type: serde_json::json!({
379                        "type": "parameter_shift",
380                        "target": "global.period_months",
381                        "to": 3,
382                        "interpolation": "linear"
383                    }),
384                    timing: InterventionTimingSchemaConfig {
385                        start_month: 1,
386                        duration_months: None,
387                        onset: "sudden".to_string(),
388                        ramp_months: None,
389                    },
390                    label: Some("Test shift".to_string()),
391                    priority: 0,
392                }],
393                constraints: ScenarioConstraintsSchemaConfig::default(),
394                output: ScenarioOutputSchemaConfig::default(),
395                metadata: Default::default(),
396            }],
397            causal_model: Default::default(),
398            defaults: Default::default(),
399        };
400        config
401    }
402
403    #[test]
404    fn test_scenario_engine_new_default_dag() {
405        let config = config_with_scenario();
406        let engine = ScenarioEngine::new(config).expect("should create engine");
407        assert!(!engine.causal_dag().nodes.is_empty());
408        assert!(!engine.causal_dag().edges.is_empty());
409    }
410
411    #[test]
412    fn test_scenario_engine_list_scenarios() {
413        let config = config_with_scenario();
414        let engine = ScenarioEngine::new(config).expect("should create engine");
415        let scenarios = engine.list_scenarios();
416        assert_eq!(scenarios.len(), 1);
417        assert_eq!(scenarios[0].name, "test_recession");
418        assert_eq!(scenarios[0].intervention_count, 1);
419    }
420
421    #[test]
422    fn test_scenario_engine_validate_all() {
423        let config = config_with_scenario();
424        let engine = ScenarioEngine::new(config).expect("should create engine");
425        let results = engine.validate_all();
426        assert_eq!(results.len(), 1);
427        assert!(results[0].valid, "validation error: {:?}", results[0].error);
428    }
429
430    #[test]
431    fn test_scenario_engine_converts_schema_to_interventions() {
432        let config = config_with_scenario();
433        let interventions =
434            ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
435                .expect("should convert");
436        assert_eq!(interventions.len(), 1);
437        assert!(matches!(
438            interventions[0].intervention_type,
439            InterventionType::ParameterShift(_)
440        ));
441    }
442
443    #[test]
444    fn test_minimal_dag_preset_valid() {
445        let mut config = minimal_config();
446        config.scenarios = ScenariosConfig {
447            enabled: true,
448            scenarios: vec![ScenarioSchemaConfig {
449                name: "minimal_test".to_string(),
450                description: "Test with minimal DAG".to_string(),
451                tags: vec![],
452                base: None,
453                probability_weight: None,
454                interventions: vec![InterventionSchemaConfig {
455                    intervention_type: serde_json::json!({
456                        "type": "parameter_shift",
457                        "target": "transactions.volume_multiplier",
458                        "to": 2.0,
459                        "interpolation": "linear"
460                    }),
461                    timing: InterventionTimingSchemaConfig {
462                        start_month: 1,
463                        duration_months: None,
464                        onset: "sudden".to_string(),
465                        ramp_months: None,
466                    },
467                    label: Some("Volume increase".to_string()),
468                    priority: 0,
469                }],
470                constraints: ScenarioConstraintsSchemaConfig::default(),
471                output: ScenarioOutputSchemaConfig::default(),
472                metadata: Default::default(),
473            }],
474            causal_model: datasynth_config::CausalModelSchemaConfig {
475                preset: "minimal".to_string(),
476                ..Default::default()
477            },
478            defaults: Default::default(),
479        };
480
481        let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
482        assert_eq!(engine.causal_dag().nodes.len(), 3);
483        assert_eq!(engine.causal_dag().edges.len(), 2);
484
485        // Validate all scenarios pass
486        let results = engine.validate_all();
487        assert_eq!(results.len(), 1);
488        assert!(results[0].valid, "validation error: {:?}", results[0].error);
489    }
490
491    #[test]
492    fn test_scenario_engine_generates_output() {
493        let config = config_with_scenario();
494        let engine = ScenarioEngine::new(config).expect("should create engine");
495        let tmpdir = TempDir::new().expect("should create tmpdir");
496        let results = engine.generate_all(tmpdir.path()).expect("should generate");
497        assert_eq!(results.len(), 1);
498        assert_eq!(results[0].scenario_name, "test_recession");
499        // Manifest should exist
500        let manifest_path = tmpdir
501            .path()
502            .join("scenarios")
503            .join("test_recession")
504            .join("scenario_manifest.yaml");
505        assert!(manifest_path.exists());
506    }
507}