Skip to main content

datasynth_core/causal/
intervention.rs

1//! Intervention engine for causal inference using do-calculus.
2//!
3//! Computes average treatment effects by comparing baseline and intervened samples.
4
5use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11/// Result of an intervention experiment.
12#[derive(Debug, Clone)]
13pub struct InterventionResult {
14    /// Samples generated without any intervention.
15    pub baseline_samples: Vec<HashMap<String, f64>>,
16    /// Samples generated under the do-calculus intervention.
17    pub intervened_samples: Vec<HashMap<String, f64>>,
18    /// Estimated causal effects for each variable.
19    pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22/// Estimated causal effect of an intervention on a single variable.
23#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25    /// Average Treatment Effect: mean(intervened) - mean(baseline).
26    pub average_treatment_effect: f64,
27    /// Percentile-based confidence interval for the ATE.
28    pub confidence_interval: (f64, f64),
29    /// Number of samples used for the estimate.
30    pub sample_size: usize,
31}
32
33/// Engine for running causal interventions and estimating treatment effects.
34pub struct InterventionEngine {
35    scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39    /// Create a new intervention engine wrapping a structural causal model.
40    pub fn new(scm: StructuralCausalModel) -> Self {
41        Self { scm }
42    }
43
44    /// Run a do-calculus intervention and estimate causal effects.
45    ///
46    /// Generates baseline samples (no intervention) and intervened samples,
47    /// then computes the average treatment effect for each variable.
48    pub fn do_intervention(
49        &self,
50        interventions: &[(String, f64)],
51        n_samples: usize,
52        seed: u64,
53    ) -> Result<InterventionResult, SynthError> {
54        if interventions.is_empty() {
55            return Err(SynthError::validation(
56                "At least one intervention must be specified",
57            ));
58        }
59
60        // Validate all intervention variables exist
61        for (var_name, _) in interventions {
62            if self.scm.graph().get_variable(var_name).is_none() {
63                return Err(SynthError::generation(format!(
64                    "Intervention variable '{var_name}' not found in causal graph"
65                )));
66            }
67        }
68
69        // Generate baseline samples (no intervention)
70        let baseline_samples = self
71            .scm
72            .generate(n_samples, seed)
73            .map_err(SynthError::generation)?;
74
75        // Generate intervened samples using do-calculus
76        // Use a different seed offset so baseline and intervened don't share the same RNG state
77        let intervened_seed = seed.wrapping_add(1_000_000);
78        let intervened_samples = self
79            .generate_with_interventions(interventions, n_samples, intervened_seed)
80            .map_err(SynthError::generation)?;
81
82        // Compute effect estimates for each variable
83        let var_names = self.scm.graph().variable_names();
84        let mut effect_estimates = HashMap::new();
85
86        for var_name in &var_names {
87            let name = var_name.to_string();
88            let estimate =
89                Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
90            effect_estimates.insert(name, estimate);
91        }
92
93        Ok(InterventionResult {
94            baseline_samples,
95            intervened_samples,
96            effect_estimates,
97        })
98    }
99
100    /// Generate samples with multiple interventions applied.
101    fn generate_with_interventions(
102        &self,
103        interventions: &[(String, f64)],
104        n_samples: usize,
105        seed: u64,
106    ) -> Result<Vec<HashMap<String, f64>>, String> {
107        if interventions.is_empty() {
108            return self.scm.generate(n_samples, seed);
109        }
110
111        // Build the intervened SCM by chaining interventions
112        let first = &interventions[0];
113        let mut intervened = self.scm.intervene(&first.0, first.1)?;
114        for (var_name, value) in interventions.iter().skip(1) {
115            intervened = intervened.and_intervene(var_name, *value);
116        }
117        intervened.generate(n_samples, seed)
118    }
119
120    /// Compute the effect estimate for a single variable.
121    fn compute_effect_estimate(
122        baseline: &[HashMap<String, f64>],
123        intervened: &[HashMap<String, f64>],
124        variable: &str,
125    ) -> EffectEstimate {
126        let baseline_vals: Vec<f64> = baseline
127            .iter()
128            .filter_map(|s| s.get(variable).copied())
129            .collect();
130        let intervened_vals: Vec<f64> = intervened
131            .iter()
132            .filter_map(|s| s.get(variable).copied())
133            .collect();
134
135        let n = baseline_vals.len().min(intervened_vals.len());
136        if n == 0 {
137            return EffectEstimate {
138                average_treatment_effect: 0.0,
139                confidence_interval: (0.0, 0.0),
140                sample_size: 0,
141            };
142        }
143
144        let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
145        let intervened_mean: f64 =
146            intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
147        let ate = intervened_mean - baseline_mean;
148
149        // Compute percentile-based confidence interval using individual diffs
150        let mut diffs: Vec<f64> = baseline_vals
151            .iter()
152            .zip(intervened_vals.iter())
153            .map(|(b, i)| i - b)
154            .collect();
155        diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
156
157        let ci = if diffs.len() >= 2 {
158            let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
159            let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
160            (diffs[lower_idx], diffs[upper_idx])
161        } else {
162            (ate, ate)
163        };
164
165        EffectEstimate {
166            average_treatment_effect: ate,
167            confidence_interval: ci,
168            sample_size: n,
169        }
170    }
171}
172
173#[cfg(test)]
174#[allow(clippy::unwrap_used)]
175mod tests {
176    use super::*;
177    use crate::causal::graph::CausalGraph;
178
179    fn build_engine() -> InterventionEngine {
180        let graph = CausalGraph::fraud_detection_template();
181        let scm = StructuralCausalModel::new(graph).unwrap();
182        InterventionEngine::new(scm)
183    }
184
185    #[test]
186    fn test_causal_intervention_positive_ate() {
187        // Increasing transaction_amount should increase fraud_probability
188        // because the mechanism is Linear { coefficient: 0.3 } (positive).
189        let engine = build_engine();
190        let result = engine
191            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
192            .unwrap();
193
194        let fp_estimate = result
195            .effect_estimates
196            .get("fraud_probability")
197            .expect("fraud_probability estimate missing");
198
199        // With a very large transaction_amount (50000), the linear mechanism
200        // contributes 0.3 * 50000 = 15000, which is much larger than typical
201        // baseline values. The ATE should be positive.
202        assert!(
203            fp_estimate.average_treatment_effect > 0.0,
204            "ATE for fraud_probability should be positive, got {}",
205            fp_estimate.average_treatment_effect
206        );
207        assert_eq!(fp_estimate.sample_size, 500);
208    }
209
210    #[test]
211    fn test_causal_intervention_zero_ate_for_unconnected() {
212        // transaction_amount is a root variable. Intervening on fraud_probability
213        // should not affect transaction_amount (it has no incoming edge from
214        // fraud_probability).
215        let engine = build_engine();
216        let result = engine
217            .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
218            .unwrap();
219
220        let amt_estimate = result
221            .effect_estimates
222            .get("transaction_amount")
223            .expect("transaction_amount estimate missing");
224
225        // The ATE should be approximately zero (within noise tolerance).
226        // The root variables are sampled independently, so only seed differences matter.
227        // We use a generous tolerance since the seeds differ.
228        assert!(
229            amt_estimate.average_treatment_effect.abs() < 500.0,
230            "ATE for unconnected variable should be near zero, got {}",
231            amt_estimate.average_treatment_effect
232        );
233    }
234
235    #[test]
236    fn test_causal_intervention_multiple_interventions() {
237        let engine = build_engine();
238        let result = engine
239            .do_intervention(
240                &[
241                    ("transaction_amount".to_string(), 10000.0),
242                    ("merchant_risk".to_string(), 0.9),
243                ],
244                200,
245                99,
246            )
247            .unwrap();
248
249        // Both interventions should be reflected in the intervened samples
250        for sample in &result.intervened_samples {
251            let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
252            let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
253            assert!(
254                (amt - 10000.0).abs() < 1e-10,
255                "transaction_amount should be fixed at 10000.0"
256            );
257            assert!(
258                (risk - 0.9).abs() < 1e-10,
259                "merchant_risk should be fixed at 0.9"
260            );
261        }
262
263        assert_eq!(result.baseline_samples.len(), 200);
264        assert_eq!(result.intervened_samples.len(), 200);
265    }
266
267    #[test]
268    fn test_causal_intervention_empty_returns_error() {
269        let engine = build_engine();
270        let result = engine.do_intervention(&[], 100, 42);
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_causal_intervention_unknown_variable_returns_error() {
276        let engine = build_engine();
277        let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
278        assert!(result.is_err());
279    }
280
281    #[test]
282    fn test_causal_intervention_confidence_interval() {
283        let engine = build_engine();
284        let result = engine
285            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
286            .unwrap();
287
288        let fp_estimate = result
289            .effect_estimates
290            .get("fraud_probability")
291            .expect("fraud_probability estimate missing");
292
293        // CI lower bound should be <= ATE <= CI upper bound
294        assert!(
295            fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
296            "CI lower ({}) should be <= ATE ({})",
297            fp_estimate.confidence_interval.0,
298            fp_estimate.average_treatment_effect
299        );
300        // Note: the ATE is the mean of diffs, CI is percentile-based on individual diffs,
301        // so ATE does not strictly need to be <= upper CI, but it generally is for
302        // well-behaved distributions. We just verify the CI has reasonable width.
303        assert!(
304            fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
305            "CI upper ({}) should be >= CI lower ({})",
306            fp_estimate.confidence_interval.1,
307            fp_estimate.confidence_interval.0
308        );
309    }
310}