Skip to main content

datasynth_eval/causal/
mod.rs

1//! Causal model evaluator.
2//!
3//! Validates causal model preservation including edge correlation sign accuracy,
4//! topological consistency (DAG structure), and intervention effect direction.
5
6use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Causal edge data for validation.
11#[derive(Debug, Clone)]
12pub struct CausalEdgeData {
13    /// Source variable.
14    pub source: String,
15    /// Target variable.
16    pub target: String,
17    /// Expected correlation sign: +1.0 for positive, -1.0 for negative.
18    pub expected_sign: f64,
19    /// Observed correlation between source and target.
20    pub observed_correlation: f64,
21}
22
23/// Intervention data for validation.
24#[derive(Debug, Clone)]
25pub struct InterventionData {
26    /// Variable intervened upon.
27    pub intervention_variable: String,
28    /// Expected effect direction on target: +1.0 for increase, -1.0 for decrease.
29    pub expected_direction: f64,
30    /// Observed change in target.
31    pub observed_change: f64,
32    /// Target variable.
33    pub target_variable: String,
34    /// Expected magnitude of the intervention effect.
35    pub expected_magnitude: f64,
36    /// Pre-intervention sample values (for Cohen's d computation).
37    pub pre_intervention_values: Vec<f64>,
38    /// Post-intervention sample values (for Cohen's d computation).
39    pub post_intervention_values: Vec<f64>,
40}
41
42/// Thresholds for causal model evaluation.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CausalThresholds {
45    /// Minimum edge correlation sign accuracy.
46    pub min_sign_accuracy: f64,
47    /// Minimum intervention effect accuracy.
48    pub min_intervention_accuracy: f64,
49    /// Minimum intervention magnitude accuracy (fraction within 0.25x-4.0x bounds).
50    pub min_magnitude_accuracy: f64,
51}
52
53impl Default for CausalThresholds {
54    fn default() -> Self {
55        Self {
56            min_sign_accuracy: 0.80,
57            min_intervention_accuracy: 0.70,
58            min_magnitude_accuracy: 0.60,
59        }
60    }
61}
62
63/// Results of causal model evaluation.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CausalModelEvaluation {
66    /// Edge correlation sign accuracy: fraction of edges with correct sign.
67    pub edge_correlation_sign_accuracy: f64,
68    /// Whether the graph is topologically consistent (DAG - no cycles).
69    pub topological_consistency: bool,
70    /// Intervention effect accuracy: fraction with correct direction.
71    pub intervention_effect_accuracy: f64,
72    /// Fraction of interventions with observed magnitude within 0.25x to 4.0x of expected.
73    pub intervention_magnitude_accuracy: f64,
74    /// Average effect size (Cohen's d) across interventions.
75    pub avg_effect_size: f64,
76    /// Total edges evaluated.
77    pub total_edges: usize,
78    /// Total interventions evaluated.
79    pub total_interventions: usize,
80    /// Overall pass/fail.
81    pub passes: bool,
82    /// Issues found.
83    pub issues: Vec<String>,
84}
85
86/// Evaluator for causal model preservation.
87pub struct CausalModelEvaluator {
88    thresholds: CausalThresholds,
89}
90
91impl CausalModelEvaluator {
92    /// Create a new evaluator with default thresholds.
93    pub fn new() -> Self {
94        Self {
95            thresholds: CausalThresholds::default(),
96        }
97    }
98
99    /// Create with custom thresholds.
100    pub fn with_thresholds(thresholds: CausalThresholds) -> Self {
101        Self { thresholds }
102    }
103
104    /// Check if the edge set forms a DAG (no cycles) using Kahn's algorithm.
105    fn is_dag(edges: &[CausalEdgeData]) -> bool {
106        let mut in_degree: HashMap<&str, usize> = HashMap::new();
107        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
108
109        // Initialize all nodes
110        for edge in edges {
111            in_degree.entry(edge.source.as_str()).or_insert(0);
112            in_degree.entry(edge.target.as_str()).or_insert(0);
113            adj.entry(edge.source.as_str()).or_default();
114        }
115
116        // Build adjacency and in-degree
117        for edge in edges {
118            adj.entry(edge.source.as_str())
119                .or_default()
120                .push(edge.target.as_str());
121            *in_degree.entry(edge.target.as_str()).or_insert(0) += 1;
122        }
123
124        // Kahn's algorithm
125        let mut queue: VecDeque<&str> = in_degree
126            .iter()
127            .filter(|(_, &d)| d == 0)
128            .map(|(&n, _)| n)
129            .collect();
130        let mut visited = 0usize;
131
132        while let Some(node) = queue.pop_front() {
133            visited += 1;
134            if let Some(neighbors) = adj.get(node) {
135                for &neighbor in neighbors {
136                    if let Some(d) = in_degree.get_mut(neighbor) {
137                        *d -= 1;
138                        if *d == 0 {
139                            queue.push_back(neighbor);
140                        }
141                    }
142                }
143            }
144        }
145
146        visited == in_degree.len()
147    }
148
149    /// Compute Cohen's d for a single intervention from pre/post samples.
150    ///
151    /// Cohen's d = |mean_diff| / pooled_std
152    /// where pooled_std = sqrt(((n1-1)*s1^2 + (n2-1)*s2^2) / (n1+n2-2))
153    fn cohens_d(pre: &[f64], post: &[f64]) -> Option<f64> {
154        let n1 = pre.len();
155        let n2 = post.len();
156        if n1 < 2 || n2 < 2 {
157            return None;
158        }
159
160        let mean1 = pre.iter().sum::<f64>() / n1 as f64;
161        let mean2 = post.iter().sum::<f64>() / n2 as f64;
162
163        let var1 = pre.iter().map(|x| (x - mean1).powi(2)).sum::<f64>() / (n1 - 1) as f64;
164        let var2 = post.iter().map(|x| (x - mean2).powi(2)).sum::<f64>() / (n2 - 1) as f64;
165
166        let pooled_var = ((n1 - 1) as f64 * var1 + (n2 - 1) as f64 * var2) / (n1 + n2 - 2) as f64;
167        let pooled_std = pooled_var.sqrt();
168
169        if pooled_std < f64::EPSILON {
170            return None;
171        }
172
173        Some((mean2 - mean1).abs() / pooled_std)
174    }
175
176    /// Compute average Cohen's d across all interventions with sample data.
177    fn compute_avg_effect_size(interventions: &[InterventionData]) -> f64 {
178        let effect_sizes: Vec<f64> = interventions
179            .iter()
180            .filter_map(|i| Self::cohens_d(&i.pre_intervention_values, &i.post_intervention_values))
181            .collect();
182
183        if effect_sizes.is_empty() {
184            0.0
185        } else {
186            effect_sizes.iter().sum::<f64>() / effect_sizes.len() as f64
187        }
188    }
189
190    /// Evaluate causal model data.
191    pub fn evaluate(
192        &self,
193        edges: &[CausalEdgeData],
194        interventions: &[InterventionData],
195    ) -> EvalResult<CausalModelEvaluation> {
196        let mut issues = Vec::new();
197
198        // 1. Edge correlation sign accuracy
199        let sign_correct = edges
200            .iter()
201            .filter(|e| {
202                // Signs match: both positive or both negative
203                e.expected_sign * e.observed_correlation > 0.0
204                    || (e.expected_sign.abs() < f64::EPSILON && e.observed_correlation.abs() < 0.05)
205            })
206            .count();
207        let edge_correlation_sign_accuracy = if edges.is_empty() {
208            1.0
209        } else {
210            sign_correct as f64 / edges.len() as f64
211        };
212
213        // 2. Topological consistency (DAG check)
214        let topological_consistency = if edges.is_empty() {
215            true
216        } else {
217            Self::is_dag(edges)
218        };
219
220        // 3. Intervention effect direction
221        let intervention_correct = interventions
222            .iter()
223            .filter(|i| i.expected_direction * i.observed_change > 0.0)
224            .count();
225        let intervention_effect_accuracy = if interventions.is_empty() {
226            1.0
227        } else {
228            intervention_correct as f64 / interventions.len() as f64
229        };
230
231        // 4. Intervention magnitude accuracy
232        let magnitude_within_bounds = interventions
233            .iter()
234            .filter(|i| {
235                if i.expected_magnitude.abs() < f64::EPSILON {
236                    // Cannot compute ratio when expected magnitude is zero
237                    false
238                } else {
239                    let ratio = i.observed_change.abs() / i.expected_magnitude.abs();
240                    (0.25..=4.0).contains(&ratio)
241                }
242            })
243            .count();
244        let intervention_magnitude_accuracy = if interventions.is_empty() {
245            1.0
246        } else {
247            magnitude_within_bounds as f64 / interventions.len() as f64
248        };
249
250        // 5. Average effect size (Cohen's d)
251        let avg_effect_size = Self::compute_avg_effect_size(interventions);
252
253        // Check thresholds
254        if edge_correlation_sign_accuracy < self.thresholds.min_sign_accuracy {
255            issues.push(format!(
256                "Edge sign accuracy {:.3} < {:.3}",
257                edge_correlation_sign_accuracy, self.thresholds.min_sign_accuracy
258            ));
259        }
260        if !topological_consistency {
261            issues.push("Causal graph contains cycles (not a DAG)".to_string());
262        }
263        if intervention_effect_accuracy < self.thresholds.min_intervention_accuracy {
264            issues.push(format!(
265                "Intervention accuracy {:.3} < {:.3}",
266                intervention_effect_accuracy, self.thresholds.min_intervention_accuracy
267            ));
268        }
269        if intervention_magnitude_accuracy < self.thresholds.min_magnitude_accuracy {
270            issues.push(format!(
271                "Intervention magnitude accuracy {:.3} < {:.3}",
272                intervention_magnitude_accuracy, self.thresholds.min_magnitude_accuracy
273            ));
274        }
275
276        let passes = issues.is_empty();
277
278        Ok(CausalModelEvaluation {
279            edge_correlation_sign_accuracy,
280            topological_consistency,
281            intervention_effect_accuracy,
282            intervention_magnitude_accuracy,
283            avg_effect_size,
284            total_edges: edges.len(),
285            total_interventions: interventions.len(),
286            passes,
287            issues,
288        })
289    }
290}
291
292impl Default for CausalModelEvaluator {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298#[cfg(test)]
299#[allow(clippy::unwrap_used)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_valid_causal_model() {
305        let evaluator = CausalModelEvaluator::new();
306        let edges = vec![
307            CausalEdgeData {
308                source: "revenue".to_string(),
309                target: "profit".to_string(),
310                expected_sign: 1.0,
311                observed_correlation: 0.85,
312            },
313            CausalEdgeData {
314                source: "cost".to_string(),
315                target: "profit".to_string(),
316                expected_sign: -1.0,
317                observed_correlation: -0.70,
318            },
319        ];
320        let interventions = vec![InterventionData {
321            intervention_variable: "revenue".to_string(),
322            expected_direction: 1.0,
323            observed_change: 5000.0,
324            target_variable: "profit".to_string(),
325            expected_magnitude: 5000.0,
326            pre_intervention_values: vec![100.0, 110.0, 105.0, 95.0, 108.0],
327            post_intervention_values: vec![200.0, 210.0, 205.0, 195.0, 208.0],
328        }];
329
330        let result = evaluator.evaluate(&edges, &interventions).unwrap();
331        assert!(result.passes);
332        assert!(result.topological_consistency);
333        assert_eq!(result.edge_correlation_sign_accuracy, 1.0);
334    }
335
336    #[test]
337    fn test_cyclic_graph() {
338        let evaluator = CausalModelEvaluator::new();
339        let edges = vec![
340            CausalEdgeData {
341                source: "A".to_string(),
342                target: "B".to_string(),
343                expected_sign: 1.0,
344                observed_correlation: 0.5,
345            },
346            CausalEdgeData {
347                source: "B".to_string(),
348                target: "C".to_string(),
349                expected_sign: 1.0,
350                observed_correlation: 0.5,
351            },
352            CausalEdgeData {
353                source: "C".to_string(),
354                target: "A".to_string(), // Cycle!
355                expected_sign: 1.0,
356                observed_correlation: 0.5,
357            },
358        ];
359
360        let result = evaluator.evaluate(&edges, &[]).unwrap();
361        assert!(!result.topological_consistency);
362        assert!(!result.passes);
363    }
364
365    #[test]
366    fn test_wrong_signs() {
367        let evaluator = CausalModelEvaluator::new();
368        let edges = vec![CausalEdgeData {
369            source: "revenue".to_string(),
370            target: "profit".to_string(),
371            expected_sign: 1.0,
372            observed_correlation: -0.5, // Wrong sign
373        }];
374
375        let result = evaluator.evaluate(&edges, &[]).unwrap();
376        assert!(!result.passes);
377        assert_eq!(result.edge_correlation_sign_accuracy, 0.0);
378    }
379
380    #[test]
381    fn test_empty() {
382        let evaluator = CausalModelEvaluator::new();
383        let result = evaluator.evaluate(&[], &[]).unwrap();
384        assert!(result.passes);
385    }
386
387    #[test]
388    fn test_intervention_magnitude_within_bounds() {
389        let evaluator = CausalModelEvaluator::new();
390        let edges = vec![CausalEdgeData {
391            source: "price".to_string(),
392            target: "demand".to_string(),
393            expected_sign: -1.0,
394            observed_correlation: -0.6,
395        }];
396        // All interventions have observed magnitude within 0.25x to 4.0x of expected
397        let interventions = vec![
398            InterventionData {
399                intervention_variable: "price".to_string(),
400                expected_direction: -1.0,
401                observed_change: -120.0,
402                target_variable: "demand".to_string(),
403                expected_magnitude: 100.0, // ratio = 1.2, within [0.25, 4.0]
404                pre_intervention_values: vec![500.0, 510.0, 490.0, 505.0, 495.0],
405                post_intervention_values: vec![380.0, 390.0, 370.0, 385.0, 375.0],
406            },
407            InterventionData {
408                intervention_variable: "price".to_string(),
409                expected_direction: -1.0,
410                observed_change: -200.0,
411                target_variable: "demand".to_string(),
412                expected_magnitude: 150.0, // ratio = 1.33, within [0.25, 4.0]
413                pre_intervention_values: vec![600.0, 610.0, 590.0, 605.0, 595.0],
414                post_intervention_values: vec![400.0, 410.0, 390.0, 405.0, 395.0],
415            },
416            InterventionData {
417                intervention_variable: "price".to_string(),
418                expected_direction: -1.0,
419                observed_change: -50.0,
420                target_variable: "demand".to_string(),
421                expected_magnitude: 60.0, // ratio = 0.83, within [0.25, 4.0]
422                pre_intervention_values: vec![300.0, 310.0, 290.0, 305.0, 295.0],
423                post_intervention_values: vec![250.0, 260.0, 240.0, 255.0, 245.0],
424            },
425        ];
426
427        let result = evaluator.evaluate(&edges, &interventions).unwrap();
428        assert_eq!(result.intervention_magnitude_accuracy, 1.0);
429        assert!(result.avg_effect_size > 0.0);
430        assert!(result.passes);
431    }
432
433    #[test]
434    fn test_intervention_magnitude_out_of_bounds() {
435        let evaluator = CausalModelEvaluator::new();
436        let edges = vec![CausalEdgeData {
437            source: "marketing".to_string(),
438            target: "sales".to_string(),
439            expected_sign: 1.0,
440            observed_correlation: 0.7,
441        }];
442        // Most interventions have extreme magnitudes (outside 0.25x to 4.0x)
443        let interventions = vec![
444            InterventionData {
445                intervention_variable: "marketing".to_string(),
446                expected_direction: 1.0,
447                observed_change: 10.0,
448                target_variable: "sales".to_string(),
449                expected_magnitude: 1000.0, // ratio = 0.01, below 0.25
450                pre_intervention_values: vec![100.0, 105.0, 95.0],
451                post_intervention_values: vec![110.0, 115.0, 105.0],
452            },
453            InterventionData {
454                intervention_variable: "marketing".to_string(),
455                expected_direction: 1.0,
456                observed_change: 50000.0,
457                target_variable: "sales".to_string(),
458                expected_magnitude: 100.0, // ratio = 500.0, above 4.0
459                pre_intervention_values: vec![200.0, 210.0, 190.0],
460                post_intervention_values: vec![50200.0, 50210.0, 50190.0],
461            },
462            InterventionData {
463                intervention_variable: "marketing".to_string(),
464                expected_direction: 1.0,
465                observed_change: 5.0,
466                target_variable: "sales".to_string(),
467                expected_magnitude: 500.0, // ratio = 0.01, below 0.25
468                pre_intervention_values: vec![100.0, 105.0, 95.0],
469                post_intervention_values: vec![105.0, 110.0, 100.0],
470            },
471            InterventionData {
472                intervention_variable: "marketing".to_string(),
473                expected_direction: 1.0,
474                observed_change: 150.0,
475                target_variable: "sales".to_string(),
476                expected_magnitude: 100.0, // ratio = 1.5, within bounds (the one pass)
477                pre_intervention_values: vec![100.0, 105.0, 95.0],
478                post_intervention_values: vec![250.0, 255.0, 245.0],
479            },
480        ];
481
482        let result = evaluator.evaluate(&edges, &interventions).unwrap();
483        // Only 1 out of 4 is within bounds => 0.25 < 0.60 (default threshold)
484        assert_eq!(result.intervention_magnitude_accuracy, 0.25);
485        assert!(!result.passes);
486        assert!(result
487            .issues
488            .iter()
489            .any(|i| i.contains("magnitude accuracy")));
490    }
491
492    #[test]
493    fn test_effect_size_computation() {
494        let evaluator = CausalModelEvaluator::new();
495        // Create intervention with known pre/post values for Cohen's d verification.
496        // Pre: mean=100, Post: mean=120, pooled_std should be ~5.0
497        // Cohen's d = |120 - 100| / 5.0 = 4.0
498        let interventions = vec![InterventionData {
499            intervention_variable: "treatment".to_string(),
500            expected_direction: 1.0,
501            observed_change: 20.0,
502            target_variable: "outcome".to_string(),
503            expected_magnitude: 20.0,
504            pre_intervention_values: vec![95.0, 100.0, 105.0, 100.0, 100.0],
505            post_intervention_values: vec![115.0, 120.0, 125.0, 120.0, 120.0],
506        }];
507
508        // Manually compute expected Cohen's d:
509        // pre: mean=100, var = ((25+0+25+0+0)/4) = 12.5, std = 3.536
510        // post: mean=120, var = ((25+0+25+0+0)/4) = 12.5, std = 3.536
511        // pooled_var = ((4*12.5 + 4*12.5) / 8) = 12.5
512        // pooled_std = sqrt(12.5) = 3.536
513        // Cohen's d = |120-100| / 3.536 = 5.657
514        let edges = vec![CausalEdgeData {
515            source: "treatment".to_string(),
516            target: "outcome".to_string(),
517            expected_sign: 1.0,
518            observed_correlation: 0.9,
519        }];
520
521        let result = evaluator.evaluate(&edges, &interventions).unwrap();
522        assert!(result.avg_effect_size > 5.0);
523        assert!((result.avg_effect_size - 5.657).abs() < 0.1);
524
525        // Also test with multiple interventions
526        let interventions_multi = vec![
527            InterventionData {
528                intervention_variable: "a".to_string(),
529                expected_direction: 1.0,
530                observed_change: 10.0,
531                target_variable: "b".to_string(),
532                expected_magnitude: 10.0,
533                // pre mean=50, post mean=60, same variance => d = 10/std
534                pre_intervention_values: vec![48.0, 50.0, 52.0],
535                post_intervention_values: vec![58.0, 60.0, 62.0],
536            },
537            InterventionData {
538                intervention_variable: "c".to_string(),
539                expected_direction: 1.0,
540                observed_change: 0.1,
541                target_variable: "d".to_string(),
542                expected_magnitude: 0.1,
543                // pre mean=0, post mean=0 with same std => d ≈ 0
544                pre_intervention_values: vec![0.0, 0.0, 0.0],
545                post_intervention_values: vec![0.0, 0.0, 0.0],
546            },
547        ];
548
549        let result2 = evaluator.evaluate(&edges, &interventions_multi).unwrap();
550        // Second intervention has zero pooled_std, so only first contributes
551        // For first: pre var = 4.0, post var = 4.0, pooled_std = 2.0, d = 10/2 = 5.0
552        assert!((result2.avg_effect_size - 5.0).abs() < 0.01);
553    }
554}