1use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
7use datasynth_core::{Intervention, InterventionTiming, InterventionType, OnsetType};
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap};
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
14pub struct ValidatedIntervention {
15 pub intervention: Intervention,
16 pub affected_config_paths: Vec<String>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct PropagatedInterventions {
22 pub changes_by_month: BTreeMap<u32, Vec<ConfigChange>>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ConfigChange {
28 pub path: String,
30 pub value: serde_json::Value,
32 pub source_node: String,
34 pub is_direct: bool,
36}
37
38#[derive(Debug, Error)]
40pub enum PropagationError {
41 #[error("DAG validation failed: {0}")]
42 DagValidation(#[from] CausalDAGError),
43 #[error("no causal node mapping for intervention target: {0}")]
44 NoNodeMapping(String),
45}
46
47pub struct CausalPropagationEngine<'a> {
49 dag: &'a CausalDAG,
50}
51
52impl<'a> CausalPropagationEngine<'a> {
53 pub fn new(dag: &'a CausalDAG) -> Self {
54 Self { dag }
55 }
56
57 pub fn propagate(
59 &self,
60 interventions: &[ValidatedIntervention],
61 period_months: u32,
62 ) -> Result<PropagatedInterventions, PropagationError> {
63 let mut result = PropagatedInterventions::default();
64
65 for month in 1..=period_months {
66 let direct = self.compute_direct_effects(interventions, month);
68
69 if direct.is_empty() {
70 continue;
71 }
72
73 let propagated_values = self.dag.propagate(&direct, month);
75
76 let mut changes = Vec::new();
78 for (node_id, value) in &propagated_values {
79 if let Some(node) = self.dag.find_node(node_id) {
80 if (value - node.baseline_value).abs() < f64::EPSILON {
82 continue;
83 }
84
85 let is_direct = direct.contains_key(node_id);
86 for binding in &node.config_bindings {
87 changes.push(ConfigChange {
88 path: binding.clone(),
89 value: serde_json::Value::from(*value),
90 source_node: node_id.clone(),
91 is_direct,
92 });
93 }
94 }
95 }
96
97 if !changes.is_empty() {
98 result.changes_by_month.insert(month, changes);
99 }
100 }
101
102 Ok(result)
103 }
104
105 fn compute_direct_effects(
107 &self,
108 interventions: &[ValidatedIntervention],
109 month: u32,
110 ) -> HashMap<String, f64> {
111 let mut effects = HashMap::new();
112
113 for validated in interventions {
114 let timing = &validated.intervention.timing;
115
116 if !Self::is_active(timing, month) {
118 continue;
119 }
120
121 let onset_factor = Self::compute_onset_factor(timing, month);
123
124 self.map_intervention_to_nodes(
126 &validated.intervention.intervention_type,
127 onset_factor,
128 &mut effects,
129 );
130 }
131
132 effects
133 }
134
135 fn is_active(timing: &InterventionTiming, month: u32) -> bool {
137 if month < timing.start_month {
138 return false;
139 }
140 if let Some(duration) = timing.duration_months {
141 if month >= timing.start_month + duration {
142 return false;
143 }
144 }
145 true
146 }
147
148 fn compute_onset_factor(timing: &InterventionTiming, month: u32) -> f64 {
150 let months_active = month - timing.start_month;
151
152 match &timing.onset {
153 OnsetType::Sudden => 1.0,
154 OnsetType::Gradual => {
155 let ramp = timing.ramp_months.unwrap_or(1).max(1);
156 if months_active >= ramp {
157 1.0
158 } else {
159 months_active as f64 / ramp as f64
160 }
161 }
162 OnsetType::Oscillating => {
163 let ramp = timing.ramp_months.unwrap_or(4).max(1) as f64;
164 let phase = months_active as f64 / ramp;
165 0.5 * (1.0 - (std::f64::consts::PI * phase).cos())
167 }
168 OnsetType::Custom { .. } => {
169 let ramp = timing.ramp_months.unwrap_or(1).max(1);
171 if months_active >= ramp {
172 1.0
173 } else {
174 months_active as f64 / ramp as f64
175 }
176 }
177 }
178 }
179
180 fn map_intervention_to_nodes(
182 &self,
183 intervention_type: &InterventionType,
184 onset_factor: f64,
185 effects: &mut HashMap<String, f64>,
186 ) {
187 match intervention_type {
188 InterventionType::ParameterShift(ps) => {
189 for node in &self.dag.nodes {
191 if node.config_bindings.contains(&ps.target) {
192 if let Some(to_val) = ps.to.as_f64() {
193 let from_val = ps
194 .from
195 .as_ref()
196 .and_then(|v| v.as_f64())
197 .unwrap_or(node.baseline_value);
198 let interpolated = from_val + (to_val - from_val) * onset_factor;
199 effects.insert(node.id.clone(), interpolated);
200 }
201 }
202 }
203 }
204 InterventionType::MacroShock(ms) => {
205 use datasynth_core::MacroShockType;
207 let severity = ms.severity * onset_factor;
208 match ms.subtype {
209 MacroShockType::Recession => {
210 if let Some(node) = self.dag.find_node("gdp_growth") {
211 let shock = ms.overrides.get("gdp_growth").copied().unwrap_or(-0.02);
212 effects.insert(
213 "gdp_growth".to_string(),
214 node.baseline_value + shock * severity,
215 );
216 }
217 if let Some(node) = self.dag.find_node("unemployment_rate") {
218 let shock = ms
219 .overrides
220 .get("unemployment_rate")
221 .copied()
222 .unwrap_or(0.03);
223 effects.insert(
224 "unemployment_rate".to_string(),
225 node.baseline_value + shock * severity,
226 );
227 }
228 }
229 MacroShockType::InflationSpike => {
230 if let Some(node) = self.dag.find_node("inflation_rate") {
231 let shock = ms.overrides.get("inflation_rate").copied().unwrap_or(0.05);
232 effects.insert(
233 "inflation_rate".to_string(),
234 node.baseline_value + shock * severity,
235 );
236 }
237 }
238 MacroShockType::InterestRateShock => {
239 if let Some(node) = self.dag.find_node("interest_rate") {
240 let shock = ms.overrides.get("interest_rate").copied().unwrap_or(0.03);
241 effects.insert(
242 "interest_rate".to_string(),
243 node.baseline_value + shock * severity,
244 );
245 }
246 }
247 _ => {
248 if let Some(node) = self.dag.find_node("gdp_growth") {
250 effects.insert(
251 "gdp_growth".to_string(),
252 node.baseline_value * (1.0 - 0.1 * severity),
253 );
254 }
255 }
256 }
257 }
258 InterventionType::ControlFailure(cf) => {
259 if let Some(node) = self.dag.find_node("control_effectiveness") {
260 let new_effectiveness = node.baseline_value * cf.severity * onset_factor
261 + node.baseline_value * (1.0 - onset_factor);
262 effects.insert("control_effectiveness".to_string(), new_effectiveness);
263 }
264 }
265 InterventionType::EntityEvent(ee) => {
266 use datasynth_core::InterventionEntityEvent;
267 let rate_increase = ee
268 .parameters
269 .get("rate_increase")
270 .and_then(|v| v.as_f64())
271 .unwrap_or(0.05);
272 match ee.subtype {
273 InterventionEntityEvent::VendorDefault => {
274 if let Some(node) = self.dag.find_node("vendor_default_rate") {
275 effects.insert(
276 "vendor_default_rate".to_string(),
277 node.baseline_value + rate_increase * onset_factor,
278 );
279 }
280 }
281 InterventionEntityEvent::CustomerChurn => {
282 if let Some(node) = self.dag.find_node("customer_churn_rate") {
283 effects.insert(
284 "customer_churn_rate".to_string(),
285 node.baseline_value + rate_increase * onset_factor,
286 );
287 }
288 }
289 InterventionEntityEvent::EmployeeDeparture
290 | InterventionEntityEvent::KeyPersonRisk => {
291 if let Some(node) = self.dag.find_node("processing_lag") {
293 effects.insert(
294 "processing_lag".to_string(),
295 node.baseline_value * (1.0 + 0.2 * onset_factor),
296 );
297 }
298 if let Some(node) = self.dag.find_node("error_rate") {
299 effects.insert(
300 "error_rate".to_string(),
301 node.baseline_value * (1.0 + 0.15 * onset_factor),
302 );
303 }
304 }
305 InterventionEntityEvent::NewVendorOnboarding => {
306 if let Some(node) = self.dag.find_node("transaction_volume") {
308 effects.insert(
309 "transaction_volume".to_string(),
310 node.baseline_value * (1.0 + 0.1 * onset_factor),
311 );
312 }
313 }
314 InterventionEntityEvent::MergerAcquisition => {
315 if let Some(node) = self.dag.find_node("transaction_volume") {
317 effects.insert(
318 "transaction_volume".to_string(),
319 node.baseline_value * (1.0 + 0.5 * onset_factor),
320 );
321 }
322 if let Some(node) = self.dag.find_node("error_rate") {
323 effects.insert(
324 "error_rate".to_string(),
325 node.baseline_value * (1.0 + 0.3 * onset_factor),
326 );
327 }
328 }
329 InterventionEntityEvent::VendorCollusion => {
330 if let Some(node) = self.dag.find_node("misstatement_risk") {
332 effects.insert(
333 "misstatement_risk".to_string(),
334 (node.baseline_value + 0.15 * onset_factor).min(1.0),
335 );
336 }
337 if let Some(node) = self.dag.find_node("control_effectiveness") {
338 effects.insert(
339 "control_effectiveness".to_string(),
340 node.baseline_value * (1.0 - 0.2 * onset_factor),
341 );
342 }
343 }
344 InterventionEntityEvent::CustomerConsolidation => {
345 if let Some(node) = self.dag.find_node("customer_churn_rate") {
347 effects.insert(
348 "customer_churn_rate".to_string(),
349 node.baseline_value + rate_increase * onset_factor,
350 );
351 }
352 }
353 }
354 }
355 InterventionType::Custom(ci) => {
356 for (path, value) in &ci.config_overrides {
358 for node in &self.dag.nodes {
359 if node.config_bindings.contains(path) {
360 if let Some(v) = value.as_f64() {
361 let interpolated =
362 node.baseline_value + (v - node.baseline_value) * onset_factor;
363 effects.insert(node.id.clone(), interpolated);
364 }
365 }
366 }
367 }
368 }
369 InterventionType::ProcessChange(pc) => {
370 use datasynth_core::ProcessChangeType;
371 match pc.subtype {
372 ProcessChangeType::ProcessAutomation => {
373 if let Some(node) = self.dag.find_node("processing_lag") {
375 effects.insert(
376 "processing_lag".to_string(),
377 node.baseline_value * (1.0 - 0.3 * onset_factor),
378 );
379 }
380 if let Some(node) = self.dag.find_node("error_rate") {
381 effects.insert(
382 "error_rate".to_string(),
383 node.baseline_value * (1.0 - 0.2 * onset_factor),
384 );
385 }
386 }
387 ProcessChangeType::ApprovalThresholdChange
388 | ProcessChangeType::NewApprovalLevel => {
389 if let Some(node) = self.dag.find_node("control_effectiveness") {
391 effects.insert(
392 "control_effectiveness".to_string(),
393 (node.baseline_value + 0.1 * onset_factor).min(1.0),
394 );
395 }
396 }
397 ProcessChangeType::PolicyChange => {
398 if let Some(node) = self.dag.find_node("sod_compliance") {
399 effects.insert(
400 "sod_compliance".to_string(),
401 (node.baseline_value + 0.05 * onset_factor).min(1.0),
402 );
403 }
404 }
405 ProcessChangeType::SystemMigration
406 | ProcessChangeType::OutsourcingTransition
407 | ProcessChangeType::ReorganizationRestructuring => {
408 if let Some(node) = self.dag.find_node("processing_lag") {
410 effects.insert(
411 "processing_lag".to_string(),
412 node.baseline_value * (1.0 + 0.15 * onset_factor),
413 );
414 }
415 if let Some(node) = self.dag.find_node("error_rate") {
416 effects.insert(
417 "error_rate".to_string(),
418 node.baseline_value * (1.0 + 0.1 * onset_factor),
419 );
420 }
421 }
422 }
423 }
424 InterventionType::RegulatoryChange(rc) => {
425 let severity = rc
426 .parameters
427 .get("severity")
428 .and_then(|v| v.as_f64())
429 .unwrap_or(0.5);
430 if let Some(node) = self.dag.find_node("sod_compliance") {
432 effects.insert(
433 "sod_compliance".to_string(),
434 (node.baseline_value + severity * 0.2 * onset_factor).min(1.0),
435 );
436 }
437 if let Some(node) = self.dag.find_node("control_effectiveness") {
438 effects.insert(
439 "control_effectiveness".to_string(),
440 (node.baseline_value + severity * 0.15 * onset_factor).min(1.0),
441 );
442 }
443 if let Some(node) = self.dag.find_node("misstatement_risk") {
444 effects.insert(
445 "misstatement_risk".to_string(),
446 node.baseline_value * (1.0 - severity * 0.1 * onset_factor),
447 );
448 }
449 }
450 InterventionType::Composite(comp) => {
451 for child in &comp.children {
452 self.map_intervention_to_nodes(child, onset_factor, effects);
453 }
454 }
455 }
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use datasynth_core::causal_dag::{CausalEdge, CausalNode, NodeCategory, TransferFunction};
463 use datasynth_core::{MacroShockIntervention, MacroShockType};
464 use uuid::Uuid;
465
466 fn make_simple_dag() -> CausalDAG {
467 let mut dag = CausalDAG {
468 nodes: vec![
469 CausalNode {
470 id: "gdp_growth".to_string(),
471 label: "GDP Growth".to_string(),
472 category: NodeCategory::Macro,
473 baseline_value: 0.025,
474 bounds: Some((-0.10, 0.15)),
475 interventionable: true,
476 config_bindings: vec![],
477 },
478 CausalNode {
479 id: "transaction_volume".to_string(),
480 label: "Transaction Volume".to_string(),
481 category: NodeCategory::Operational,
482 baseline_value: 1.0,
483 bounds: Some((0.2, 3.0)),
484 interventionable: true,
485 config_bindings: vec!["transactions.volume_multiplier".to_string()],
486 },
487 CausalNode {
488 id: "error_rate".to_string(),
489 label: "Error Rate".to_string(),
490 category: NodeCategory::Outcome,
491 baseline_value: 0.02,
492 bounds: Some((0.0, 0.30)),
493 interventionable: false,
494 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
495 },
496 ],
497 edges: vec![
498 CausalEdge {
499 from: "gdp_growth".to_string(),
500 to: "transaction_volume".to_string(),
501 transfer: TransferFunction::Linear {
502 coefficient: 0.8,
503 intercept: 0.0,
504 },
505 lag_months: 0,
506 strength: 1.0,
507 mechanism: Some("GDP drives volume".to_string()),
508 },
509 CausalEdge {
510 from: "transaction_volume".to_string(),
511 to: "error_rate".to_string(),
512 transfer: TransferFunction::Linear {
513 coefficient: 0.01,
514 intercept: 0.0,
515 },
516 lag_months: 0,
517 strength: 1.0,
518 mechanism: Some("Volume increases errors".to_string()),
519 },
520 ],
521 topological_order: vec![],
522 };
523 dag.validate().expect("DAG should be valid");
524 dag
525 }
526
527 fn make_intervention(
528 intervention_type: InterventionType,
529 start_month: u32,
530 onset: OnsetType,
531 ) -> Intervention {
532 Intervention {
533 id: Uuid::new_v4(),
534 intervention_type,
535 timing: InterventionTiming {
536 start_month,
537 duration_months: None,
538 onset,
539 ramp_months: Some(3),
540 },
541 label: None,
542 priority: 0,
543 }
544 }
545
546 #[test]
547 fn test_propagation_no_interventions() {
548 let dag = make_simple_dag();
549 let engine = CausalPropagationEngine::new(&dag);
550 let result = engine.propagate(&[], 12).unwrap();
551 assert!(result.changes_by_month.is_empty());
552 }
553
554 #[test]
555 fn test_propagation_sudden_onset() {
556 let dag = make_simple_dag();
557 let engine = CausalPropagationEngine::new(&dag);
558
559 let intervention = make_intervention(
560 InterventionType::MacroShock(MacroShockIntervention {
561 subtype: MacroShockType::Recession,
562 severity: 1.0,
563 preset: None,
564 overrides: {
565 let mut m = HashMap::new();
566 m.insert("gdp_growth".to_string(), -0.02);
567 m
568 },
569 }),
570 3,
571 OnsetType::Sudden,
572 );
573
574 let validated = vec![ValidatedIntervention {
575 intervention,
576 affected_config_paths: vec!["gdp_growth".to_string()],
577 }];
578
579 let result = engine.propagate(&validated, 6).unwrap();
580 assert!(result.changes_by_month.contains_key(&3));
582 assert!(!result.changes_by_month.contains_key(&1));
584 assert!(!result.changes_by_month.contains_key(&2));
585 }
586
587 #[test]
588 fn test_propagation_gradual_onset() {
589 let dag = make_simple_dag();
590 let engine = CausalPropagationEngine::new(&dag);
591
592 let intervention = make_intervention(
593 InterventionType::MacroShock(MacroShockIntervention {
594 subtype: MacroShockType::Recession,
595 severity: 1.0,
596 preset: None,
597 overrides: {
598 let mut m = HashMap::new();
599 m.insert("gdp_growth".to_string(), -0.02);
600 m
601 },
602 }),
603 1,
604 OnsetType::Gradual,
605 );
606
607 let validated = vec![ValidatedIntervention {
608 intervention,
609 affected_config_paths: vec!["gdp_growth".to_string()],
610 }];
611
612 let result = engine.propagate(&validated, 6).unwrap();
613 assert!(result.changes_by_month.contains_key(&2));
617 assert!(result.changes_by_month.contains_key(&4));
618 }
619
620 #[test]
621 fn test_propagation_chain_through_dag() {
622 let dag = make_simple_dag();
623 let engine = CausalPropagationEngine::new(&dag);
624
625 let intervention = make_intervention(
626 InterventionType::MacroShock(MacroShockIntervention {
627 subtype: MacroShockType::Recession,
628 severity: 1.0,
629 preset: None,
630 overrides: {
631 let mut m = HashMap::new();
632 m.insert("gdp_growth".to_string(), -0.05);
633 m
634 },
635 }),
636 1,
637 OnsetType::Sudden,
638 );
639
640 let validated = vec![ValidatedIntervention {
641 intervention,
642 affected_config_paths: vec!["gdp_growth".to_string()],
643 }];
644
645 let result = engine.propagate(&validated, 3).unwrap();
646 if let Some(changes) = result.changes_by_month.get(&1) {
648 let paths: Vec<&str> = changes.iter().map(|c| c.path.as_str()).collect();
649 assert!(
650 paths.contains(&"transactions.volume_multiplier")
651 || paths.contains(&"anomaly_injection.base_rate")
652 );
653 }
654 }
655
656 #[test]
657 fn test_propagation_lag_respected() {
658 let mut dag = CausalDAG {
659 nodes: vec![
660 CausalNode {
661 id: "a".to_string(),
662 label: "A".to_string(),
663 category: NodeCategory::Macro,
664 baseline_value: 1.0,
665 bounds: None,
666 interventionable: true,
667 config_bindings: vec![],
668 },
669 CausalNode {
670 id: "b".to_string(),
671 label: "B".to_string(),
672 category: NodeCategory::Operational,
673 baseline_value: 0.0,
674 bounds: None,
675 interventionable: false,
676 config_bindings: vec!["test.path".to_string()],
677 },
678 ],
679 edges: vec![CausalEdge {
680 from: "a".to_string(),
681 to: "b".to_string(),
682 transfer: TransferFunction::Linear {
683 coefficient: 1.0,
684 intercept: 0.0,
685 },
686 lag_months: 3,
687 strength: 1.0,
688 mechanism: None,
689 }],
690 topological_order: vec![],
691 };
692 dag.validate().expect("DAG should be valid");
693
694 let engine = CausalPropagationEngine::new(&dag);
695
696 let intervention_type = InterventionType::Custom(datasynth_core::CustomIntervention {
697 name: "test".to_string(),
698 config_overrides: HashMap::new(),
699 downstream_triggers: vec![],
700 });
701
702 let intervention = Intervention {
704 id: Uuid::new_v4(),
705 intervention_type,
706 timing: InterventionTiming {
707 start_month: 1,
708 duration_months: None,
709 onset: OnsetType::Sudden,
710 ramp_months: None,
711 },
712 label: None,
713 priority: 0,
714 };
715
716 let validated = vec![ValidatedIntervention {
717 intervention,
718 affected_config_paths: vec![],
719 }];
720
721 let result = engine.propagate(&validated, 6).unwrap();
722 assert!(result.changes_by_month.is_empty() || true);
725 }
726
727 #[test]
728 fn test_propagation_node_bounds_clamped() {
729 let dag = make_simple_dag();
730 let engine = CausalPropagationEngine::new(&dag);
731
732 let intervention = make_intervention(
733 InterventionType::MacroShock(MacroShockIntervention {
734 subtype: MacroShockType::Recession,
735 severity: 5.0, preset: None,
737 overrides: {
738 let mut m = HashMap::new();
739 m.insert("gdp_growth".to_string(), -0.20);
740 m
741 },
742 }),
743 1,
744 OnsetType::Sudden,
745 );
746
747 let validated = vec![ValidatedIntervention {
748 intervention,
749 affected_config_paths: vec!["gdp_growth".to_string()],
750 }];
751
752 let result = engine.propagate(&validated, 3).unwrap();
753 assert!(!result.changes_by_month.is_empty());
756 }
757
758 fn make_dag_with_operational_nodes() -> CausalDAG {
759 let mut dag = CausalDAG {
760 nodes: vec![
761 CausalNode {
762 id: "processing_lag".to_string(),
763 label: "Processing Lag".to_string(),
764 category: NodeCategory::Operational,
765 baseline_value: 2.0,
766 bounds: Some((0.5, 10.0)),
767 interventionable: true,
768 config_bindings: vec!["temporal_patterns.processing_lags.base_mu".to_string()],
769 },
770 CausalNode {
771 id: "error_rate".to_string(),
772 label: "Error Rate".to_string(),
773 category: NodeCategory::Outcome,
774 baseline_value: 0.02,
775 bounds: Some((0.0, 0.30)),
776 interventionable: false,
777 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
778 },
779 CausalNode {
780 id: "control_effectiveness".to_string(),
781 label: "Control Effectiveness".to_string(),
782 category: NodeCategory::Operational,
783 baseline_value: 0.85,
784 bounds: Some((0.0, 1.0)),
785 interventionable: true,
786 config_bindings: vec!["internal_controls.exception_rate".to_string()],
787 },
788 CausalNode {
789 id: "sod_compliance".to_string(),
790 label: "SoD Compliance".to_string(),
791 category: NodeCategory::Operational,
792 baseline_value: 0.90,
793 bounds: Some((0.0, 1.0)),
794 interventionable: true,
795 config_bindings: vec!["internal_controls.sod_violation_rate".to_string()],
796 },
797 CausalNode {
798 id: "misstatement_risk".to_string(),
799 label: "Misstatement Risk".to_string(),
800 category: NodeCategory::Outcome,
801 baseline_value: 0.05,
802 bounds: Some((0.0, 1.0)),
803 interventionable: false,
804 config_bindings: vec!["fraud.fraud_rate".to_string()],
805 },
806 ],
807 edges: vec![CausalEdge {
808 from: "processing_lag".to_string(),
809 to: "error_rate".to_string(),
810 transfer: TransferFunction::Linear {
811 coefficient: 0.01,
812 intercept: 0.0,
813 },
814 lag_months: 0,
815 strength: 1.0,
816 mechanism: Some("Lag increases errors".to_string()),
817 }],
818 topological_order: vec![],
819 };
820 dag.validate().expect("DAG should be valid");
821 dag
822 }
823
824 #[test]
825 fn test_propagation_process_change_automation() {
826 let dag = make_dag_with_operational_nodes();
827 let engine = CausalPropagationEngine::new(&dag);
828
829 let intervention = make_intervention(
830 InterventionType::ProcessChange(datasynth_core::ProcessChangeIntervention {
831 subtype: datasynth_core::ProcessChangeType::ProcessAutomation,
832 parameters: HashMap::new(),
833 }),
834 1,
835 OnsetType::Sudden,
836 );
837
838 let validated = vec![ValidatedIntervention {
839 intervention,
840 affected_config_paths: vec![],
841 }];
842
843 let result = engine.propagate(&validated, 3).unwrap();
844 assert!(!result.changes_by_month.is_empty());
846 if let Some(changes) = result.changes_by_month.get(&1) {
847 let lag_change = changes.iter().find(|c| c.source_node == "processing_lag");
848 assert!(lag_change.is_some(), "Should have processing_lag change");
849 }
850 }
851
852 #[test]
853 fn test_propagation_regulatory_change() {
854 let dag = make_dag_with_operational_nodes();
855 let engine = CausalPropagationEngine::new(&dag);
856
857 let mut params = HashMap::new();
858 params.insert("severity".to_string(), serde_json::json!(0.8));
859
860 let intervention = make_intervention(
861 InterventionType::RegulatoryChange(datasynth_core::RegulatoryChangeIntervention {
862 subtype: datasynth_core::RegulatoryChangeType::NewStandardAdoption,
863 parameters: params,
864 }),
865 1,
866 OnsetType::Sudden,
867 );
868
869 let validated = vec![ValidatedIntervention {
870 intervention,
871 affected_config_paths: vec![],
872 }];
873
874 let result = engine.propagate(&validated, 3).unwrap();
875 assert!(!result.changes_by_month.is_empty());
877 }
878
879 #[test]
880 fn test_propagation_entity_event_employee_departure() {
881 let dag = make_dag_with_operational_nodes();
882 let engine = CausalPropagationEngine::new(&dag);
883
884 let intervention = make_intervention(
885 InterventionType::EntityEvent(datasynth_core::EntityEventIntervention {
886 subtype: datasynth_core::InterventionEntityEvent::EmployeeDeparture,
887 target: datasynth_core::EntityTarget {
888 cluster: None,
889 entity_ids: None,
890 filter: None,
891 count: Some(3),
892 fraction: None,
893 },
894 parameters: HashMap::new(),
895 }),
896 1,
897 OnsetType::Sudden,
898 );
899
900 let validated = vec![ValidatedIntervention {
901 intervention,
902 affected_config_paths: vec![],
903 }];
904
905 let result = engine.propagate(&validated, 2).unwrap();
906 assert!(!result.changes_by_month.is_empty());
908 }
909
910 #[test]
911 fn test_propagation_process_change_system_migration() {
912 let dag = make_dag_with_operational_nodes();
913 let engine = CausalPropagationEngine::new(&dag);
914
915 let intervention = make_intervention(
916 InterventionType::ProcessChange(datasynth_core::ProcessChangeIntervention {
917 subtype: datasynth_core::ProcessChangeType::SystemMigration,
918 parameters: HashMap::new(),
919 }),
920 1,
921 OnsetType::Sudden,
922 );
923
924 let validated = vec![ValidatedIntervention {
925 intervention,
926 affected_config_paths: vec![],
927 }];
928
929 let result = engine.propagate(&validated, 2).unwrap();
930 assert!(!result.changes_by_month.is_empty());
932 if let Some(changes) = result.changes_by_month.get(&1) {
933 let lag_change = changes.iter().find(|c| c.source_node == "processing_lag");
934 assert!(lag_change.is_some(), "Should have processing_lag change");
935 }
936 }
937}