Skip to main content

datasynth_core/causal/
intervention.rs

1//! Intervention engine for causal inference using do-calculus.
2//!
3//! Computes average treatment effects by comparing baseline and intervened samples.
4
5use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11/// Result of an intervention experiment.
12#[derive(Debug, Clone)]
13pub struct InterventionResult {
14    /// Samples generated without any intervention.
15    pub baseline_samples: Vec<HashMap<String, f64>>,
16    /// Samples generated under the do-calculus intervention.
17    pub intervened_samples: Vec<HashMap<String, f64>>,
18    /// Estimated causal effects for each variable.
19    pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22/// Estimated causal effect of an intervention on a single variable.
23#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25    /// Average Treatment Effect: mean(intervened) - mean(baseline).
26    pub average_treatment_effect: f64,
27    /// Percentile-based confidence interval for the ATE.
28    pub confidence_interval: (f64, f64),
29    /// Number of samples used for the estimate.
30    pub sample_size: usize,
31}
32
33/// Engine for running causal interventions and estimating treatment effects.
34pub struct InterventionEngine {
35    scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39    /// Create a new intervention engine wrapping a structural causal model.
40    pub fn new(scm: StructuralCausalModel) -> Self {
41        Self { scm }
42    }
43
44    /// Run a do-calculus intervention and estimate causal effects.
45    ///
46    /// Generates baseline samples (no intervention) and intervened samples,
47    /// then computes the average treatment effect for each variable.
48    pub fn do_intervention(
49        &self,
50        interventions: &[(String, f64)],
51        n_samples: usize,
52        seed: u64,
53    ) -> Result<InterventionResult, SynthError> {
54        if interventions.is_empty() {
55            return Err(SynthError::validation(
56                "At least one intervention must be specified",
57            ));
58        }
59
60        // Validate all intervention variables exist
61        for (var_name, _) in interventions {
62            if self.scm.graph().get_variable(var_name).is_none() {
63                return Err(SynthError::generation(format!(
64                    "Intervention variable '{var_name}' not found in causal graph"
65                )));
66            }
67        }
68
69        // Generate baseline samples (no intervention)
70        let baseline_samples = self
71            .scm
72            .generate(n_samples, seed)
73            .map_err(SynthError::generation)?;
74
75        // Generate intervened samples using do-calculus
76        // Use a different seed offset so baseline and intervened don't share the same RNG state
77        let intervened_seed = seed.wrapping_add(1_000_000);
78        let intervened_samples = self
79            .generate_with_interventions(interventions, n_samples, intervened_seed)
80            .map_err(SynthError::generation)?;
81
82        // Compute effect estimates for each variable
83        let var_names = self.scm.graph().variable_names();
84        let mut effect_estimates = HashMap::new();
85
86        for var_name in &var_names {
87            let name = var_name.to_string();
88            let estimate =
89                Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
90            effect_estimates.insert(name, estimate);
91        }
92
93        Ok(InterventionResult {
94            baseline_samples,
95            intervened_samples,
96            effect_estimates,
97        })
98    }
99
100    /// Generate samples with multiple interventions applied.
101    fn generate_with_interventions(
102        &self,
103        interventions: &[(String, f64)],
104        n_samples: usize,
105        seed: u64,
106    ) -> Result<Vec<HashMap<String, f64>>, String> {
107        if interventions.is_empty() {
108            return self.scm.generate(n_samples, seed);
109        }
110
111        // Build the intervened SCM by chaining interventions
112        let first = &interventions[0];
113        let mut intervened = self.scm.intervene(&first.0, first.1)?;
114        for (var_name, value) in interventions.iter().skip(1) {
115            intervened = intervened.and_intervene(var_name, *value);
116        }
117        intervened.generate(n_samples, seed)
118    }
119
120    /// Compute the effect estimate for a single variable.
121    fn compute_effect_estimate(
122        baseline: &[HashMap<String, f64>],
123        intervened: &[HashMap<String, f64>],
124        variable: &str,
125    ) -> EffectEstimate {
126        let baseline_vals: Vec<f64> = baseline
127            .iter()
128            .filter_map(|s| s.get(variable).copied())
129            .collect();
130        let intervened_vals: Vec<f64> = intervened
131            .iter()
132            .filter_map(|s| s.get(variable).copied())
133            .collect();
134
135        let n = baseline_vals.len().min(intervened_vals.len());
136        if n == 0 {
137            return EffectEstimate {
138                average_treatment_effect: 0.0,
139                confidence_interval: (0.0, 0.0),
140                sample_size: 0,
141            };
142        }
143
144        let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
145        let intervened_mean: f64 =
146            intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
147        let ate = intervened_mean - baseline_mean;
148
149        // Compute percentile-based confidence interval using individual diffs
150        let mut diffs: Vec<f64> = baseline_vals
151            .iter()
152            .zip(intervened_vals.iter())
153            .map(|(b, i)| i - b)
154            .collect();
155        diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
156
157        let ci = if diffs.len() >= 2 {
158            let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
159            let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
160            (diffs[lower_idx], diffs[upper_idx])
161        } else {
162            (ate, ate)
163        };
164
165        EffectEstimate {
166            average_treatment_effect: ate,
167            confidence_interval: ci,
168            sample_size: n,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::causal::graph::CausalGraph;
177
178    fn build_engine() -> InterventionEngine {
179        let graph = CausalGraph::fraud_detection_template();
180        let scm = StructuralCausalModel::new(graph).unwrap();
181        InterventionEngine::new(scm)
182    }
183
184    #[test]
185    fn test_causal_intervention_positive_ate() {
186        // Increasing transaction_amount should increase fraud_probability
187        // because the mechanism is Linear { coefficient: 0.3 } (positive).
188        let engine = build_engine();
189        let result = engine
190            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
191            .unwrap();
192
193        let fp_estimate = result
194            .effect_estimates
195            .get("fraud_probability")
196            .expect("fraud_probability estimate missing");
197
198        // With a very large transaction_amount (50000), the linear mechanism
199        // contributes 0.3 * 50000 = 15000, which is much larger than typical
200        // baseline values. The ATE should be positive.
201        assert!(
202            fp_estimate.average_treatment_effect > 0.0,
203            "ATE for fraud_probability should be positive, got {}",
204            fp_estimate.average_treatment_effect
205        );
206        assert_eq!(fp_estimate.sample_size, 500);
207    }
208
209    #[test]
210    fn test_causal_intervention_zero_ate_for_unconnected() {
211        // transaction_amount is a root variable. Intervening on fraud_probability
212        // should not affect transaction_amount (it has no incoming edge from
213        // fraud_probability).
214        let engine = build_engine();
215        let result = engine
216            .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
217            .unwrap();
218
219        let amt_estimate = result
220            .effect_estimates
221            .get("transaction_amount")
222            .expect("transaction_amount estimate missing");
223
224        // The ATE should be approximately zero (within noise tolerance).
225        // The root variables are sampled independently, so only seed differences matter.
226        // We use a generous tolerance since the seeds differ.
227        assert!(
228            amt_estimate.average_treatment_effect.abs() < 500.0,
229            "ATE for unconnected variable should be near zero, got {}",
230            amt_estimate.average_treatment_effect
231        );
232    }
233
234    #[test]
235    fn test_causal_intervention_multiple_interventions() {
236        let engine = build_engine();
237        let result = engine
238            .do_intervention(
239                &[
240                    ("transaction_amount".to_string(), 10000.0),
241                    ("merchant_risk".to_string(), 0.9),
242                ],
243                200,
244                99,
245            )
246            .unwrap();
247
248        // Both interventions should be reflected in the intervened samples
249        for sample in &result.intervened_samples {
250            let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
251            let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
252            assert!(
253                (amt - 10000.0).abs() < 1e-10,
254                "transaction_amount should be fixed at 10000.0"
255            );
256            assert!(
257                (risk - 0.9).abs() < 1e-10,
258                "merchant_risk should be fixed at 0.9"
259            );
260        }
261
262        assert_eq!(result.baseline_samples.len(), 200);
263        assert_eq!(result.intervened_samples.len(), 200);
264    }
265
266    #[test]
267    fn test_causal_intervention_empty_returns_error() {
268        let engine = build_engine();
269        let result = engine.do_intervention(&[], 100, 42);
270        assert!(result.is_err());
271    }
272
273    #[test]
274    fn test_causal_intervention_unknown_variable_returns_error() {
275        let engine = build_engine();
276        let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
277        assert!(result.is_err());
278    }
279
280    #[test]
281    fn test_causal_intervention_confidence_interval() {
282        let engine = build_engine();
283        let result = engine
284            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
285            .unwrap();
286
287        let fp_estimate = result
288            .effect_estimates
289            .get("fraud_probability")
290            .expect("fraud_probability estimate missing");
291
292        // CI lower bound should be <= ATE <= CI upper bound
293        assert!(
294            fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
295            "CI lower ({}) should be <= ATE ({})",
296            fp_estimate.confidence_interval.0,
297            fp_estimate.average_treatment_effect
298        );
299        // Note: the ATE is the mean of diffs, CI is percentile-based on individual diffs,
300        // so ATE does not strictly need to be <= upper CI, but it generally is for
301        // well-behaved distributions. We just verify the CI has reasonable width.
302        assert!(
303            fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
304            "CI upper ({}) should be >= CI lower ({})",
305            fp_estimate.confidence_interval.1,
306            fp_estimate.confidence_interval.0
307        );
308    }
309}