1use std::collections::HashMap;
6
7use crate::error::SynthError;
8
9use super::scm::StructuralCausalModel;
10
11#[derive(Debug, Clone)]
13pub struct InterventionResult {
14 pub baseline_samples: Vec<HashMap<String, f64>>,
16 pub intervened_samples: Vec<HashMap<String, f64>>,
18 pub effect_estimates: HashMap<String, EffectEstimate>,
20}
21
22#[derive(Debug, Clone)]
24pub struct EffectEstimate {
25 pub average_treatment_effect: f64,
27 pub confidence_interval: (f64, f64),
29 pub sample_size: usize,
31}
32
33pub struct InterventionEngine {
35 scm: StructuralCausalModel,
36}
37
38impl InterventionEngine {
39 pub fn new(scm: StructuralCausalModel) -> Self {
41 Self { scm }
42 }
43
44 pub fn do_intervention(
49 &self,
50 interventions: &[(String, f64)],
51 n_samples: usize,
52 seed: u64,
53 ) -> Result<InterventionResult, SynthError> {
54 if interventions.is_empty() {
55 return Err(SynthError::validation(
56 "At least one intervention must be specified",
57 ));
58 }
59
60 for (var_name, _) in interventions {
62 if self.scm.graph().get_variable(var_name).is_none() {
63 return Err(SynthError::generation(format!(
64 "Intervention variable '{}' not found in causal graph",
65 var_name
66 )));
67 }
68 }
69
70 let baseline_samples = self
72 .scm
73 .generate(n_samples, seed)
74 .map_err(SynthError::generation)?;
75
76 let intervened_seed = seed.wrapping_add(1_000_000);
79 let intervened_samples = self
80 .generate_with_interventions(interventions, n_samples, intervened_seed)
81 .map_err(SynthError::generation)?;
82
83 let var_names = self.scm.graph().variable_names();
85 let mut effect_estimates = HashMap::new();
86
87 for var_name in &var_names {
88 let name = var_name.to_string();
89 let estimate =
90 Self::compute_effect_estimate(&baseline_samples, &intervened_samples, &name);
91 effect_estimates.insert(name, estimate);
92 }
93
94 Ok(InterventionResult {
95 baseline_samples,
96 intervened_samples,
97 effect_estimates,
98 })
99 }
100
101 fn generate_with_interventions(
103 &self,
104 interventions: &[(String, f64)],
105 n_samples: usize,
106 seed: u64,
107 ) -> Result<Vec<HashMap<String, f64>>, String> {
108 if interventions.is_empty() {
109 return self.scm.generate(n_samples, seed);
110 }
111
112 let first = &interventions[0];
114 let mut intervened = self.scm.intervene(&first.0, first.1)?;
115 for (var_name, value) in interventions.iter().skip(1) {
116 intervened = intervened.and_intervene(var_name, *value);
117 }
118 intervened.generate(n_samples, seed)
119 }
120
121 fn compute_effect_estimate(
123 baseline: &[HashMap<String, f64>],
124 intervened: &[HashMap<String, f64>],
125 variable: &str,
126 ) -> EffectEstimate {
127 let baseline_vals: Vec<f64> = baseline
128 .iter()
129 .filter_map(|s| s.get(variable).copied())
130 .collect();
131 let intervened_vals: Vec<f64> = intervened
132 .iter()
133 .filter_map(|s| s.get(variable).copied())
134 .collect();
135
136 let n = baseline_vals.len().min(intervened_vals.len());
137 if n == 0 {
138 return EffectEstimate {
139 average_treatment_effect: 0.0,
140 confidence_interval: (0.0, 0.0),
141 sample_size: 0,
142 };
143 }
144
145 let baseline_mean: f64 = baseline_vals.iter().sum::<f64>() / baseline_vals.len() as f64;
146 let intervened_mean: f64 =
147 intervened_vals.iter().sum::<f64>() / intervened_vals.len() as f64;
148 let ate = intervened_mean - baseline_mean;
149
150 let mut diffs: Vec<f64> = baseline_vals
152 .iter()
153 .zip(intervened_vals.iter())
154 .map(|(b, i)| i - b)
155 .collect();
156 diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
157
158 let ci = if diffs.len() >= 2 {
159 let lower_idx = (diffs.len() as f64 * 0.025).floor() as usize;
160 let upper_idx = ((diffs.len() as f64 * 0.975).ceil() as usize).min(diffs.len() - 1);
161 (diffs[lower_idx], diffs[upper_idx])
162 } else {
163 (ate, ate)
164 };
165
166 EffectEstimate {
167 average_treatment_effect: ate,
168 confidence_interval: ci,
169 sample_size: n,
170 }
171 }
172}
173
174#[cfg(test)]
175#[allow(clippy::unwrap_used)]
176mod tests {
177 use super::*;
178 use crate::causal::graph::CausalGraph;
179
180 fn build_engine() -> InterventionEngine {
181 let graph = CausalGraph::fraud_detection_template();
182 let scm = StructuralCausalModel::new(graph).unwrap();
183 InterventionEngine::new(scm)
184 }
185
186 #[test]
187 fn test_causal_intervention_positive_ate() {
188 let engine = build_engine();
191 let result = engine
192 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
193 .unwrap();
194
195 let fp_estimate = result
196 .effect_estimates
197 .get("fraud_probability")
198 .expect("fraud_probability estimate missing");
199
200 assert!(
204 fp_estimate.average_treatment_effect > 0.0,
205 "ATE for fraud_probability should be positive, got {}",
206 fp_estimate.average_treatment_effect
207 );
208 assert_eq!(fp_estimate.sample_size, 500);
209 }
210
211 #[test]
212 fn test_causal_intervention_zero_ate_for_unconnected() {
213 let engine = build_engine();
217 let result = engine
218 .do_intervention(&[("fraud_probability".to_string(), 0.99)], 500, 42)
219 .unwrap();
220
221 let amt_estimate = result
222 .effect_estimates
223 .get("transaction_amount")
224 .expect("transaction_amount estimate missing");
225
226 assert!(
230 amt_estimate.average_treatment_effect.abs() < 500.0,
231 "ATE for unconnected variable should be near zero, got {}",
232 amt_estimate.average_treatment_effect
233 );
234 }
235
236 #[test]
237 fn test_causal_intervention_multiple_interventions() {
238 let engine = build_engine();
239 let result = engine
240 .do_intervention(
241 &[
242 ("transaction_amount".to_string(), 10000.0),
243 ("merchant_risk".to_string(), 0.9),
244 ],
245 200,
246 99,
247 )
248 .unwrap();
249
250 for sample in &result.intervened_samples {
252 let amt = sample.get("transaction_amount").copied().unwrap_or(0.0);
253 let risk = sample.get("merchant_risk").copied().unwrap_or(0.0);
254 assert!(
255 (amt - 10000.0).abs() < 1e-10,
256 "transaction_amount should be fixed at 10000.0"
257 );
258 assert!(
259 (risk - 0.9).abs() < 1e-10,
260 "merchant_risk should be fixed at 0.9"
261 );
262 }
263
264 assert_eq!(result.baseline_samples.len(), 200);
265 assert_eq!(result.intervened_samples.len(), 200);
266 }
267
268 #[test]
269 fn test_causal_intervention_empty_returns_error() {
270 let engine = build_engine();
271 let result = engine.do_intervention(&[], 100, 42);
272 assert!(result.is_err());
273 }
274
275 #[test]
276 fn test_causal_intervention_unknown_variable_returns_error() {
277 let engine = build_engine();
278 let result = engine.do_intervention(&[("nonexistent_var".to_string(), 1.0)], 100, 42);
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn test_causal_intervention_confidence_interval() {
284 let engine = build_engine();
285 let result = engine
286 .do_intervention(&[("transaction_amount".to_string(), 50000.0)], 500, 42)
287 .unwrap();
288
289 let fp_estimate = result
290 .effect_estimates
291 .get("fraud_probability")
292 .expect("fraud_probability estimate missing");
293
294 assert!(
296 fp_estimate.confidence_interval.0 <= fp_estimate.average_treatment_effect,
297 "CI lower ({}) should be <= ATE ({})",
298 fp_estimate.confidence_interval.0,
299 fp_estimate.average_treatment_effect
300 );
301 assert!(
305 fp_estimate.confidence_interval.1 >= fp_estimate.confidence_interval.0,
306 "CI upper ({}) should be >= CI lower ({})",
307 fp_estimate.confidence_interval.1,
308 fp_estimate.confidence_interval.0
309 );
310 }
311}