1use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11#[derive(Debug, Clone)]
13pub struct InterventionResult {
14 pub baseline_samples: Vec<HashMap<String, f64>>,
16 pub intervened_samples: Vec<HashMap<String, f64>>,
18 pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25 pub average_treatment_effect: f64,
27 pub confidence_interval: (f64, f64),
29 pub sample_size: usize,
31}
32
33pub struct InterventionEngine {
35 scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39 pub fn new(scm: StructuralCausalModel) -> Self {
41 Self { scm }
42 }
43
44 pub fn do_intervention(
49 &self,
50 interventions: &[(String, f64)],
51 n_samples: usize,
52 seed: u64,
53 ) -> Result<InterventionResult, SynthError> {
54 if interventions.is_empty() {
55 return Err(SynthError::validation(
56 "At least one intervention must be specified",
57 ));
58 }
59
60 for (var_name, _) in interventions {
62 if self.scm.graph().get_variable(var_name).is_none() {
63 return Err(SynthError::generation(format!(
64 "Intervention variable '{var_name}' not found in causal graph"
65 )));
66 }
67 }
68
69 let baseline_samples = self
71 .scm
72 .generate(n_samples, seed)
73 .map_err(SynthError::generation)?;
74
75 let intervened_seed = seed.wrapping_add(1_000_000);
78 let intervened_samples = self
79 .generate_with_interventions(interventions, n_samples, intervened_seed)
80 .map_err(SynthError::generation)?;
81
82 let var_names = self.scm.graph().variable_names();
84 let mut effect_estimates = HashMap::new();
85
86 for var_name in &var_names {
87 let name = var_name.to_string();
88 let estimate =
89 Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
90 effect_estimates.insert(name, estimate);
91 }
92
93 Ok(InterventionResult {
94 baseline_samples,
95 intervened_samples,
96 effect_estimates,
97 })
98 }
99
100 fn generate_with_interventions(
102 &self,
103 interventions: &[(String, f64)],
104 n_samples: usize,
105 seed: u64,
106 ) -> Result<Vec<HashMap<String, f64>>, String> {
107 if interventions.is_empty() {
108 return self.scm.generate(n_samples, seed);
109 }
110
111 let first = &interventions[0];
113 let mut intervened = self.scm.intervene(&first.0, first.1)?;
114 for (var_name, value) in interventions.iter().skip(1) {
115 intervened = intervened.and_intervene(var_name, *value);
116 }
117 intervened.generate(n_samples, seed)
118 }
119
120 fn compute_effect_estimate(
122 baseline: &[HashMap<String, f64>],
123 intervened: &[HashMap<String, f64>],
124 variable: &str,
125 ) -> EffectEstimate {
126 let baseline_vals: Vec<f64> = baseline
127 .iter()
128 .filter_map(|s| s.get(variable).copied())
129 .collect();
130 let intervened_vals: Vec<f64> = intervened
131 .iter()
132 .filter_map(|s| s.get(variable).copied())
133 .collect();
134
135 let n = baseline_vals.len().min(intervened_vals.len());
136 if n == 0 {
137 return EffectEstimate {
138 average_treatment_effect: 0.0,
139 confidence_interval: (0.0, 0.0),
140 sample_size: 0,
141 };
142 }
143
144 let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
145 let intervened_mean: f64 =
146 intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
147 let ate = intervened_mean - baseline_mean;
148
149 let mut diffs: Vec<f64> = baseline_vals
151 .iter()
152 .zip(intervened_vals.iter())
153 .map(|(b, i)| i - b)
154 .collect();
155 diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
156
157 let ci = if diffs.len() >= 2 {
158 let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
159 let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
160 (diffs[lower_idx], diffs[upper_idx])
161 } else {
162 (ate, ate)
163 };
164
165 EffectEstimate {
166 average_treatment_effect: ate,
167 confidence_interval: ci,
168 sample_size: n,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::causal::graph::CausalGraph;
177
178 fn build_engine() -> InterventionEngine {
179 let graph = CausalGraph::fraud_detection_template();
180 let scm = StructuralCausalModel::new(graph).unwrap();
181 InterventionEngine::new(scm)
182 }
183
184 #[test]
185 fn test_causal_intervention_positive_ate() {
186 let engine = build_engine();
189 let result = engine
190 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
191 .unwrap();
192
193 let fp_estimate = result
194 .effect_estimates
195 .get("fraud_probability")
196 .expect("fraud_probability estimate missing");
197
198 assert!(
202 fp_estimate.average_treatment_effect > 0.0,
203 "ATE for fraud_probability should be positive, got {}",
204 fp_estimate.average_treatment_effect
205 );
206 assert_eq!(fp_estimate.sample_size, 500);
207 }
208
209 #[test]
210 fn test_causal_intervention_zero_ate_for_unconnected() {
211 let engine = build_engine();
215 let result = engine
216 .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
217 .unwrap();
218
219 let amt_estimate = result
220 .effect_estimates
221 .get("transaction_amount")
222 .expect("transaction_amount estimate missing");
223
224 assert!(
228 amt_estimate.average_treatment_effect.abs() < 500.0,
229 "ATE for unconnected variable should be near zero, got {}",
230 amt_estimate.average_treatment_effect
231 );
232 }
233
234 #[test]
235 fn test_causal_intervention_multiple_interventions() {
236 let engine = build_engine();
237 let result = engine
238 .do_intervention(
239 &[
240 ("transaction_amount".to_string(), 10000.0),
241 ("merchant_risk".to_string(), 0.9),
242 ],
243 200,
244 99,
245 )
246 .unwrap();
247
248 for sample in &result.intervened_samples {
250 let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
251 let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
252 assert!(
253 (amt - 10000.0).abs() < 1e-10,
254 "transaction_amount should be fixed at 10000.0"
255 );
256 assert!(
257 (risk - 0.9).abs() < 1e-10,
258 "merchant_risk should be fixed at 0.9"
259 );
260 }
261
262 assert_eq!(result.baseline_samples.len(), 200);
263 assert_eq!(result.intervened_samples.len(), 200);
264 }
265
266 #[test]
267 fn test_causal_intervention_empty_returns_error() {
268 let engine = build_engine();
269 let result = engine.do_intervention(&[], 100, 42);
270 assert!(result.is_err());
271 }
272
273 #[test]
274 fn test_causal_intervention_unknown_variable_returns_error() {
275 let engine = build_engine();
276 let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
277 assert!(result.is_err());
278 }
279
280 #[test]
281 fn test_causal_intervention_confidence_interval() {
282 let engine = build_engine();
283 let result = engine
284 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
285 .unwrap();
286
287 let fp_estimate = result
288 .effect_estimates
289 .get("fraud_probability")
290 .expect("fraud_probability estimate missing");
291
292 assert!(
294 fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
295 "CI lower ({}) should be <= ATE ({})",
296 fp_estimate.confidence_interval.0,
297 fp_estimate.average_treatment_effect
298 );
299 assert!(
303 fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
304 "CI upper ({}) should be >= CI lower ({})",
305 fp_estimate.confidence_interval.1,
306 fp_estimate.confidence_interval.0
307 );
308 }
309}