1use datasynth_core::causal_dag::{CausalDAG, CausalDAGError};
7use datasynth_core::{Intervention, InterventionTiming, InterventionType, OnsetType};
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap};
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
14pub struct ValidatedIntervention {
15 pub intervention: Intervention,
16 pub affected_config_paths: Vec<String>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct PropagatedInterventions {
22 pub changes_by_month: BTreeMap<u32, Vec<ConfigChange>>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ConfigChange {
28 pub path: String,
30 pub value: serde_json::Value,
32 pub source_node: String,
34 pub is_direct: bool,
36}
37
38#[derive(Debug, Error)]
40pub enum PropagationError {
41 #[error("DAG validation failed: {0}")]
42 DagValidation(#[from] CausalDAGError),
43 #[error("no causal node mapping for intervention target: {0}")]
44 NoNodeMapping(String),
45}
46
47pub struct CausalPropagationEngine<'a> {
49 dag: &'a CausalDAG,
50}
51
52impl<'a> CausalPropagationEngine<'a> {
53 pub fn new(dag: &'a CausalDAG) -> Self {
54 Self { dag }
55 }
56
57 pub fn propagate(
59 &self,
60 interventions: &[ValidatedIntervention],
61 period_months: u32,
62 ) -> Result<PropagatedInterventions, PropagationError> {
63 let mut result = PropagatedInterventions::default();
64
65 for month in 1..=period_months {
66 let direct = self.compute_direct_effects(interventions, month);
68
69 if direct.is_empty() {
70 continue;
71 }
72
73 let propagated_values = self.dag.propagate(&direct, month);
75
76 let mut changes = Vec::new();
78 for (node_id, value) in &propagated_values {
79 if let Some(node) = self.dag.find_node(node_id) {
80 if (value - node.baseline_value).abs() < f64::EPSILON {
82 continue;
83 }
84
85 let is_direct = direct.contains_key(node_id);
86 for binding in &node.config_bindings {
87 changes.push(ConfigChange {
88 path: binding.clone(),
89 value: serde_json::Value::from(*value),
90 source_node: node_id.clone(),
91 is_direct,
92 });
93 }
94 }
95 }
96
97 if !changes.is_empty() {
98 result.changes_by_month.insert(month, changes);
99 }
100 }
101
102 Ok(result)
103 }
104
105 fn compute_direct_effects(
107 &self,
108 interventions: &[ValidatedIntervention],
109 month: u32,
110 ) -> HashMap<String, f64> {
111 let mut effects = HashMap::new();
112
113 for validated in interventions {
114 let timing = &validated.intervention.timing;
115
116 if !Self::is_active(timing, month) {
118 continue;
119 }
120
121 let onset_factor = Self::compute_onset_factor(timing, month);
123
124 self.map_intervention_to_nodes(
126 &validated.intervention.intervention_type,
127 onset_factor,
128 &mut effects,
129 );
130 }
131
132 effects
133 }
134
135 fn is_active(timing: &InterventionTiming, month: u32) -> bool {
137 if month < timing.start_month {
138 return false;
139 }
140 if let Some(duration) = timing.duration_months {
141 if month >= timing.start_month + duration {
142 return false;
143 }
144 }
145 true
146 }
147
148 fn compute_onset_factor(timing: &InterventionTiming, month: u32) -> f64 {
150 let months_active = month - timing.start_month;
151
152 match &timing.onset {
153 OnsetType::Sudden => 1.0,
154 OnsetType::Gradual => {
155 let ramp = timing.ramp_months.unwrap_or(1).max(1);
156 if months_active >= ramp {
157 1.0
158 } else {
159 months_active as f64 / ramp as f64
160 }
161 }
162 OnsetType::Oscillating => {
163 let ramp = timing.ramp_months.unwrap_or(4).max(1) as f64;
164 let phase = months_active as f64 / ramp;
165 0.5 * (1.0 - (std::f64::consts::PI * phase).cos())
167 }
168 OnsetType::Custom { .. } => {
169 let ramp = timing.ramp_months.unwrap_or(1).max(1);
171 if months_active >= ramp {
172 1.0
173 } else {
174 months_active as f64 / ramp as f64
175 }
176 }
177 }
178 }
179
180 fn map_intervention_to_nodes(
182 &self,
183 intervention_type: &InterventionType,
184 onset_factor: f64,
185 effects: &mut HashMap<String, f64>,
186 ) {
187 match intervention_type {
188 InterventionType::ParameterShift(ps) => {
189 for node in &self.dag.nodes {
191 if node.config_bindings.contains(&ps.target) {
192 if let Some(to_val) = ps.to.as_f64() {
193 let from_val = ps
194 .from
195 .as_ref()
196 .and_then(serde_json::Value::as_f64)
197 .unwrap_or(node.baseline_value);
198 let interpolated = from_val + (to_val - from_val) * onset_factor;
199 effects.insert(node.id.clone(), interpolated);
200 }
201 }
202 }
203 }
204 InterventionType::MacroShock(ms) => {
205 use datasynth_core::MacroShockType;
207 let severity = ms.severity * onset_factor;
208 match ms.subtype {
209 MacroShockType::Recession => {
210 if let Some(node) = self.dag.find_node("gdp_growth") {
211 let shock = ms.overrides.get("gdp_growth").copied().unwrap_or(-0.02);
212 effects.insert(
213 "gdp_growth".to_string(),
214 node.baseline_value + shock * severity,
215 );
216 }
217 if let Some(node) = self.dag.find_node("unemployment_rate") {
218 let shock = ms
219 .overrides
220 .get("unemployment_rate")
221 .copied()
222 .unwrap_or(0.03);
223 effects.insert(
224 "unemployment_rate".to_string(),
225 node.baseline_value + shock * severity,
226 );
227 }
228 }
229 MacroShockType::InflationSpike => {
230 if let Some(node) = self.dag.find_node("inflation_rate") {
231 let shock = ms.overrides.get("inflation_rate").copied().unwrap_or(0.05);
232 effects.insert(
233 "inflation_rate".to_string(),
234 node.baseline_value + shock * severity,
235 );
236 }
237 }
238 MacroShockType::InterestRateShock => {
239 if let Some(node) = self.dag.find_node("interest_rate") {
240 let shock = ms.overrides.get("interest_rate").copied().unwrap_or(0.03);
241 effects.insert(
242 "interest_rate".to_string(),
243 node.baseline_value + shock * severity,
244 );
245 }
246 }
247 MacroShockType::CreditCrunch => {
248 if let Some(node) = self.dag.find_node("gdp_growth") {
249 effects.insert(
250 "gdp_growth".to_string(),
251 node.baseline_value * (1.0 - 0.1 * severity),
252 );
253 }
254 if let Some(node) = self.dag.find_node("ecl_provision_rate") {
255 effects.insert(
256 "ecl_provision_rate".to_string(),
257 node.baseline_value + severity * 0.5,
258 );
259 }
260 if let Some(node) = self.dag.find_node("going_concern_risk") {
261 effects.insert(
262 "going_concern_risk".to_string(),
263 node.baseline_value + severity * 0.3,
264 );
265 }
266 if let Some(node) = self.dag.find_node("debt_ratio") {
267 effects.insert(
268 "debt_ratio".to_string(),
269 node.baseline_value + severity * 0.4,
270 );
271 }
272 }
273 _ => {
274 if let Some(node) = self.dag.find_node("gdp_growth") {
276 effects.insert(
277 "gdp_growth".to_string(),
278 node.baseline_value * (1.0 - 0.1 * severity),
279 );
280 }
281 }
282 }
283 }
284 InterventionType::ControlFailure(cf) => {
285 if let Some(node) = self.dag.find_node("control_effectiveness") {
286 let new_effectiveness = node.baseline_value * cf.severity * onset_factor
287 + node.baseline_value * (1.0 - onset_factor);
288 effects.insert("control_effectiveness".to_string(), new_effectiveness);
289 }
290 }
291 InterventionType::EntityEvent(ee) => {
292 use datasynth_core::InterventionEntityEvent;
293 let rate_increase = ee
294 .parameters
295 .get("rate_increase")
296 .and_then(serde_json::Value::as_f64)
297 .unwrap_or(0.05);
298 match ee.subtype {
299 InterventionEntityEvent::VendorDefault => {
300 if let Some(node) = self.dag.find_node("vendor_default_rate") {
301 effects.insert(
302 "vendor_default_rate".to_string(),
303 node.baseline_value + rate_increase * onset_factor,
304 );
305 }
306 }
307 InterventionEntityEvent::CustomerChurn => {
308 if let Some(node) = self.dag.find_node("customer_churn_rate") {
309 effects.insert(
310 "customer_churn_rate".to_string(),
311 node.baseline_value + rate_increase * onset_factor,
312 );
313 }
314 }
315 InterventionEntityEvent::EmployeeDeparture
316 | InterventionEntityEvent::KeyPersonRisk => {
317 if let Some(node) = self.dag.find_node("processing_lag") {
319 effects.insert(
320 "processing_lag".to_string(),
321 node.baseline_value * (1.0 + 0.2 * onset_factor),
322 );
323 }
324 if let Some(node) = self.dag.find_node("error_rate") {
325 effects.insert(
326 "error_rate".to_string(),
327 node.baseline_value * (1.0 + 0.15 * onset_factor),
328 );
329 }
330 }
331 InterventionEntityEvent::NewVendorOnboarding => {
332 if let Some(node) = self.dag.find_node("transaction_volume") {
334 effects.insert(
335 "transaction_volume".to_string(),
336 node.baseline_value * (1.0 + 0.1 * onset_factor),
337 );
338 }
339 }
340 InterventionEntityEvent::MergerAcquisition => {
341 if let Some(node) = self.dag.find_node("transaction_volume") {
343 effects.insert(
344 "transaction_volume".to_string(),
345 node.baseline_value * (1.0 + 0.5 * onset_factor),
346 );
347 }
348 if let Some(node) = self.dag.find_node("error_rate") {
349 effects.insert(
350 "error_rate".to_string(),
351 node.baseline_value * (1.0 + 0.3 * onset_factor),
352 );
353 }
354 }
355 InterventionEntityEvent::VendorCollusion => {
356 if let Some(node) = self.dag.find_node("misstatement_risk") {
358 effects.insert(
359 "misstatement_risk".to_string(),
360 (node.baseline_value + 0.15 * onset_factor).min(1.0),
361 );
362 }
363 if let Some(node) = self.dag.find_node("control_effectiveness") {
364 effects.insert(
365 "control_effectiveness".to_string(),
366 node.baseline_value * (1.0 - 0.2 * onset_factor),
367 );
368 }
369 }
370 InterventionEntityEvent::CustomerConsolidation => {
371 if let Some(node) = self.dag.find_node("customer_churn_rate") {
373 effects.insert(
374 "customer_churn_rate".to_string(),
375 node.baseline_value + rate_increase * onset_factor,
376 );
377 }
378 }
379 }
380 }
381 InterventionType::Custom(ci) => {
382 for (path, value) in &ci.config_overrides {
384 for node in &self.dag.nodes {
385 if node.config_bindings.contains(path) {
386 if let Some(v) = value.as_f64() {
387 let interpolated =
388 node.baseline_value + (v - node.baseline_value) * onset_factor;
389 effects.insert(node.id.clone(), interpolated);
390 }
391 }
392 }
393 }
394 }
395 InterventionType::ProcessChange(pc) => {
396 use datasynth_core::ProcessChangeType;
397 match pc.subtype {
398 ProcessChangeType::ProcessAutomation => {
399 if let Some(node) = self.dag.find_node("processing_lag") {
401 effects.insert(
402 "processing_lag".to_string(),
403 node.baseline_value * (1.0 - 0.3 * onset_factor),
404 );
405 }
406 if let Some(node) = self.dag.find_node("error_rate") {
407 effects.insert(
408 "error_rate".to_string(),
409 node.baseline_value * (1.0 - 0.2 * onset_factor),
410 );
411 }
412 }
413 ProcessChangeType::ApprovalThresholdChange
414 | ProcessChangeType::NewApprovalLevel => {
415 if let Some(node) = self.dag.find_node("control_effectiveness") {
417 effects.insert(
418 "control_effectiveness".to_string(),
419 (node.baseline_value + 0.1 * onset_factor).min(1.0),
420 );
421 }
422 }
423 ProcessChangeType::PolicyChange => {
424 if let Some(node) = self.dag.find_node("sod_compliance") {
425 effects.insert(
426 "sod_compliance".to_string(),
427 (node.baseline_value + 0.05 * onset_factor).min(1.0),
428 );
429 }
430 }
431 ProcessChangeType::SystemMigration
432 | ProcessChangeType::OutsourcingTransition
433 | ProcessChangeType::ReorganizationRestructuring => {
434 if let Some(node) = self.dag.find_node("processing_lag") {
436 effects.insert(
437 "processing_lag".to_string(),
438 node.baseline_value * (1.0 + 0.15 * onset_factor),
439 );
440 }
441 if let Some(node) = self.dag.find_node("error_rate") {
442 effects.insert(
443 "error_rate".to_string(),
444 node.baseline_value * (1.0 + 0.1 * onset_factor),
445 );
446 }
447 }
448 }
449 }
450 InterventionType::RegulatoryChange(rc) => {
451 use datasynth_core::RegulatoryChangeType;
452 let severity = rc
453 .parameters
454 .get("severity")
455 .and_then(serde_json::Value::as_f64)
456 .unwrap_or(0.5);
457 let magnitude = severity * onset_factor;
458 match rc.subtype {
459 RegulatoryChangeType::MaterialityThresholdChange => {
460 if let Some(node) = self.dag.find_node("materiality_threshold") {
461 effects.insert(
462 "materiality_threshold".to_string(),
463 node.baseline_value + magnitude,
464 );
465 }
466 if let Some(node) = self.dag.find_node("sample_size_factor") {
467 effects.insert(
468 "sample_size_factor".to_string(),
469 node.baseline_value + magnitude * 0.5,
470 );
471 }
472 }
473 RegulatoryChangeType::AuditStandardChange => {
474 if let Some(node) = self.dag.find_node("inherent_risk") {
475 effects.insert(
476 "inherent_risk".to_string(),
477 node.baseline_value + magnitude * 0.3,
478 );
479 }
480 if let Some(node) = self.dag.find_node("sample_size_factor") {
481 effects.insert(
482 "sample_size_factor".to_string(),
483 node.baseline_value + magnitude * 0.4,
484 );
485 }
486 }
487 RegulatoryChangeType::TaxRateChange => {
488 if let Some(node) = self.dag.find_node("tax_rate") {
489 effects.insert("tax_rate".to_string(), node.baseline_value + magnitude);
490 }
491 }
492 _ => {
493 if let Some(node) = self.dag.find_node("sod_compliance") {
495 effects.insert(
496 "sod_compliance".to_string(),
497 (node.baseline_value + severity * 0.2 * onset_factor).min(1.0),
498 );
499 }
500 if let Some(node) = self.dag.find_node("control_effectiveness") {
501 effects.insert(
502 "control_effectiveness".to_string(),
503 (node.baseline_value + severity * 0.15 * onset_factor).min(1.0),
504 );
505 }
506 if let Some(node) = self.dag.find_node("misstatement_risk") {
507 effects.insert(
508 "misstatement_risk".to_string(),
509 node.baseline_value * (1.0 - severity * 0.1 * onset_factor),
510 );
511 }
512 }
513 }
514 }
515 InterventionType::Composite(comp) => {
516 for child in &comp.children {
517 self.map_intervention_to_nodes(child, onset_factor, effects);
518 }
519 }
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use datasynth_core::causal_dag::{CausalEdge, CausalNode, NodeCategory, TransferFunction};
528 use datasynth_core::{MacroShockIntervention, MacroShockType};
529 use uuid::Uuid;
530
531 fn make_simple_dag() -> CausalDAG {
532 let mut dag = CausalDAG {
533 nodes: vec![
534 CausalNode {
535 id: "gdp_growth".to_string(),
536 label: "GDP Growth".to_string(),
537 category: NodeCategory::Macro,
538 baseline_value: 0.025,
539 bounds: Some((-0.10, 0.15)),
540 interventionable: true,
541 config_bindings: vec![],
542 },
543 CausalNode {
544 id: "transaction_volume".to_string(),
545 label: "Transaction Volume".to_string(),
546 category: NodeCategory::Operational,
547 baseline_value: 1.0,
548 bounds: Some((0.2, 3.0)),
549 interventionable: true,
550 config_bindings: vec!["transactions.volume_multiplier".to_string()],
551 },
552 CausalNode {
553 id: "error_rate".to_string(),
554 label: "Error Rate".to_string(),
555 category: NodeCategory::Outcome,
556 baseline_value: 0.02,
557 bounds: Some((0.0, 0.30)),
558 interventionable: false,
559 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
560 },
561 ],
562 edges: vec![
563 CausalEdge {
564 from: "gdp_growth".to_string(),
565 to: "transaction_volume".to_string(),
566 transfer: TransferFunction::Linear {
567 coefficient: 0.8,
568 intercept: 0.0,
569 },
570 lag_months: 0,
571 strength: 1.0,
572 mechanism: Some("GDP drives volume".to_string()),
573 },
574 CausalEdge {
575 from: "transaction_volume".to_string(),
576 to: "error_rate".to_string(),
577 transfer: TransferFunction::Linear {
578 coefficient: 0.01,
579 intercept: 0.0,
580 },
581 lag_months: 0,
582 strength: 1.0,
583 mechanism: Some("Volume increases errors".to_string()),
584 },
585 ],
586 topological_order: vec![],
587 };
588 dag.validate().expect("DAG should be valid");
589 dag
590 }
591
592 fn make_intervention(
593 intervention_type: InterventionType,
594 start_month: u32,
595 onset: OnsetType,
596 ) -> Intervention {
597 Intervention {
598 id: Uuid::new_v4(),
599 intervention_type,
600 timing: InterventionTiming {
601 start_month,
602 duration_months: None,
603 onset,
604 ramp_months: Some(3),
605 },
606 label: None,
607 priority: 0,
608 }
609 }
610
611 #[test]
612 fn test_propagation_no_interventions() {
613 let dag = make_simple_dag();
614 let engine = CausalPropagationEngine::new(&dag);
615 let result = engine.propagate(&[], 12).unwrap();
616 assert!(result.changes_by_month.is_empty());
617 }
618
619 #[test]
620 fn test_propagation_sudden_onset() {
621 let dag = make_simple_dag();
622 let engine = CausalPropagationEngine::new(&dag);
623
624 let intervention = make_intervention(
625 InterventionType::MacroShock(MacroShockIntervention {
626 subtype: MacroShockType::Recession,
627 severity: 1.0,
628 preset: None,
629 overrides: {
630 let mut m = HashMap::new();
631 m.insert("gdp_growth".to_string(), -0.02);
632 m
633 },
634 }),
635 3,
636 OnsetType::Sudden,
637 );
638
639 let validated = vec![ValidatedIntervention {
640 intervention,
641 affected_config_paths: vec!["gdp_growth".to_string()],
642 }];
643
644 let result = engine.propagate(&validated, 6).unwrap();
645 assert!(result.changes_by_month.contains_key(&3));
647 assert!(!result.changes_by_month.contains_key(&1));
649 assert!(!result.changes_by_month.contains_key(&2));
650 }
651
652 #[test]
653 fn test_propagation_gradual_onset() {
654 let dag = make_simple_dag();
655 let engine = CausalPropagationEngine::new(&dag);
656
657 let intervention = make_intervention(
658 InterventionType::MacroShock(MacroShockIntervention {
659 subtype: MacroShockType::Recession,
660 severity: 1.0,
661 preset: None,
662 overrides: {
663 let mut m = HashMap::new();
664 m.insert("gdp_growth".to_string(), -0.02);
665 m
666 },
667 }),
668 1,
669 OnsetType::Gradual,
670 );
671
672 let validated = vec![ValidatedIntervention {
673 intervention,
674 affected_config_paths: vec!["gdp_growth".to_string()],
675 }];
676
677 let result = engine.propagate(&validated, 6).unwrap();
678 assert!(result.changes_by_month.contains_key(&2));
682 assert!(result.changes_by_month.contains_key(&4));
683 }
684
685 #[test]
686 fn test_propagation_chain_through_dag() {
687 let dag = make_simple_dag();
688 let engine = CausalPropagationEngine::new(&dag);
689
690 let intervention = make_intervention(
691 InterventionType::MacroShock(MacroShockIntervention {
692 subtype: MacroShockType::Recession,
693 severity: 1.0,
694 preset: None,
695 overrides: {
696 let mut m = HashMap::new();
697 m.insert("gdp_growth".to_string(), -0.05);
698 m
699 },
700 }),
701 1,
702 OnsetType::Sudden,
703 );
704
705 let validated = vec![ValidatedIntervention {
706 intervention,
707 affected_config_paths: vec!["gdp_growth".to_string()],
708 }];
709
710 let result = engine.propagate(&validated, 3).unwrap();
711 if let Some(changes) = result.changes_by_month.get(&1) {
713 let paths: Vec<&str> = changes.iter().map(|c| c.path.as_str()).collect();
714 assert!(
715 paths.contains(&"transactions.volume_multiplier")
716 || paths.contains(&"anomaly_injection.base_rate")
717 );
718 }
719 }
720
721 #[test]
722 fn test_propagation_lag_respected() {
723 let mut dag = CausalDAG {
724 nodes: vec![
725 CausalNode {
726 id: "a".to_string(),
727 label: "A".to_string(),
728 category: NodeCategory::Macro,
729 baseline_value: 1.0,
730 bounds: None,
731 interventionable: true,
732 config_bindings: vec![],
733 },
734 CausalNode {
735 id: "b".to_string(),
736 label: "B".to_string(),
737 category: NodeCategory::Operational,
738 baseline_value: 0.0,
739 bounds: None,
740 interventionable: false,
741 config_bindings: vec!["test.path".to_string()],
742 },
743 ],
744 edges: vec![CausalEdge {
745 from: "a".to_string(),
746 to: "b".to_string(),
747 transfer: TransferFunction::Linear {
748 coefficient: 1.0,
749 intercept: 0.0,
750 },
751 lag_months: 3,
752 strength: 1.0,
753 mechanism: None,
754 }],
755 topological_order: vec![],
756 };
757 dag.validate().expect("DAG should be valid");
758
759 let engine = CausalPropagationEngine::new(&dag);
760
761 let intervention_type = InterventionType::Custom(datasynth_core::CustomIntervention {
762 name: "test".to_string(),
763 config_overrides: HashMap::new(),
764 downstream_triggers: vec![],
765 });
766
767 let intervention = Intervention {
769 id: Uuid::new_v4(),
770 intervention_type,
771 timing: InterventionTiming {
772 start_month: 1,
773 duration_months: None,
774 onset: OnsetType::Sudden,
775 ramp_months: None,
776 },
777 label: None,
778 priority: 0,
779 };
780
781 let validated = vec![ValidatedIntervention {
782 intervention,
783 affected_config_paths: vec![],
784 }];
785
786 let result = engine.propagate(&validated, 6).unwrap();
787 assert!(result.changes_by_month.is_empty() || !result.changes_by_month.is_empty());
790 }
791
792 #[test]
793 fn test_propagation_node_bounds_clamped() {
794 let dag = make_simple_dag();
795 let engine = CausalPropagationEngine::new(&dag);
796
797 let intervention = make_intervention(
798 InterventionType::MacroShock(MacroShockIntervention {
799 subtype: MacroShockType::Recession,
800 severity: 5.0, preset: None,
802 overrides: {
803 let mut m = HashMap::new();
804 m.insert("gdp_growth".to_string(), -0.20);
805 m
806 },
807 }),
808 1,
809 OnsetType::Sudden,
810 );
811
812 let validated = vec![ValidatedIntervention {
813 intervention,
814 affected_config_paths: vec!["gdp_growth".to_string()],
815 }];
816
817 let result = engine.propagate(&validated, 3).unwrap();
818 assert!(!result.changes_by_month.is_empty());
821 }
822
823 fn make_dag_with_operational_nodes() -> CausalDAG {
824 let mut dag = CausalDAG {
825 nodes: vec![
826 CausalNode {
827 id: "processing_lag".to_string(),
828 label: "Processing Lag".to_string(),
829 category: NodeCategory::Operational,
830 baseline_value: 2.0,
831 bounds: Some((0.5, 10.0)),
832 interventionable: true,
833 config_bindings: vec!["temporal_patterns.processing_lags.base_mu".to_string()],
834 },
835 CausalNode {
836 id: "error_rate".to_string(),
837 label: "Error Rate".to_string(),
838 category: NodeCategory::Outcome,
839 baseline_value: 0.02,
840 bounds: Some((0.0, 0.30)),
841 interventionable: false,
842 config_bindings: vec!["anomaly_injection.base_rate".to_string()],
843 },
844 CausalNode {
845 id: "control_effectiveness".to_string(),
846 label: "Control Effectiveness".to_string(),
847 category: NodeCategory::Operational,
848 baseline_value: 0.85,
849 bounds: Some((0.0, 1.0)),
850 interventionable: true,
851 config_bindings: vec!["internal_controls.exception_rate".to_string()],
852 },
853 CausalNode {
854 id: "sod_compliance".to_string(),
855 label: "SoD Compliance".to_string(),
856 category: NodeCategory::Operational,
857 baseline_value: 0.90,
858 bounds: Some((0.0, 1.0)),
859 interventionable: true,
860 config_bindings: vec!["internal_controls.sod_violation_rate".to_string()],
861 },
862 CausalNode {
863 id: "misstatement_risk".to_string(),
864 label: "Misstatement Risk".to_string(),
865 category: NodeCategory::Outcome,
866 baseline_value: 0.05,
867 bounds: Some((0.0, 1.0)),
868 interventionable: false,
869 config_bindings: vec!["fraud.fraud_rate".to_string()],
870 },
871 ],
872 edges: vec![CausalEdge {
873 from: "processing_lag".to_string(),
874 to: "error_rate".to_string(),
875 transfer: TransferFunction::Linear {
876 coefficient: 0.01,
877 intercept: 0.0,
878 },
879 lag_months: 0,
880 strength: 1.0,
881 mechanism: Some("Lag increases errors".to_string()),
882 }],
883 topological_order: vec![],
884 };
885 dag.validate().expect("DAG should be valid");
886 dag
887 }
888
889 #[test]
890 fn test_propagation_process_change_automation() {
891 let dag = make_dag_with_operational_nodes();
892 let engine = CausalPropagationEngine::new(&dag);
893
894 let intervention = make_intervention(
895 InterventionType::ProcessChange(datasynth_core::ProcessChangeIntervention {
896 subtype: datasynth_core::ProcessChangeType::ProcessAutomation,
897 parameters: HashMap::new(),
898 }),
899 1,
900 OnsetType::Sudden,
901 );
902
903 let validated = vec![ValidatedIntervention {
904 intervention,
905 affected_config_paths: vec![],
906 }];
907
908 let result = engine.propagate(&validated, 3).unwrap();
909 assert!(!result.changes_by_month.is_empty());
911 if let Some(changes) = result.changes_by_month.get(&1) {
912 let lag_change = changes.iter().find(|c| c.source_node == "processing_lag");
913 assert!(lag_change.is_some(), "Should have processing_lag change");
914 }
915 }
916
917 #[test]
918 fn test_propagation_regulatory_change() {
919 let dag = make_dag_with_operational_nodes();
920 let engine = CausalPropagationEngine::new(&dag);
921
922 let mut params = HashMap::new();
923 params.insert("severity".to_string(), serde_json::json!(0.8));
924
925 let intervention = make_intervention(
926 InterventionType::RegulatoryChange(datasynth_core::RegulatoryChangeIntervention {
927 subtype: datasynth_core::RegulatoryChangeType::NewStandardAdoption,
928 parameters: params,
929 }),
930 1,
931 OnsetType::Sudden,
932 );
933
934 let validated = vec![ValidatedIntervention {
935 intervention,
936 affected_config_paths: vec![],
937 }];
938
939 let result = engine.propagate(&validated, 3).unwrap();
940 assert!(!result.changes_by_month.is_empty());
942 }
943
944 #[test]
945 fn test_propagation_entity_event_employee_departure() {
946 let dag = make_dag_with_operational_nodes();
947 let engine = CausalPropagationEngine::new(&dag);
948
949 let intervention = make_intervention(
950 InterventionType::EntityEvent(datasynth_core::EntityEventIntervention {
951 subtype: datasynth_core::InterventionEntityEvent::EmployeeDeparture,
952 target: datasynth_core::EntityTarget {
953 cluster: None,
954 entity_ids: None,
955 filter: None,
956 count: Some(3),
957 fraction: None,
958 },
959 parameters: HashMap::new(),
960 }),
961 1,
962 OnsetType::Sudden,
963 );
964
965 let validated = vec![ValidatedIntervention {
966 intervention,
967 affected_config_paths: vec![],
968 }];
969
970 let result = engine.propagate(&validated, 2).unwrap();
971 assert!(!result.changes_by_month.is_empty());
973 }
974
975 #[test]
976 fn test_propagation_process_change_system_migration() {
977 let dag = make_dag_with_operational_nodes();
978 let engine = CausalPropagationEngine::new(&dag);
979
980 let intervention = make_intervention(
981 InterventionType::ProcessChange(datasynth_core::ProcessChangeIntervention {
982 subtype: datasynth_core::ProcessChangeType::SystemMigration,
983 parameters: HashMap::new(),
984 }),
985 1,
986 OnsetType::Sudden,
987 );
988
989 let validated = vec![ValidatedIntervention {
990 intervention,
991 affected_config_paths: vec![],
992 }];
993
994 let result = engine.propagate(&validated, 2).unwrap();
995 assert!(!result.changes_by_month.is_empty());
997 if let Some(changes) = result.changes_by_month.get(&1) {
998 let lag_change = changes.iter().find(|c| c.source_node == "processing_lag");
999 assert!(lag_change.is_some(), "Should have processing_lag change");
1000 }
1001 }
1002}