Skip to main content

datasynth_runtime/
scenario_engine.rs

1//! Scenario engine orchestrator for paired baseline/counterfactual generation.
2
3use crate::causal_engine::{CausalPropagationEngine, PropagationError};
4use crate::config_mutator::{ConfigMutator, MutationError};
5use crate::intervention_manager::{InterventionError, InterventionManager};
6use datasynth_config::{GeneratorConfig, ScenarioSchemaConfig};
7use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
8use datasynth_core::{
9    Intervention, InterventionTiming, InterventionType, OnsetType, ScenarioConstraints,
10};
11use serde::{Deserialize, Serialize};
12use std::path::{Path, PathBuf};
13use thiserror::Error;
14use uuid::Uuid;
15
16/// Errors from the scenario engine.
17#[derive(Debug, Error)]
18pub enum ScenarioError {
19    #[error("intervention error: {0}")]
20    Intervention(#[from] InterventionError),
21    #[error("propagation error: {0}")]
22    Propagation(#[from] PropagationError),
23    #[error("mutation error: {0}")]
24    Mutation(#[from] MutationError),
25    #[error("DAG error: {0}")]
26    Dag(#[from] CausalDAGError),
27    #[error("generation error: {0}")]
28    Generation(String),
29    #[error("IO error: {0}")]
30    Io(#[from] std::io::Error),
31    #[error("serialization error: {0}")]
32    Serialization(String),
33}
34
35/// Result of generating a single scenario.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ScenarioResult {
38    pub scenario_name: String,
39    pub baseline_path: PathBuf,
40    pub counterfactual_path: PathBuf,
41    pub interventions_applied: usize,
42    pub months_affected: usize,
43}
44
45/// Orchestrates paired scenario generation.
46pub struct ScenarioEngine {
47    base_config: GeneratorConfig,
48    causal_dag: CausalDAG,
49}
50
51impl ScenarioEngine {
52    /// Create a new ScenarioEngine, loading the causal DAG from config.
53    pub fn new(config: GeneratorConfig) -> Result<Self, ScenarioError> {
54        let causal_dag = Self::load_causal_dag(&config)?;
55        Ok(Self {
56            base_config: config,
57            causal_dag,
58        })
59    }
60
61    /// Load the causal DAG from config presets or custom definition.
62    fn load_causal_dag(config: &GeneratorConfig) -> Result<CausalDAG, ScenarioError> {
63        let causal_config = &config.scenarios.causal_model;
64        let mut dag: CausalDAG = match causal_config.preset.as_str() {
65            "default" | "" => {
66                let yaml =
67                    include_str!("causal_dag_default.yaml");
68                serde_yaml::from_str(yaml).map_err(|e| {
69                    ScenarioError::Serialization(format!(
70                        "failed to parse default causal DAG: {}",
71                        e
72                    ))
73                })?
74            }
75            "minimal" => {
76                use datasynth_core::causal_dag::{
77                    CausalEdge, CausalNode, NodeCategory, TransferFunction,
78                };
79                // Minimal DAG: macro → operational → outcome (3 nodes, 2 edges)
80                CausalDAG {
81                    nodes: vec![
82                        CausalNode {
83                            id: "gdp_growth".to_string(),
84                            label: "GDP Growth".to_string(),
85                            category: NodeCategory::Macro,
86                            baseline_value: 0.025,
87                            bounds: Some((-0.10, 0.15)),
88                            interventionable: true,
89                            config_bindings: vec![],
90                        },
91                        CausalNode {
92                            id: "transaction_volume".to_string(),
93                            label: "Transaction Volume".to_string(),
94                            category: NodeCategory::Operational,
95                            baseline_value: 1.0,
96                            bounds: Some((0.2, 3.0)),
97                            interventionable: true,
98                            config_bindings: vec!["transactions.volume_multiplier".to_string()],
99                        },
100                        CausalNode {
101                            id: "error_rate".to_string(),
102                            label: "Error Rate".to_string(),
103                            category: NodeCategory::Outcome,
104                            baseline_value: 0.02,
105                            bounds: Some((0.0, 0.30)),
106                            interventionable: false,
107                            config_bindings: vec!["anomaly_injection.base_rate".to_string()],
108                        },
109                    ],
110                    edges: vec![
111                        CausalEdge {
112                            from: "gdp_growth".to_string(),
113                            to: "transaction_volume".to_string(),
114                            transfer: TransferFunction::Linear {
115                                coefficient: 0.8,
116                                intercept: 1.0,
117                            },
118                            lag_months: 1,
119                            strength: 1.0,
120                            mechanism: Some("GDP growth drives transaction volume".to_string()),
121                        },
122                        CausalEdge {
123                            from: "transaction_volume".to_string(),
124                            to: "error_rate".to_string(),
125                            transfer: TransferFunction::Linear {
126                                coefficient: 0.01,
127                                intercept: 0.0,
128                            },
129                            lag_months: 0,
130                            strength: 1.0,
131                            mechanism: Some("Higher volume increases error rate".to_string()),
132                        },
133                    ],
134                    topological_order: vec![],
135                }
136            }
137            other => {
138                return Err(ScenarioError::Serialization(format!(
139                    "unknown causal DAG preset: '{}'",
140                    other
141                )));
142            }
143        };
144
145        dag.validate()?;
146        Ok(dag)
147    }
148
149    /// Get a reference to the loaded causal DAG.
150    pub fn causal_dag(&self) -> &CausalDAG {
151        &self.causal_dag
152    }
153
154    /// Get a reference to the base config.
155    pub fn base_config(&self) -> &GeneratorConfig {
156        &self.base_config
157    }
158
159    /// Generate all scenarios defined in config.
160    pub fn generate_all(&self, output_root: &Path) -> Result<Vec<ScenarioResult>, ScenarioError> {
161        let scenarios = &self.base_config.scenarios.scenarios;
162        let mut results = Vec::with_capacity(scenarios.len());
163
164        // Create baseline directory
165        let baseline_path = output_root.join("baseline");
166        std::fs::create_dir_all(&baseline_path)?;
167
168        // Generate each scenario
169        for scenario in scenarios {
170            let result = self.generate_scenario(scenario, &baseline_path, output_root)?;
171            results.push(result);
172        }
173
174        Ok(results)
175    }
176
177    /// Generate a single scenario: validate, propagate, mutate, produce output.
178    pub fn generate_scenario(
179        &self,
180        scenario: &ScenarioSchemaConfig,
181        baseline_path: &Path,
182        output_root: &Path,
183    ) -> Result<ScenarioResult, ScenarioError> {
184        // 1. Convert schema config to core interventions
185        let interventions = Self::convert_interventions(&scenario.interventions)?;
186
187        // 2. Validate interventions
188        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
189
190        // 3. Propagate through causal DAG
191        let engine = CausalPropagationEngine::new(&self.causal_dag);
192        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
193
194        // 4. Build constraints
195        let constraints = ScenarioConstraints {
196            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
197            preserve_document_chains: scenario.constraints.preserve_document_chains,
198            preserve_period_close: scenario.constraints.preserve_period_close,
199            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
200            custom: vec![],
201        };
202
203        // 5. Apply to config (creates mutated copy)
204        let _mutated_config = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
205
206        // 6. Create scenario output directory
207        let scenario_path = output_root
208            .join("scenarios")
209            .join(&scenario.name)
210            .join("data");
211        std::fs::create_dir_all(&scenario_path)?;
212
213        // 7. Write scenario manifest
214        let manifest = ScenarioManifest {
215            scenario_name: scenario.name.clone(),
216            description: scenario.description.clone(),
217            interventions_count: interventions.len(),
218            months_affected: propagated.changes_by_month.len(),
219            config_paths_changed: propagated
220                .changes_by_month
221                .values()
222                .flat_map(|changes| changes.iter().map(|c| c.path.clone()))
223                .collect::<std::collections::HashSet<_>>()
224                .into_iter()
225                .collect(),
226        };
227
228        let manifest_path = output_root
229            .join("scenarios")
230            .join(&scenario.name)
231            .join("scenario_manifest.yaml");
232        let manifest_yaml = serde_yaml::to_string(&manifest)
233            .map_err(|e| ScenarioError::Serialization(e.to_string()))?;
234        std::fs::write(&manifest_path, manifest_yaml)?;
235
236        Ok(ScenarioResult {
237            scenario_name: scenario.name.clone(),
238            baseline_path: baseline_path.to_path_buf(),
239            counterfactual_path: scenario_path,
240            interventions_applied: interventions.len(),
241            months_affected: propagated.changes_by_month.len(),
242        })
243    }
244
245    /// Convert schema-level intervention configs to core Intervention structs.
246    fn convert_interventions(
247        schema_interventions: &[datasynth_config::InterventionSchemaConfig],
248    ) -> Result<Vec<Intervention>, ScenarioError> {
249        let mut interventions = Vec::new();
250
251        for schema in schema_interventions {
252            let intervention_type: InterventionType =
253                serde_json::from_value(schema.intervention_type.clone()).map_err(|e| {
254                    ScenarioError::Serialization(format!(
255                        "failed to parse intervention type: {}",
256                        e
257                    ))
258                })?;
259
260            let onset = match schema.timing.onset.to_lowercase().as_str() {
261                "sudden" => OnsetType::Sudden,
262                "gradual" => OnsetType::Gradual,
263                "oscillating" => OnsetType::Oscillating,
264                _ => OnsetType::Sudden,
265            };
266
267            interventions.push(Intervention {
268                id: Uuid::new_v4(),
269                intervention_type,
270                timing: InterventionTiming {
271                    start_month: schema.timing.start_month,
272                    duration_months: schema.timing.duration_months,
273                    onset,
274                    ramp_months: schema.timing.ramp_months,
275                },
276                label: schema.label.clone(),
277                priority: schema.priority,
278            });
279        }
280
281        Ok(interventions)
282    }
283
284    /// List all scenarios in the config.
285    pub fn list_scenarios(&self) -> Vec<ScenarioSummary> {
286        self.base_config
287            .scenarios
288            .scenarios
289            .iter()
290            .map(|s| ScenarioSummary {
291                name: s.name.clone(),
292                description: s.description.clone(),
293                tags: s.tags.clone(),
294                intervention_count: s.interventions.len(),
295                probability_weight: s.probability_weight,
296            })
297            .collect()
298    }
299
300    /// Validate all scenarios without generating.
301    pub fn validate_all(&self) -> Vec<ScenarioValidationResult> {
302        self.base_config
303            .scenarios
304            .scenarios
305            .iter()
306            .map(|s| {
307                let result = self.validate_scenario(s);
308                ScenarioValidationResult {
309                    name: s.name.clone(),
310                    valid: result.is_ok(),
311                    error: result.err().map(|e| e.to_string()),
312                }
313            })
314            .collect()
315    }
316
317    /// Validate a single scenario.
318    fn validate_scenario(&self, scenario: &ScenarioSchemaConfig) -> Result<(), ScenarioError> {
319        let interventions = Self::convert_interventions(&scenario.interventions)?;
320        let validated = InterventionManager::validate(&interventions, &self.base_config)?;
321        let engine = CausalPropagationEngine::new(&self.causal_dag);
322        let propagated = engine.propagate(&validated, self.base_config.global.period_months)?;
323
324        let constraints = ScenarioConstraints {
325            preserve_accounting_identity: scenario.constraints.preserve_accounting_identity,
326            preserve_document_chains: scenario.constraints.preserve_document_chains,
327            preserve_period_close: scenario.constraints.preserve_period_close,
328            preserve_balance_coherence: scenario.constraints.preserve_balance_coherence,
329            custom: vec![],
330        };
331
332        let _mutated = ConfigMutator::apply(&self.base_config, &propagated, &constraints)?;
333        Ok(())
334    }
335}
336
337/// Summary info for listing scenarios.
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ScenarioSummary {
340    pub name: String,
341    pub description: String,
342    pub tags: Vec<String>,
343    pub intervention_count: usize,
344    pub probability_weight: Option<f64>,
345}
346
347/// Result of validating a scenario.
348#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct ScenarioValidationResult {
350    pub name: String,
351    pub valid: bool,
352    pub error: Option<String>,
353}
354
355/// Manifest written alongside scenario output.
356#[derive(Debug, Clone, Serialize, Deserialize)]
357struct ScenarioManifest {
358    scenario_name: String,
359    description: String,
360    interventions_count: usize,
361    months_affected: usize,
362    config_paths_changed: Vec<String>,
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use datasynth_config::{
369        InterventionSchemaConfig, InterventionTimingSchemaConfig, ScenarioConstraintsSchemaConfig,
370        ScenarioOutputSchemaConfig, ScenariosConfig,
371    };
372    use datasynth_test_utils::fixtures::minimal_config;
373    use tempfile::TempDir;
374
375    fn config_with_scenario() -> GeneratorConfig {
376        let mut config = minimal_config();
377        config.scenarios = ScenariosConfig {
378            enabled: true,
379            scenarios: vec![ScenarioSchemaConfig {
380                name: "test_recession".to_string(),
381                description: "Test recession scenario".to_string(),
382                tags: vec!["test".to_string()],
383                base: None,
384                probability_weight: Some(0.3),
385                interventions: vec![InterventionSchemaConfig {
386                    intervention_type: serde_json::json!({
387                        "type": "parameter_shift",
388                        "target": "global.period_months",
389                        "to": 3,
390                        "interpolation": "linear"
391                    }),
392                    timing: InterventionTimingSchemaConfig {
393                        start_month: 1,
394                        duration_months: None,
395                        onset: "sudden".to_string(),
396                        ramp_months: None,
397                    },
398                    label: Some("Test shift".to_string()),
399                    priority: 0,
400                }],
401                constraints: ScenarioConstraintsSchemaConfig::default(),
402                output: ScenarioOutputSchemaConfig::default(),
403                metadata: Default::default(),
404            }],
405            causal_model: Default::default(),
406            defaults: Default::default(),
407        };
408        config
409    }
410
411    #[test]
412    fn test_scenario_engine_new_default_dag() {
413        let config = config_with_scenario();
414        let engine = ScenarioEngine::new(config).expect("should create engine");
415        assert!(!engine.causal_dag().nodes.is_empty());
416        assert!(!engine.causal_dag().edges.is_empty());
417    }
418
419    #[test]
420    fn test_scenario_engine_list_scenarios() {
421        let config = config_with_scenario();
422        let engine = ScenarioEngine::new(config).expect("should create engine");
423        let scenarios = engine.list_scenarios();
424        assert_eq!(scenarios.len(), 1);
425        assert_eq!(scenarios[0].name, "test_recession");
426        assert_eq!(scenarios[0].intervention_count, 1);
427    }
428
429    #[test]
430    fn test_scenario_engine_validate_all() {
431        let config = config_with_scenario();
432        let engine = ScenarioEngine::new(config).expect("should create engine");
433        let results = engine.validate_all();
434        assert_eq!(results.len(), 1);
435        assert!(results[0].valid, "validation error: {:?}", results[0].error);
436    }
437
438    #[test]
439    fn test_scenario_engine_converts_schema_to_interventions() {
440        let config = config_with_scenario();
441        let interventions =
442            ScenarioEngine::convert_interventions(&config.scenarios.scenarios[0].interventions)
443                .expect("should convert");
444        assert_eq!(interventions.len(), 1);
445        assert!(matches!(
446            interventions[0].intervention_type,
447            InterventionType::ParameterShift(_)
448        ));
449    }
450
451    #[test]
452    fn test_minimal_dag_preset_valid() {
453        let mut config = minimal_config();
454        config.scenarios = ScenariosConfig {
455            enabled: true,
456            scenarios: vec![ScenarioSchemaConfig {
457                name: "minimal_test".to_string(),
458                description: "Test with minimal DAG".to_string(),
459                tags: vec![],
460                base: None,
461                probability_weight: None,
462                interventions: vec![InterventionSchemaConfig {
463                    intervention_type: serde_json::json!({
464                        "type": "parameter_shift",
465                        "target": "transactions.volume_multiplier",
466                        "to": 2.0,
467                        "interpolation": "linear"
468                    }),
469                    timing: InterventionTimingSchemaConfig {
470                        start_month: 1,
471                        duration_months: None,
472                        onset: "sudden".to_string(),
473                        ramp_months: None,
474                    },
475                    label: Some("Volume increase".to_string()),
476                    priority: 0,
477                }],
478                constraints: ScenarioConstraintsSchemaConfig::default(),
479                output: ScenarioOutputSchemaConfig::default(),
480                metadata: Default::default(),
481            }],
482            causal_model: datasynth_config::CausalModelSchemaConfig {
483                preset: "minimal".to_string(),
484                ..Default::default()
485            },
486            defaults: Default::default(),
487        };
488
489        let engine = ScenarioEngine::new(config).expect("should create engine with minimal DAG");
490        assert_eq!(engine.causal_dag().nodes.len(), 3);
491        assert_eq!(engine.causal_dag().edges.len(), 2);
492
493        // Validate all scenarios pass
494        let results = engine.validate_all();
495        assert_eq!(results.len(), 1);
496        assert!(results[0].valid, "validation error: {:?}", results[0].error);
497    }
498
499    #[test]
500    fn test_scenario_engine_generates_output() {
501        let config = config_with_scenario();
502        let engine = ScenarioEngine::new(config).expect("should create engine");
503        let tmpdir = TempDir::new().expect("should create tmpdir");
504        let results = engine.generate_all(tmpdir.path()).expect("should generate");
505        assert_eq!(results.len(), 1);
506        assert_eq!(results[0].scenario_name, "test_recession");
507        // Manifest should exist
508        let manifest_path = tmpdir
509            .path()
510            .join("scenarios")
511            .join("test_recession")
512            .join("scenario_manifest.yaml");
513        assert!(manifest_path.exists());
514    }
515}