1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10#[derive(Debug, Clone)]
12pub struct CausalEdgeData {
13 pub source: String,
15 pub target: String,
17 pub expected_sign: f64,
19 pub observed_correlation: f64,
21}
22
23#[derive(Debug, Clone)]
25pub struct InterventionData {
26 pub intervention_variable: String,
28 pub expected_direction: f64,
30 pub observed_change: f64,
32 pub target_variable: String,
34 pub expected_magnitude: f64,
36 pub pre_intervention_values: Vec<f64>,
38 pub post_intervention_values: Vec<f64>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CausalThresholds {
45 pub min_sign_accuracy: f64,
47 pub min_intervention_accuracy: f64,
49 pub min_magnitude_accuracy: f64,
51}
52
53impl Default for CausalThresholds {
54 fn default() -> Self {
55 Self {
56 min_sign_accuracy: 0.80,
57 min_intervention_accuracy: 0.70,
58 min_magnitude_accuracy: 0.60,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CausalModelEvaluation {
66 pub edge_correlation_sign_accuracy: f64,
68 pub topological_consistency: bool,
70 pub intervention_effect_accuracy: f64,
72 pub intervention_magnitude_accuracy: f64,
74 pub avg_effect_size: f64,
76 pub total_edges: usize,
78 pub total_interventions: usize,
80 pub passes: bool,
82 pub issues: Vec<String>,
84}
85
86pub struct CausalModelEvaluator {
88 thresholds: CausalThresholds,
89}
90
91impl CausalModelEvaluator {
92 pub fn new() -> Self {
94 Self {
95 thresholds: CausalThresholds::default(),
96 }
97 }
98
99 pub fn with_thresholds(thresholds: CausalThresholds) -> Self {
101 Self { thresholds }
102 }
103
104 fn is_dag(edges: &[CausalEdgeData]) -> bool {
106 let mut in_degree: HashMap<&str, usize> = HashMap::new();
107 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
108
109 for edge in edges {
111 in_degree.entry(edge.source.as_str()).or_insert(0);
112 in_degree.entry(edge.target.as_str()).or_insert(0);
113 adj.entry(edge.source.as_str()).or_default();
114 }
115
116 for edge in edges {
118 adj.entry(edge.source.as_str())
119 .or_default()
120 .push(edge.target.as_str());
121 *in_degree.entry(edge.target.as_str()).or_insert(0) += 1;
122 }
123
124 let mut queue: VecDeque<&str> = in_degree
126 .iter()
127 .filter(|(_, &d)| d == 0)
128 .map(|(&n, _)| n)
129 .collect();
130 let mut visited = 0usize;
131
132 while let Some(node) = queue.pop_front() {
133 visited += 1;
134 if let Some(neighbors) = adj.get(node) {
135 for &neighbor in neighbors {
136 if let Some(d) = in_degree.get_mut(neighbor) {
137 *d -= 1;
138 if *d == 0 {
139 queue.push_back(neighbor);
140 }
141 }
142 }
143 }
144 }
145
146 visited == in_degree.len()
147 }
148
149 fn cohens_d(pre: &[f64], post: &[f64]) -> Option<f64> {
154 let n1 = pre.len();
155 let n2 = post.len();
156 if n1 < 2 || n2 < 2 {
157 return None;
158 }
159
160 let mean1 = pre.iter().sum::<f64>() / n1 as f64;
161 let mean2 = post.iter().sum::<f64>() / n2 as f64;
162
163 let var1 = pre.iter().map(|x| (x - mean1).powi(2)).sum::<f64>() / (n1 - 1) as f64;
164 let var2 = post.iter().map(|x| (x - mean2).powi(2)).sum::<f64>() / (n2 - 1) as f64;
165
166 let pooled_var = ((n1 - 1) as f64 * var1 + (n2 - 1) as f64 * var2) / (n1 + n2 - 2) as f64;
167 let pooled_std = pooled_var.sqrt();
168
169 if pooled_std < f64::EPSILON {
170 return None;
171 }
172
173 Some((mean2 - mean1).abs() / pooled_std)
174 }
175
176 fn compute_avg_effect_size(interventions: &[InterventionData]) -> f64 {
178 let effect_sizes: Vec<f64> = interventions
179 .iter()
180 .filter_map(|i| Self::cohens_d(&i.pre_intervention_values, &i.post_intervention_values))
181 .collect();
182
183 if effect_sizes.is_empty() {
184 0.0
185 } else {
186 effect_sizes.iter().sum::<f64>() / effect_sizes.len() as f64
187 }
188 }
189
190 pub fn evaluate(
192 &self,
193 edges: &[CausalEdgeData],
194 interventions: &[InterventionData],
195 ) -> EvalResult<CausalModelEvaluation> {
196 let mut issues = Vec::new();
197
198 let sign_correct = edges
200 .iter()
201 .filter(|e| {
202 e.expected_sign * e.observed_correlation > 0.0
204 || (e.expected_sign.abs() < f64::EPSILON && e.observed_correlation.abs() < 0.05)
205 })
206 .count();
207 let edge_correlation_sign_accuracy = if edges.is_empty() {
208 1.0
209 } else {
210 sign_correct as f64 / edges.len() as f64
211 };
212
213 let topological_consistency = if edges.is_empty() {
215 true
216 } else {
217 Self::is_dag(edges)
218 };
219
220 let intervention_correct = interventions
222 .iter()
223 .filter(|i| i.expected_direction * i.observed_change > 0.0)
224 .count();
225 let intervention_effect_accuracy = if interventions.is_empty() {
226 1.0
227 } else {
228 intervention_correct as f64 / interventions.len() as f64
229 };
230
231 let magnitude_within_bounds = interventions
233 .iter()
234 .filter(|i| {
235 if i.expected_magnitude.abs() < f64::EPSILON {
236 false
238 } else {
239 let ratio = i.observed_change.abs() / i.expected_magnitude.abs();
240 (0.25..=4.0).contains(&ratio)
241 }
242 })
243 .count();
244 let intervention_magnitude_accuracy = if interventions.is_empty() {
245 1.0
246 } else {
247 magnitude_within_bounds as f64 / interventions.len() as f64
248 };
249
250 let avg_effect_size = Self::compute_avg_effect_size(interventions);
252
253 if edge_correlation_sign_accuracy < self.thresholds.min_sign_accuracy {
255 issues.push(format!(
256 "Edge sign accuracy {:.3} < {:.3}",
257 edge_correlation_sign_accuracy, self.thresholds.min_sign_accuracy
258 ));
259 }
260 if !topological_consistency {
261 issues.push("Causal graph contains cycles (not a DAG)".to_string());
262 }
263 if intervention_effect_accuracy < self.thresholds.min_intervention_accuracy {
264 issues.push(format!(
265 "Intervention accuracy {:.3} < {:.3}",
266 intervention_effect_accuracy, self.thresholds.min_intervention_accuracy
267 ));
268 }
269 if intervention_magnitude_accuracy < self.thresholds.min_magnitude_accuracy {
270 issues.push(format!(
271 "Intervention magnitude accuracy {:.3} < {:.3}",
272 intervention_magnitude_accuracy, self.thresholds.min_magnitude_accuracy
273 ));
274 }
275
276 let passes = issues.is_empty();
277
278 Ok(CausalModelEvaluation {
279 edge_correlation_sign_accuracy,
280 topological_consistency,
281 intervention_effect_accuracy,
282 intervention_magnitude_accuracy,
283 avg_effect_size,
284 total_edges: edges.len(),
285 total_interventions: interventions.len(),
286 passes,
287 issues,
288 })
289 }
290}
291
292impl Default for CausalModelEvaluator {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298#[cfg(test)]
299#[allow(clippy::unwrap_used)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_valid_causal_model() {
305 let evaluator = CausalModelEvaluator::new();
306 let edges = vec![
307 CausalEdgeData {
308 source: "revenue".to_string(),
309 target: "profit".to_string(),
310 expected_sign: 1.0,
311 observed_correlation: 0.85,
312 },
313 CausalEdgeData {
314 source: "cost".to_string(),
315 target: "profit".to_string(),
316 expected_sign: -1.0,
317 observed_correlation: -0.70,
318 },
319 ];
320 let interventions = vec![InterventionData {
321 intervention_variable: "revenue".to_string(),
322 expected_direction: 1.0,
323 observed_change: 5000.0,
324 target_variable: "profit".to_string(),
325 expected_magnitude: 5000.0,
326 pre_intervention_values: vec![100.0, 110.0, 105.0, 95.0, 108.0],
327 post_intervention_values: vec![200.0, 210.0, 205.0, 195.0, 208.0],
328 }];
329
330 let result = evaluator.evaluate(&edges, &interventions).unwrap();
331 assert!(result.passes);
332 assert!(result.topological_consistency);
333 assert_eq!(result.edge_correlation_sign_accuracy, 1.0);
334 }
335
336 #[test]
337 fn test_cyclic_graph() {
338 let evaluator = CausalModelEvaluator::new();
339 let edges = vec![
340 CausalEdgeData {
341 source: "A".to_string(),
342 target: "B".to_string(),
343 expected_sign: 1.0,
344 observed_correlation: 0.5,
345 },
346 CausalEdgeData {
347 source: "B".to_string(),
348 target: "C".to_string(),
349 expected_sign: 1.0,
350 observed_correlation: 0.5,
351 },
352 CausalEdgeData {
353 source: "C".to_string(),
354 target: "A".to_string(), expected_sign: 1.0,
356 observed_correlation: 0.5,
357 },
358 ];
359
360 let result = evaluator.evaluate(&edges, &[]).unwrap();
361 assert!(!result.topological_consistency);
362 assert!(!result.passes);
363 }
364
365 #[test]
366 fn test_wrong_signs() {
367 let evaluator = CausalModelEvaluator::new();
368 let edges = vec![CausalEdgeData {
369 source: "revenue".to_string(),
370 target: "profit".to_string(),
371 expected_sign: 1.0,
372 observed_correlation: -0.5, }];
374
375 let result = evaluator.evaluate(&edges, &[]).unwrap();
376 assert!(!result.passes);
377 assert_eq!(result.edge_correlation_sign_accuracy, 0.0);
378 }
379
380 #[test]
381 fn test_empty() {
382 let evaluator = CausalModelEvaluator::new();
383 let result = evaluator.evaluate(&[], &[]).unwrap();
384 assert!(result.passes);
385 }
386
387 #[test]
388 fn test_intervention_magnitude_within_bounds() {
389 let evaluator = CausalModelEvaluator::new();
390 let edges = vec![CausalEdgeData {
391 source: "price".to_string(),
392 target: "demand".to_string(),
393 expected_sign: -1.0,
394 observed_correlation: -0.6,
395 }];
396 let interventions = vec![
398 InterventionData {
399 intervention_variable: "price".to_string(),
400 expected_direction: -1.0,
401 observed_change: -120.0,
402 target_variable: "demand".to_string(),
403 expected_magnitude: 100.0, pre_intervention_values: vec![500.0, 510.0, 490.0, 505.0, 495.0],
405 post_intervention_values: vec![380.0, 390.0, 370.0, 385.0, 375.0],
406 },
407 InterventionData {
408 intervention_variable: "price".to_string(),
409 expected_direction: -1.0,
410 observed_change: -200.0,
411 target_variable: "demand".to_string(),
412 expected_magnitude: 150.0, pre_intervention_values: vec![600.0, 610.0, 590.0, 605.0, 595.0],
414 post_intervention_values: vec![400.0, 410.0, 390.0, 405.0, 395.0],
415 },
416 InterventionData {
417 intervention_variable: "price".to_string(),
418 expected_direction: -1.0,
419 observed_change: -50.0,
420 target_variable: "demand".to_string(),
421 expected_magnitude: 60.0, pre_intervention_values: vec![300.0, 310.0, 290.0, 305.0, 295.0],
423 post_intervention_values: vec![250.0, 260.0, 240.0, 255.0, 245.0],
424 },
425 ];
426
427 let result = evaluator.evaluate(&edges, &interventions).unwrap();
428 assert_eq!(result.intervention_magnitude_accuracy, 1.0);
429 assert!(result.avg_effect_size > 0.0);
430 assert!(result.passes);
431 }
432
433 #[test]
434 fn test_intervention_magnitude_out_of_bounds() {
435 let evaluator = CausalModelEvaluator::new();
436 let edges = vec![CausalEdgeData {
437 source: "marketing".to_string(),
438 target: "sales".to_string(),
439 expected_sign: 1.0,
440 observed_correlation: 0.7,
441 }];
442 let interventions = vec![
444 InterventionData {
445 intervention_variable: "marketing".to_string(),
446 expected_direction: 1.0,
447 observed_change: 10.0,
448 target_variable: "sales".to_string(),
449 expected_magnitude: 1000.0, pre_intervention_values: vec![100.0, 105.0, 95.0],
451 post_intervention_values: vec![110.0, 115.0, 105.0],
452 },
453 InterventionData {
454 intervention_variable: "marketing".to_string(),
455 expected_direction: 1.0,
456 observed_change: 50000.0,
457 target_variable: "sales".to_string(),
458 expected_magnitude: 100.0, pre_intervention_values: vec![200.0, 210.0, 190.0],
460 post_intervention_values: vec![50200.0, 50210.0, 50190.0],
461 },
462 InterventionData {
463 intervention_variable: "marketing".to_string(),
464 expected_direction: 1.0,
465 observed_change: 5.0,
466 target_variable: "sales".to_string(),
467 expected_magnitude: 500.0, pre_intervention_values: vec![100.0, 105.0, 95.0],
469 post_intervention_values: vec![105.0, 110.0, 100.0],
470 },
471 InterventionData {
472 intervention_variable: "marketing".to_string(),
473 expected_direction: 1.0,
474 observed_change: 150.0,
475 target_variable: "sales".to_string(),
476 expected_magnitude: 100.0, pre_intervention_values: vec![100.0, 105.0, 95.0],
478 post_intervention_values: vec![250.0, 255.0, 245.0],
479 },
480 ];
481
482 let result = evaluator.evaluate(&edges, &interventions).unwrap();
483 assert_eq!(result.intervention_magnitude_accuracy, 0.25);
485 assert!(!result.passes);
486 assert!(result
487 .issues
488 .iter()
489 .any(|i| i.contains("magnitude accuracy")));
490 }
491
492 #[test]
493 fn test_effect_size_computation() {
494 let evaluator = CausalModelEvaluator::new();
495 let interventions = vec![InterventionData {
499 intervention_variable: "treatment".to_string(),
500 expected_direction: 1.0,
501 observed_change: 20.0,
502 target_variable: "outcome".to_string(),
503 expected_magnitude: 20.0,
504 pre_intervention_values: vec![95.0, 100.0, 105.0, 100.0, 100.0],
505 post_intervention_values: vec![115.0, 120.0, 125.0, 120.0, 120.0],
506 }];
507
508 let edges = vec![CausalEdgeData {
515 source: "treatment".to_string(),
516 target: "outcome".to_string(),
517 expected_sign: 1.0,
518 observed_correlation: 0.9,
519 }];
520
521 let result = evaluator.evaluate(&edges, &interventions).unwrap();
522 assert!(result.avg_effect_size > 5.0);
523 assert!((result.avg_effect_size - 5.657).abs() < 0.1);
524
525 let interventions_multi = vec![
527 InterventionData {
528 intervention_variable: "a".to_string(),
529 expected_direction: 1.0,
530 observed_change: 10.0,
531 target_variable: "b".to_string(),
532 expected_magnitude: 10.0,
533 pre_intervention_values: vec![48.0, 50.0, 52.0],
535 post_intervention_values: vec![58.0, 60.0, 62.0],
536 },
537 InterventionData {
538 intervention_variable: "c".to_string(),
539 expected_direction: 1.0,
540 observed_change: 0.1,
541 target_variable: "d".to_string(),
542 expected_magnitude: 0.1,
543 pre_intervention_values: vec![0.0, 0.0, 0.0],
545 post_intervention_values: vec![0.0, 0.0, 0.0],
546 },
547 ];
548
549 let result2 = evaluator.evaluate(&edges, &interventions_multi).unwrap();
550 assert!((result2.avg_effect_size - 5.0).abs() < 0.01);
553 }
554}