1use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11#[derive(Debug, Clone)]
13pub struct InterventionResult {
14 pub baseline_samples: Vec<HashMap<String, f64>>,
16 pub intervened_samples: Vec<HashMap<String, f64>>,
18 pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25 pub average_treatment_effect: f64,
27 pub confidence_interval: (f64, f64),
29 pub sample_size: usize,
31}
32
33pub struct InterventionEngine {
35 scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39 pub fn new(scm: StructuralCausalModel) -> Self {
41 Self { scm }
42 }
43
44 pub fn do_intervention(
49 &self,
50 interventions: &[(String, f64)],
51 n_samples: usize,
52 seed: u64,
53 ) -> Result<InterventionResult, SynthError> {
54 if interventions.is_empty() {
55 return Err(SynthError::validation(
56 "At least one intervention must be specified",
57 ));
58 }
59
60 for (var_name, _) in interventions {
62 if self.scm.graph().get_variable(var_name).is_none() {
63 return Err(SynthError::generation(format!(
64 "Intervention variable '{var_name}' not found in causal graph"
65 )));
66 }
67 }
68
69 let baseline_samples = self
71 .scm
72 .generate(n_samples, seed)
73 .map_err(SynthError::generation)?;
74
75 let intervened_seed = seed.wrapping_add(1_000_000);
78 let intervened_samples = self
79 .generate_with_interventions(interventions, n_samples, intervened_seed)
80 .map_err(SynthError::generation)?;
81
82 let var_names = self.scm.graph().variable_names();
84 let mut effect_estimates = HashMap::new();
85
86 for var_name in &var_names {
87 let name = var_name.to_string();
88 let estimate =
89 Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
90 effect_estimates.insert(name, estimate);
91 }
92
93 Ok(InterventionResult {
94 baseline_samples,
95 intervened_samples,
96 effect_estimates,
97 })
98 }
99
100 fn generate_with_interventions(
102 &self,
103 interventions: &[(String, f64)],
104 n_samples: usize,
105 seed: u64,
106 ) -> Result<Vec<HashMap<String, f64>>, String> {
107 if interventions.is_empty() {
108 return self.scm.generate(n_samples, seed);
109 }
110
111 let first = &interventions[0];
113 let mut intervened = self.scm.intervene(&first.0, first.1)?;
114 for (var_name, value) in interventions.iter().skip(1) {
115 intervened = intervened.and_intervene(var_name, *value);
116 }
117 intervened.generate(n_samples, seed)
118 }
119
120 fn compute_effect_estimate(
122 baseline: &[HashMap<String, f64>],
123 intervened: &[HashMap<String, f64>],
124 variable: &str,
125 ) -> EffectEstimate {
126 let baseline_vals: Vec<f64> = baseline
127 .iter()
128 .filter_map(|s| s.get(variable).copied())
129 .collect();
130 let intervened_vals: Vec<f64> = intervened
131 .iter()
132 .filter_map(|s| s.get(variable).copied())
133 .collect();
134
135 let n = baseline_vals.len().min(intervened_vals.len());
136 if n == 0 {
137 return EffectEstimate {
138 average_treatment_effect: 0.0,
139 confidence_interval: (0.0, 0.0),
140 sample_size: 0,
141 };
142 }
143
144 let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
145 let intervened_mean: f64 =
146 intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
147 let ate = intervened_mean - baseline_mean;
148
149 let mut diffs: Vec<f64> = baseline_vals
151 .iter()
152 .zip(intervened_vals.iter())
153 .map(|(b, i)| i - b)
154 .collect();
155 diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
156
157 let ci = if diffs.len() >= 2 {
158 let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
159 let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
160 (diffs[lower_idx], diffs[upper_idx])
161 } else {
162 (ate, ate)
163 };
164
165 EffectEstimate {
166 average_treatment_effect: ate,
167 confidence_interval: ci,
168 sample_size: n,
169 }
170 }
171}
172
173#[cfg(test)]
174#[allow(clippy::unwrap_used)]
175mod tests {
176 use super::*;
177 use crate::causal::graph::CausalGraph;
178
179 fn build_engine() -> InterventionEngine {
180 let graph = CausalGraph::fraud_detection_template();
181 let scm = StructuralCausalModel::new(graph).unwrap();
182 InterventionEngine::new(scm)
183 }
184
185 #[test]
186 fn test_causal_intervention_positive_ate() {
187 let engine = build_engine();
190 let result = engine
191 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
192 .unwrap();
193
194 let fp_estimate = result
195 .effect_estimates
196 .get("fraud_probability")
197 .expect("fraud_probability estimate missing");
198
199 assert!(
203 fp_estimate.average_treatment_effect > 0.0,
204 "ATE for fraud_probability should be positive, got {}",
205 fp_estimate.average_treatment_effect
206 );
207 assert_eq!(fp_estimate.sample_size, 500);
208 }
209
210 #[test]
211 fn test_causal_intervention_zero_ate_for_unconnected() {
212 let engine = build_engine();
216 let result = engine
217 .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
218 .unwrap();
219
220 let amt_estimate = result
221 .effect_estimates
222 .get("transaction_amount")
223 .expect("transaction_amount estimate missing");
224
225 assert!(
229 amt_estimate.average_treatment_effect.abs() < 500.0,
230 "ATE for unconnected variable should be near zero, got {}",
231 amt_estimate.average_treatment_effect
232 );
233 }
234
235 #[test]
236 fn test_causal_intervention_multiple_interventions() {
237 let engine = build_engine();
238 let result = engine
239 .do_intervention(
240 &[
241 ("transaction_amount".to_string(), 10000.0),
242 ("merchant_risk".to_string(), 0.9),
243 ],
244 200,
245 99,
246 )
247 .unwrap();
248
249 for sample in &result.intervened_samples {
251 let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
252 let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
253 assert!(
254 (amt - 10000.0).abs() < 1e-10,
255 "transaction_amount should be fixed at 10000.0"
256 );
257 assert!(
258 (risk - 0.9).abs() < 1e-10,
259 "merchant_risk should be fixed at 0.9"
260 );
261 }
262
263 assert_eq!(result.baseline_samples.len(), 200);
264 assert_eq!(result.intervened_samples.len(), 200);
265 }
266
267 #[test]
268 fn test_causal_intervention_empty_returns_error() {
269 let engine = build_engine();
270 let result = engine.do_intervention(&[], 100, 42);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_causal_intervention_unknown_variable_returns_error() {
276 let engine = build_engine();
277 let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
278 assert!(result.is_err());
279 }
280
281 #[test]
282 fn test_causal_intervention_confidence_interval() {
283 let engine = build_engine();
284 let result = engine
285 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
286 .unwrap();
287
288 let fp_estimate = result
289 .effect_estimates
290 .get("fraud_probability")
291 .expect("fraud_probability estimate missing");
292
293 assert!(
295 fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
296 "CI lower ({}) should be <= ATE ({})",
297 fp_estimate.confidence_interval.0,
298 fp_estimate.average_treatment_effect
299 );
300 assert!(
304 fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
305 "CI upper ({}) should be >= CI lower ({})",
306 fp_estimate.confidence_interval.1,
307 fp_estimate.confidence_interval.0
308 );
309 }
310}