Skip to main content

datasynth_core/causal/
intervention.rs

1//! Intervention engine for causal inference using do-calculus.
2//!
3//! Computes average treatment effects by comparing baseline and intervened samples.
4
5use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11/// Result of an intervention experiment.
12#[derive(Debug, Clone)]
13pub struct InterventionResult {
14    /// Samples generated without any intervention.
15    pub baseline_samples: Vec<HashMap<String, f64>>,
16    /// Samples generated under the do-calculus intervention.
17    pub intervened_samples: Vec<HashMap<String, f64>>,
18    /// Estimated causal effects for each variable.
19    pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22/// Estimated causal effect of an intervention on a single variable.
23#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25    /// Average Treatment Effect: mean(intervened) - mean(baseline).
26    pub average_treatment_effect: f64,
27    /// Percentile-based confidence interval for the ATE.
28    pub confidence_interval: (f64, f64),
29    /// Number of samples used for the estimate.
30    pub sample_size: usize,
31}
32
33/// Engine for running causal interventions and estimating treatment effects.
34pub struct InterventionEngine {
35    scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39    /// Create a new intervention engine wrapping a structural causal model.
40    pub fn new(scm: StructuralCausalModel) -> Self {
41        Self { scm }
42    }
43
44    /// Run a do-calculus intervention and estimate causal effects.
45    ///
46    /// Generates baseline samples (no intervention) and intervened samples,
47    /// then computes the average treatment effect for each variable.
48    pub fn do_intervention(
49        &self,
50        interventions: &[(String, f64)],
51        n_samples: usize,
52        seed: u64,
53    ) -> Result<InterventionResult, SynthError> {
54        if interventions.is_empty() {
55            return Err(SynthError::validation(
56                "At least one intervention must be specified",
57            ));
58        }
59
60        // Validate all intervention variables exist
61        for (var_name, _) in interventions {
62            if self.scm.graph().get_variable(var_name).is_none() {
63                return Err(SynthError::generation(format!(
64                    "Intervention variable '{}' not found in causal graph",
65                    var_name
66                )));
67            }
68        }
69
70        // Generate baseline samples (no intervention)
71        let baseline_samples = self
72            .scm
73            .generate(n_samples, seed)
74            .map_err(SynthError::generation)?;
75
76        // Generate intervened samples using do-calculus
77        // Use a different seed offset so baseline and intervened don't share the same RNG state
78        let intervened_seed = seed.wrapping_add(1_000_000);
79        let intervened_samples = self
80            .generate_with_interventions(interventions, n_samples, intervened_seed)
81            .map_err(SynthError::generation)?;
82
83        // Compute effect estimates for each variable
84        let var_names = self.scm.graph().variable_names();
85        let mut effect_estimates = HashMap::new();
86
87        for var_name in &var_names {
88            let name = var_name.to_string();
89            let estimate =
90                Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
91            effect_estimates.insert(name, estimate);
92        }
93
94        Ok(InterventionResult {
95            baseline_samples,
96            intervened_samples,
97            effect_estimates,
98        })
99    }
100
101    /// Generate samples with multiple interventions applied.
102    fn generate_with_interventions(
103        &self,
104        interventions: &[(String, f64)],
105        n_samples: usize,
106        seed: u64,
107    ) -> Result<Vec<HashMap<String, f64>>, String> {
108        if interventions.is_empty() {
109            return self.scm.generate(n_samples, seed);
110        }
111
112        // Build the intervened SCM by chaining interventions
113        let first = &interventions[0];
114        let mut intervened = self.scm.intervene(&first.0, first.1)?;
115        for (var_name, value) in interventions.iter().skip(1) {
116            intervened = intervened.and_intervene(var_name, *value);
117        }
118        intervened.generate(n_samples, seed)
119    }
120
121    /// Compute the effect estimate for a single variable.
122    fn compute_effect_estimate(
123        baseline: &[HashMap<String, f64>],
124        intervened: &[HashMap<String, f64>],
125        variable: &str,
126    ) -> EffectEstimate {
127        let baseline_vals: Vec<f64> = baseline
128            .iter()
129            .filter_map(|s| s.get(variable).copied())
130            .collect();
131        let intervened_vals: Vec<f64> = intervened
132            .iter()
133            .filter_map(|s| s.get(variable).copied())
134            .collect();
135
136        let n = baseline_vals.len().min(intervened_vals.len());
137        if n == 0 {
138            return EffectEstimate {
139                average_treatment_effect: 0.0,
140                confidence_interval: (0.0, 0.0),
141                sample_size: 0,
142            };
143        }
144
145        let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
146        let intervened_mean: f64 =
147            intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
148        let ate = intervened_mean - baseline_mean;
149
150        // Compute percentile-based confidence interval using individual diffs
151        let mut diffs: Vec<f64> = baseline_vals
152            .iter()
153            .zip(intervened_vals.iter())
154            .map(|(b, i)| i - b)
155            .collect();
156        diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
157
158        let ci = if diffs.len() >= 2 {
159            let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
160            let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
161            (diffs[lower_idx], diffs[upper_idx])
162        } else {
163            (ate, ate)
164        };
165
166        EffectEstimate {
167            average_treatment_effect: ate,
168            confidence_interval: ci,
169            sample_size: n,
170        }
171    }
172}
173
174#[cfg(test)]
175#[allow(clippy::unwrap_used)]
176mod tests {
177    use super::*;
178    use crate::causal::graph::CausalGraph;
179
180    fn build_engine() -> InterventionEngine {
181        let graph = CausalGraph::fraud_detection_template();
182        let scm = StructuralCausalModel::new(graph).unwrap();
183        InterventionEngine::new(scm)
184    }
185
186    #[test]
187    fn test_causal_intervention_positive_ate() {
188        // Increasing transaction_amount should increase fraud_probability
189        // because the mechanism is Linear { coefficient: 0.3 } (positive).
190        let engine = build_engine();
191        let result = engine
192            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
193            .unwrap();
194
195        let fp_estimate = result
196            .effect_estimates
197            .get("fraud_probability")
198            .expect("fraud_probability estimate missing");
199
200        // With a very large transaction_amount (50000), the linear mechanism
201        // contributes 0.3 * 50000 = 15000, which is much larger than typical
202        // baseline values. The ATE should be positive.
203        assert!(
204            fp_estimate.average_treatment_effect > 0.0,
205            "ATE for fraud_probability should be positive, got {}",
206            fp_estimate.average_treatment_effect
207        );
208        assert_eq!(fp_estimate.sample_size, 500);
209    }
210
211    #[test]
212    fn test_causal_intervention_zero_ate_for_unconnected() {
213        // transaction_amount is a root variable. Intervening on fraud_probability
214        // should not affect transaction_amount (it has no incoming edge from
215        // fraud_probability).
216        let engine = build_engine();
217        let result = engine
218            .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
219            .unwrap();
220
221        let amt_estimate = result
222            .effect_estimates
223            .get("transaction_amount")
224            .expect("transaction_amount estimate missing");
225
226        // The ATE should be approximately zero (within noise tolerance).
227        // The root variables are sampled independently, so only seed differences matter.
228        // We use a generous tolerance since the seeds differ.
229        assert!(
230            amt_estimate.average_treatment_effect.abs() < 500.0,
231            "ATE for unconnected variable should be near zero, got {}",
232            amt_estimate.average_treatment_effect
233        );
234    }
235
236    #[test]
237    fn test_causal_intervention_multiple_interventions() {
238        let engine = build_engine();
239        let result = engine
240            .do_intervention(
241                &[
242                    ("transaction_amount".to_string(), 10000.0),
243                    ("merchant_risk".to_string(), 0.9),
244                ],
245                200,
246                99,
247            )
248            .unwrap();
249
250        // Both interventions should be reflected in the intervened samples
251        for sample in &result.intervened_samples {
252            let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
253            let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
254            assert!(
255                (amt - 10000.0).abs() < 1e-10,
256                "transaction_amount should be fixed at 10000.0"
257            );
258            assert!(
259                (risk - 0.9).abs() < 1e-10,
260                "merchant_risk should be fixed at 0.9"
261            );
262        }
263
264        assert_eq!(result.baseline_samples.len(), 200);
265        assert_eq!(result.intervened_samples.len(), 200);
266    }
267
268    #[test]
269    fn test_causal_intervention_empty_returns_error() {
270        let engine = build_engine();
271        let result = engine.do_intervention(&[], 100, 42);
272        assert!(result.is_err());
273    }
274
275    #[test]
276    fn test_causal_intervention_unknown_variable_returns_error() {
277        let engine = build_engine();
278        let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn test_causal_intervention_confidence_interval() {
284        let engine = build_engine();
285        let result = engine
286            .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
287            .unwrap();
288
289        let fp_estimate = result
290            .effect_estimates
291            .get("fraud_probability")
292            .expect("fraud_probability estimate missing");
293
294        // CI lower bound should be <= ATE <= CI upper bound
295        assert!(
296            fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
297            "CI lower ({}) should be <= ATE ({})",
298            fp_estimate.confidence_interval.0,
299            fp_estimate.average_treatment_effect
300        );
301        // Note: the ATE is the mean of diffs, CI is percentile-based on individual diffs,
302        // so ATE does not strictly need to be <= upper CI, but it generally is for
303        // well-behaved distributions. We just verify the CI has reasonable width.
304        assert!(
305            fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
306            "CI upper ({}) should be >= CI lower ({})",
307            fp_estimate.confidence_interval.1,
308            fp_estimate.confidence_interval.0
309        );
310    }
311}