Skip to main content

datasynth_runtime/
scenario_engine.rs

1//! Scenario engine orchestrator for paired baseline/counterfactual generation.
2
3use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9    Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16/// Errors from the scenario engine.
17#[derive(Debug, Error)]
18pub enum ScenarioError {
19    #[error("intervention error: {0}")]
20    Intervention(#[from] InterventionError),
21    #[error("propagation error: {0}")]
22    Propagation(#[from] PropagationError),
23    #[error("mutation error: {0}")]
24    Mutation(#[from] MutationError),
25    #[error("DAG error: {0}")]
26    Dag(#[from] CausalDAGError),
27    #[error("generation error: {0}")]
28    Generation(String),
29    #[error("IO error: {0}")]
30    Io(#[from] std::io::Error),
31    #[error("serialization error: {0}")]
32    Serialization(String),
33}
34
35/// Result of generating a single scenario.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38    pub scenario_name: String,
39    pub baseline_path: PathBuf,
40    pub counterfactual_path: PathBuf,
41    pub interventions_applied: usize,
42    pub months_affected: usize,
43}
44
45/// Orchestrates paired scenario generation.
46pub struct ScenarioEngine {
47    base_config: GeneratorConfig,
48    causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52    /// Create a new ScenarioEngine, loading the causal DAG from config.
53    pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54        let causal_dag = Self::load_causal_dag(&config)?;
55        Ok(Self {
56            base_config: config,
57            causal_dag,
58        })
59    }
60
61    /// Load the causal DAG from config presets or custom definition.
62    fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63        let causal_config = &config.scenarios.causal_model;
64        let mut dag: CausalDAG = match causal_config.preset.as_str() {
65            "default" | "" => {
66                let yaml = include_str!("causal_dag_default.yaml");
67                serde_yaml::from_str(yaml).map_err(|e| {
68                    ScenarioError::Serialization(format!(
69                        "failed to parse default causal DAG: {}",
70                        e
71                    ))
72                })?
73            }
74            "minimal" => {
75                use datasynth_core::causal_dag::{
76                    CausalEdge, CausalNode, NodeCategory, TransferFunction,
77                };
78                // Minimal DAG: macro → operational → outcome (3 nodes, 2 edges)
79                CausalDAG {
80                    nodes: vec![
81                        CausalNode {
82                            id: "gdp_growth".to_string(),
83                            label: "GDP Growth".to_string(),
84                            category: NodeCategory::Macro,
85                            baseline_value: 0.025,
86                            bounds: Some((-0.10, 0.15)),
87                            interventionable: true,
88                            config_bindings: vec![],
89                        },
90                        CausalNode {
91                            id: "transaction_volume".to_string(),
92                            label: "Transaction Volume".to_string(),
93                            category: NodeCategory::Operational,
94                            baseline_value: 1.0,
95                            bounds: Some((0.2, 3.0)),
96                            interventionable: true,
97                            config_bindings: vec!["transactions.volume_multiplier".to_string()],
98                        },
99                        CausalNode {
100                            id: "error_rate".to_string(),
101                            label: "Error Rate".to_string(),
102                            category: NodeCategory::Outcome,
103                            baseline_value: 0.02,
104                            bounds: Some((0.0, 0.30)),
105                            interventionable: false,
106                            config_bindings: vec!["anomaly_injection.base_rate".to_string()],
107                        },
108                    ],
109                    edges: vec![
110                        CausalEdge {
111                            from: "gdp_growth".to_string(),
112                            to: "transaction_volume".to_string(),
113                            transfer: TransferFunction::Linear {
114                                coefficient: 0.8,
115                                intercept: 1.0,
116                            },
117                            lag_months: 1,
118                            strength: 1.0,
119                            mechanism: Some("GDP growth drives transaction volume".to_string()),
120                        },
121                        CausalEdge {
122                            from: "transaction_volume".to_string(),
123                            to: "error_rate".to_string(),
124                            transfer: TransferFunction::Linear {
125                                coefficient: 0.01,
126                                intercept: 0.0,
127                            },
128                            lag_months: 0,
129                            strength: 1.0,
130                            mechanism: Some("Higher volume increases error rate".to_string()),
131                        },
132                    ],
133                    topological_order: vec![],
134                }
135            }
136            other => {
137                return Err(ScenarioError::Serialization(format!(
138                    "unknown causal DAG preset: '{}'",
139                    other
140                )));
141            }
142        };
143
144        dag.validate()?;
145        Ok(dag)
146    }
147
148    /// Get a reference to the loaded causal DAG.
149    pub fn causal_dag(&self) -> &CausalDAG {
150        &self.causal_dag
151    }
152
153    /// Get a reference to the base config.
154    pub fn base_config(&self) -> &GeneratorConfig {
155        &self.base_config
156    }
157
158    /// Generate all scenarios defined in config.
159    pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
160        let scenarios = &self.base_config.scenarios.scenarios;
161        let mut results = Vec::with_capacity(scenarios.len());
162
163        // Create baseline directory
164        let baseline_path = output_root.join("baseline");
165        std::fs::create_dir_all(&baseline_path)?;
166
167        // Generate each scenario
168        for scenario in scenarios {
169            let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
170            results.push(result);
171        }
172
173        Ok(results)
174    }
175
176    /// Generate a single scenario: validate, propagate, mutate, produce output.
177    pub fn generate_scenario(
178        &self,
179        scenario: &ScenarioSchemaConfig,
180        baseline_path: &Path,
181        output_root: &Path,
182    ) -> Result<ScenarioResult, ScenarioError> {
183        // 1. Convert schema config to core interventions
184        let interventions = Self::convert_interventions(&scenario.interventions)?;
185
186        // 2. Validate interventions
187        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
188
189        // 3. Propagate through causal DAG
190        let engine = CausalPropagationEngine::new(&self.causal_dag);
191        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
192
193        // 4. Build constraints
194        let constraints = ScenarioConstraints {
195            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
196            preserve_document_chains: scenario.constraints.preserve_document_chains,
197            preserve_period_close: scenario.constraints.preserve_period_close,
198            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
199            custom: vec![],
200        };
201
202        // 5. Apply to config (creates mutated copy)
203        let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
204
205        // 6. Create scenario output directory
206        let scenario_path = output_root
207            .join("scenarios")
208            .join(&scenario.name)
209            .join("data");
210        std::fs::create_dir_all(&scenario_path)?;
211
212        // 7. Write scenario manifest
213        let manifest = ScenarioManifest {
214            scenario_name: scenario.name.clone(),
215            description: scenario.description.clone(),
216            interventions_count: interventions.len(),
217            months_affected: propagated.changes_by_month.len(),
218            config_paths_changed: propagated
219                .changes_by_month
220                .values()
221                .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
222                .collect::<std::collections::HashSet<_>>()
223                .into_iter()
224                .collect(),
225        };
226
227        let manifest_path = output_root
228            .join("scenarios")
229            .join(&scenario.name)
230            .join("scenario_manifest.yaml");
231        let manifest_yaml = serde_yaml::to_string(&manifest)
232            .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
233        std::fs::write(&manifest_path, manifest_yaml)?;
234
235        Ok(ScenarioResult {
236            scenario_name: scenario.name.clone(),
237            baseline_path: baseline_path.to_path_buf(),
238            counterfactual_path: scenario_path,
239            interventions_applied: interventions.len(),
240            months_affected: propagated.changes_by_month.len(),
241        })
242    }
243
244    /// Convert schema-level intervention configs to core Intervention structs.
245    fn convert_interventions(
246        schema_interventions: &[datasynth_config::InterventionSchemaConfig],
247    ) -> Result<Vec<Intervention>, ScenarioError> {
248        let mut interventions = Vec::new();
249
250        for schema in schema_interventions {
251            let intervention_type: InterventionType =
252                serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
253                    ScenarioError::Serialization(format!(
254                        "failed to parse intervention type: {}",
255                        e
256                    ))
257                })?;
258
259            let onset = match schema.timing.onset.to_lowercase().as_str() {
260                "sudden" => OnsetType::Sudden,
261                "gradual" => OnsetType::Gradual,
262                "oscillating" => OnsetType::Oscillating,
263                _ => OnsetType::Sudden,
264            };
265
266            interventions.push(Intervention {
267                id: Uuid::new_v4(),
268                intervention_type,
269                timing: InterventionTiming {
270                    start_month: schema.timing.start_month,
271                    duration_months: schema.timing.duration_months,
272                    onset,
273                    ramp_months: schema.timing.ramp_months,
274                },
275                label: schema.label.clone(),
276                priority: schema.priority,
277            });
278        }
279
280        Ok(interventions)
281    }
282
283    /// List all scenarios in the config.
284    pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
285        self.base_config
286            .scenarios
287            .scenarios
288            .iter()
289            .map(|s| ScenarioSummary {
290                name: s.name.clone(),
291                description: s.description.clone(),
292                tags: s.tags.clone(),
293                intervention_count: s.interventions.len(),
294                probability_weight: s.probability_weight,
295            })
296            .collect()
297    }
298
299    /// Validate all scenarios without generating.
300    pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
301        self.base_config
302            .scenarios
303            .scenarios
304            .iter()
305            .map(|s| {
306                let result = self.validate_scenario(s);
307                ScenarioValidationResult {
308                    name: s.name.clone(),
309                    valid: result.is_ok(),
310                    error: result.err().map(|e| e.to_string()),
311                }
312            })
313            .collect()
314    }
315
316    /// Validate a single scenario.
317    fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
318        let interventions = Self::convert_interventions(&scenario.interventions)?;
319        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
320        let engine = CausalPropagationEngine::new(&self.causal_dag);
321        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
322
323        let constraints = ScenarioConstraints {
324            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
325            preserve_document_chains: scenario.constraints.preserve_document_chains,
326            preserve_period_close: scenario.constraints.preserve_period_close,
327            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
328            custom: vec![],
329        };
330
331        let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
332        Ok(())
333    }
334}
335
336/// Summary info for listing scenarios.
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ScenarioSummary {
339    pub name: String,
340    pub description: String,
341    pub tags: Vec<String>,
342    pub intervention_count: usize,
343    pub probability_weight: Option<f64>,
344}
345
346/// Result of validating a scenario.
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ScenarioValidationResult {
349    pub name: String,
350    pub valid: bool,
351    pub error: Option<String>,
352}
353
354/// Manifest written alongside scenario output.
355#[derive(Debug, Clone, Serialize, Deserialize)]
356struct ScenarioManifest {
357    scenario_name: String,
358    description: String,
359    interventions_count: usize,
360    months_affected: usize,
361    config_paths_changed: Vec<String>,
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use datasynth_config::{
368        InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
369        ScenarioOutputSchemaConfig, ScenariosConfig,
370    };
371    use datasynth_test_utils::fixtures::minimal_config;
372    use tempfile::TempDir;
373
374    fn config_with_scenario() -> GeneratorConfig {
375        let mut config = minimal_config();
376        config.scenarios = ScenariosConfig {
377            enabled: true,
378            scenarios: vec![ScenarioSchemaConfig {
379                name: "test_recession".to_string(),
380                description: "Test recession scenario".to_string(),
381                tags: vec!["test".to_string()],
382                base: None,
383                probability_weight: Some(0.3),
384                interventions: vec![InterventionSchemaConfig {
385                    intervention_type: serde_json::json!({
386                        "type": "parameter_shift",
387                        "target": "global.period_months",
388                        "to": 3,
389                        "interpolation": "linear"
390                    }),
391                    timing: InterventionTimingSchemaConfig {
392                        start_month: 1,
393                        duration_months: None,
394                        onset: "sudden".to_string(),
395                        ramp_months: None,
396                    },
397                    label: Some("Test shift".to_string()),
398                    priority: 0,
399                }],
400                constraints: ScenarioConstraintsSchemaConfig::default(),
401                output: ScenarioOutputSchemaConfig::default(),
402                metadata: Default::default(),
403            }],
404            causal_model: Default::default(),
405            defaults: Default::default(),
406        };
407        config
408    }
409
410    #[test]
411    fn test_scenario_engine_new_default_dag() {
412        let config = config_with_scenario();
413        let engine = ScenarioEngine::new(config).expect("should create engine");
414        assert!(!engine.causal_dag().nodes.is_empty());
415        assert!(!engine.causal_dag().edges.is_empty());
416    }
417
418    #[test]
419    fn test_scenario_engine_list_scenarios() {
420        let config = config_with_scenario();
421        let engine = ScenarioEngine::new(config).expect("should create engine");
422        let scenarios = engine.list_scenarios();
423        assert_eq!(scenarios.len(), 1);
424        assert_eq!(scenarios[0].name, "test_recession");
425        assert_eq!(scenarios[0].intervention_count, 1);
426    }
427
428    #[test]
429    fn test_scenario_engine_validate_all() {
430        let config = config_with_scenario();
431        let engine = ScenarioEngine::new(config).expect("should create engine");
432        let results = engine.validate_all();
433        assert_eq!(results.len(), 1);
434        assert!(results[0].valid, "validation error: {:?}", results[0].error);
435    }
436
437    #[test]
438    fn test_scenario_engine_converts_schema_to_interventions() {
439        let config = config_with_scenario();
440        let interventions =
441            ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
442                .expect("should convert");
443        assert_eq!(interventions.len(), 1);
444        assert!(matches!(
445            interventions[0].intervention_type,
446            InterventionType::ParameterShift(_)
447        ));
448    }
449
450    #[test]
451    fn test_minimal_dag_preset_valid() {
452        let mut config = minimal_config();
453        config.scenarios = ScenariosConfig {
454            enabled: true,
455            scenarios: vec![ScenarioSchemaConfig {
456                name: "minimal_test".to_string(),
457                description: "Test with minimal DAG".to_string(),
458                tags: vec![],
459                base: None,
460                probability_weight: None,
461                interventions: vec![InterventionSchemaConfig {
462                    intervention_type: serde_json::json!({
463                        "type": "parameter_shift",
464                        "target": "transactions.volume_multiplier",
465                        "to": 2.0,
466                        "interpolation": "linear"
467                    }),
468                    timing: InterventionTimingSchemaConfig {
469                        start_month: 1,
470                        duration_months: None,
471                        onset: "sudden".to_string(),
472                        ramp_months: None,
473                    },
474                    label: Some("Volume increase".to_string()),
475                    priority: 0,
476                }],
477                constraints: ScenarioConstraintsSchemaConfig::default(),
478                output: ScenarioOutputSchemaConfig::default(),
479                metadata: Default::default(),
480            }],
481            causal_model: datasynth_config::CausalModelSchemaConfig {
482                preset: "minimal".to_string(),
483                ..Default::default()
484            },
485            defaults: Default::default(),
486        };
487
488        let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
489        assert_eq!(engine.causal_dag().nodes.len(), 3);
490        assert_eq!(engine.causal_dag().edges.len(), 2);
491
492        // Validate all scenarios pass
493        let results = engine.validate_all();
494        assert_eq!(results.len(), 1);
495        assert!(results[0].valid, "validation error: {:?}", results[0].error);
496    }
497
498    #[test]
499    fn test_scenario_engine_generates_output() {
500        let config = config_with_scenario();
501        let engine = ScenarioEngine::new(config).expect("should create engine");
502        let tmpdir = TempDir::new().expect("should create tmpdir");
503        let results = engine.generate_all(tmpdir.path()).expect("should generate");
504        assert_eq!(results.len(), 1);
505        assert_eq!(results[0].scenario_name, "test_recession");
506        // Manifest should exist
507        let manifest_path = tmpdir
508            .path()
509            .join("scenarios")
510            .join("test_recession")
511            .join("scenario_manifest.yaml");
512        assert!(manifest_path.exists());
513    }
514}