Skip to main content

datasynth_eval/causal/
mod.rs

1//! Causal model evaluator.
2//!
3//! Validates causal model preservation including edge correlation sign accuracy,
4//! topological consistency (DAG structure), and intervention effect direction.
5
6use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Causal edge data for validation.
11#[derive(Debug, Clone)]
12pub struct CausalEdgeData {
13    /// Source variable.
14    pub source: String,
15    /// Target variable.
16    pub target: String,
17    /// Expected correlation sign: +1.0 for positive, -1.0 for negative.
18    pub expected_sign: f64,
19    /// Observed correlation between source and target.
20    pub observed_correlation: f64,
21}
22
23/// Intervention data for validation.
24#[derive(Debug, Clone)]
25pub struct InterventionData {
26    /// Variable intervened upon.
27    pub intervention_variable: String,
28    /// Expected effect direction on target: +1.0 for increase, -1.0 for decrease.
29    pub expected_direction: f64,
30    /// Observed change in target.
31    pub observed_change: f64,
32    /// Target variable.
33    pub target_variable: String,
34    /// Expected magnitude of the intervention effect.
35    pub expected_magnitude: f64,
36    /// Pre-intervention sample values (for Cohen's d computation).
37    pub pre_intervention_values: Vec<f64>,
38    /// Post-intervention sample values (for Cohen's d computation).
39    pub post_intervention_values: Vec<f64>,
40}
41
42/// Thresholds for causal model evaluation.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CausalThresholds {
45    /// Minimum edge correlation sign accuracy.
46    pub min_sign_accuracy: f64,
47    /// Minimum intervention effect accuracy.
48    pub min_intervention_accuracy: f64,
49    /// Minimum intervention magnitude accuracy (fraction within 0.25x-4.0x bounds).
50    pub min_magnitude_accuracy: f64,
51}
52
53impl Default for CausalThresholds {
54    fn default() -> Self {
55        Self {
56            min_sign_accuracy: 0.80,
57            min_intervention_accuracy: 0.70,
58            min_magnitude_accuracy: 0.60,
59        }
60    }
61}
62
63/// Results of causal model evaluation.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CausalModelEvaluation {
66    /// Edge correlation sign accuracy: fraction of edges with correct sign.
67    pub edge_correlation_sign_accuracy: f64,
68    /// Whether the graph is topologically consistent (DAG - no cycles).
69    pub topological_consistency: bool,
70    /// Intervention effect accuracy: fraction with correct direction.
71    pub intervention_effect_accuracy: f64,
72    /// Fraction of interventions with observed magnitude within 0.25x to 4.0x of expected.
73    pub intervention_magnitude_accuracy: f64,
74    /// Average effect size (Cohen's d) across interventions.
75    pub avg_effect_size: f64,
76    /// Total edges evaluated.
77    pub total_edges: usize,
78    /// Total interventions evaluated.
79    pub total_interventions: usize,
80    /// Overall pass/fail.
81    pub passes: bool,
82    /// Issues found.
83    pub issues: Vec<String>,
84}
85
86/// Evaluator for causal model preservation.
87pub struct CausalModelEvaluator {
88    thresholds: CausalThresholds,
89}
90
91impl CausalModelEvaluator {
92    /// Create a new evaluator with default thresholds.
93    pub fn new() -> Self {
94        Self {
95            thresholds: CausalThresholds::default(),
96        }
97    }
98
99    /// Create with custom thresholds.
100    pub fn with_thresholds(thresholds: CausalThresholds) -> Self {
101        Self { thresholds }
102    }
103
104    /// Check if the edge set forms a DAG (no cycles) using Kahn's algorithm.
105    fn is_dag(edges: &[CausalEdgeData]) -> bool {
106        let mut in_degree: HashMap<&str, usize> = HashMap::new();
107        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
108
109        // Initialize all nodes
110        for edge in edges {
111            in_degree.entry(edge.source.as_str()).or_insert(0);
112            in_degree.entry(edge.target.as_str()).or_insert(0);
113            adj.entry(edge.source.as_str()).or_default();
114        }
115
116        // Build adjacency and in-degree
117        for edge in edges {
118            adj.entry(edge.source.as_str())
119                .or_default()
120                .push(edge.target.as_str());
121            *in_degree.entry(edge.target.as_str()).or_insert(0) += 1;
122        }
123
124        // Kahn's algorithm
125        let mut queue: VecDeque<&str> = in_degree
126            .iter()
127            .filter(|(_, &d)| d == 0)
128            .map(|(&n, _)| n)
129            .collect();
130        let mut visited = 0usize;
131
132        while let Some(node) = queue.pop_front() {
133            visited += 1;
134            if let Some(neighbors) = adj.get(node) {
135                for &neighbor in neighbors {
136                    if let Some(d) = in_degree.get_mut(neighbor) {
137                        *d -= 1;
138                        if *d == 0 {
139                            queue.push_back(neighbor);
140                        }
141                    }
142                }
143            }
144        }
145
146        visited == in_degree.len()
147    }
148
149    /// Compute Cohen's d for a single intervention from pre/post samples.
150    ///
151    /// Cohen's d = |mean_diff| / pooled_std
152    /// where pooled_std = sqrt(((n1-1)*s1^2 + (n2-1)*s2^2) / (n1+n2-2))
153    fn cohens_d(pre: &[f64], post: &[f64]) -> Option<f64> {
154        let n1 = pre.len();
155        let n2 = post.len();
156        if n1 < 2 || n2 < 2 {
157            return None;
158        }
159
160        let mean1 = pre.iter().sum::<f64>() / n1 as f64;
161        let mean2 = post.iter().sum::<f64>() / n2 as f64;
162
163        let var1 = pre.iter().map(|x| (x - mean1).powi(2)).sum::<f64>() / (n1 - 1) as f64;
164        let var2 = post.iter().map(|x| (x - mean2).powi(2)).sum::<f64>() / (n2 - 1) as f64;
165
166        let pooled_var = ((n1 - 1) as f64 * var1 + (n2 - 1) as f64 * var2) / (n1 + n2 - 2) as f64;
167        let pooled_std = pooled_var.sqrt();
168
169        if pooled_std < f64::EPSILON {
170            return None;
171        }
172
173        Some((mean2 - mean1).abs() / pooled_std)
174    }
175
176    /// Compute average Cohen's d across all interventions with sample data.
177    fn compute_avg_effect_size(interventions: &[InterventionData]) -> f64 {
178        let effect_sizes: Vec<f64> = interventions
179            .iter()
180            .filter_map(|i| Self::cohens_d(&i.pre_intervention_values, &i.post_intervention_values))
181            .collect();
182
183        if effect_sizes.is_empty() {
184            0.0
185        } else {
186            effect_sizes.iter().sum::<f64>() / effect_sizes.len() as f64
187        }
188    }
189
190    /// Evaluate causal model data.
191    pub fn evaluate(
192        &self,
193        edges: &[CausalEdgeData],
194        interventions: &[InterventionData],
195    ) -> EvalResult<CausalModelEvaluation> {
196        let mut issues = Vec::new();
197
198        // 1. Edge correlation sign accuracy
199        let sign_correct = edges
200            .iter()
201            .filter(|e| {
202                // Signs match: both positive or both negative
203                e.expected_sign * e.observed_correlation > 0.0
204                    || (e.expected_sign.abs() < f64::EPSILON && e.observed_correlation.abs() < 0.05)
205            })
206            .count();
207        let edge_correlation_sign_accuracy = if edges.is_empty() {
208            1.0
209        } else {
210            sign_correct as f64 / edges.len() as f64
211        };
212
213        // 2. Topological consistency (DAG check)
214        let topological_consistency = if edges.is_empty() {
215            true
216        } else {
217            Self::is_dag(edges)
218        };
219
220        // 3. Intervention effect direction
221        let intervention_correct = interventions
222            .iter()
223            .filter(|i| i.expected_direction * i.observed_change > 0.0)
224            .count();
225        let intervention_effect_accuracy = if interventions.is_empty() {
226            1.0
227        } else {
228            intervention_correct as f64 / interventions.len() as f64
229        };
230
231        // 4. Intervention magnitude accuracy
232        let magnitude_within_bounds = interventions
233            .iter()
234            .filter(|i| {
235                if i.expected_magnitude.abs() < f64::EPSILON {
236                    // Cannot compute ratio when expected magnitude is zero
237                    false
238                } else {
239                    let ratio = i.observed_change.abs() / i.expected_magnitude.abs();
240                    (0.25..=4.0).contains(&ratio)
241                }
242            })
243            .count();
244        let intervention_magnitude_accuracy = if interventions.is_empty() {
245            1.0
246        } else {
247            magnitude_within_bounds as f64 / interventions.len() as f64
248        };
249
250        // 5. Average effect size (Cohen's d)
251        let avg_effect_size = Self::compute_avg_effect_size(interventions);
252
253        // Check thresholds
254        if edge_correlation_sign_accuracy < self.thresholds.min_sign_accuracy {
255            issues.push(format!(
256                "Edge sign accuracy {:.3} < {:.3}",
257                edge_correlation_sign_accuracy, self.thresholds.min_sign_accuracy
258            ));
259        }
260        if !topological_consistency {
261            issues.push("Causal graph contains cycles (not a DAG)".to_string());
262        }
263        if intervention_effect_accuracy < self.thresholds.min_intervention_accuracy {
264            issues.push(format!(
265                "Intervention accuracy {:.3} < {:.3}",
266                intervention_effect_accuracy, self.thresholds.min_intervention_accuracy
267            ));
268        }
269        if intervention_magnitude_accuracy < self.thresholds.min_magnitude_accuracy {
270            issues.push(format!(
271                "Intervention magnitude accuracy {:.3} < {:.3}",
272                intervention_magnitude_accuracy, self.thresholds.min_magnitude_accuracy
273            ));
274        }
275
276        let passes = issues.is_empty();
277
278        Ok(CausalModelEvaluation {
279            edge_correlation_sign_accuracy,
280            topological_consistency,
281            intervention_effect_accuracy,
282            intervention_magnitude_accuracy,
283            avg_effect_size,
284            total_edges: edges.len(),
285            total_interventions: interventions.len(),
286            passes,
287            issues,
288        })
289    }
290}
291
292impl Default for CausalModelEvaluator {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_valid_causal_model() {
304        let evaluator = CausalModelEvaluator::new();
305        let edges = vec![
306            CausalEdgeData {
307                source: "revenue".to_string(),
308                target: "profit".to_string(),
309                expected_sign: 1.0,
310                observed_correlation: 0.85,
311            },
312            CausalEdgeData {
313                source: "cost".to_string(),
314                target: "profit".to_string(),
315                expected_sign: -1.0,
316                observed_correlation: -0.70,
317            },
318        ];
319        let interventions = vec![InterventionData {
320            intervention_variable: "revenue".to_string(),
321            expected_direction: 1.0,
322            observed_change: 5000.0,
323            target_variable: "profit".to_string(),
324            expected_magnitude: 5000.0,
325            pre_intervention_values: vec![100.0, 110.0, 105.0, 95.0, 108.0],
326            post_intervention_values: vec![200.0, 210.0, 205.0, 195.0, 208.0],
327        }];
328
329        let result = evaluator.evaluate(&edges, &interventions).unwrap();
330        assert!(result.passes);
331        assert!(result.topological_consistency);
332        assert_eq!(result.edge_correlation_sign_accuracy, 1.0);
333    }
334
335    #[test]
336    fn test_cyclic_graph() {
337        let evaluator = CausalModelEvaluator::new();
338        let edges = vec![
339            CausalEdgeData {
340                source: "A".to_string(),
341                target: "B".to_string(),
342                expected_sign: 1.0,
343                observed_correlation: 0.5,
344            },
345            CausalEdgeData {
346                source: "B".to_string(),
347                target: "C".to_string(),
348                expected_sign: 1.0,
349                observed_correlation: 0.5,
350            },
351            CausalEdgeData {
352                source: "C".to_string(),
353                target: "A".to_string(), // Cycle!
354                expected_sign: 1.0,
355                observed_correlation: 0.5,
356            },
357        ];
358
359        let result = evaluator.evaluate(&edges, &[]).unwrap();
360        assert!(!result.topological_consistency);
361        assert!(!result.passes);
362    }
363
364    #[test]
365    fn test_wrong_signs() {
366        let evaluator = CausalModelEvaluator::new();
367        let edges = vec![CausalEdgeData {
368            source: "revenue".to_string(),
369            target: "profit".to_string(),
370            expected_sign: 1.0,
371            observed_correlation: -0.5, // Wrong sign
372        }];
373
374        let result = evaluator.evaluate(&edges, &[]).unwrap();
375        assert!(!result.passes);
376        assert_eq!(result.edge_correlation_sign_accuracy, 0.0);
377    }
378
379    #[test]
380    fn test_empty() {
381        let evaluator = CausalModelEvaluator::new();
382        let result = evaluator.evaluate(&[], &[]).unwrap();
383        assert!(result.passes);
384    }
385
386    #[test]
387    fn test_intervention_magnitude_within_bounds() {
388        let evaluator = CausalModelEvaluator::new();
389        let edges = vec![CausalEdgeData {
390            source: "price".to_string(),
391            target: "demand".to_string(),
392            expected_sign: -1.0,
393            observed_correlation: -0.6,
394        }];
395        // All interventions have observed magnitude within 0.25x to 4.0x of expected
396        let interventions = vec![
397            InterventionData {
398                intervention_variable: "price".to_string(),
399                expected_direction: -1.0,
400                observed_change: -120.0,
401                target_variable: "demand".to_string(),
402                expected_magnitude: 100.0, // ratio = 1.2, within [0.25, 4.0]
403                pre_intervention_values: vec![500.0, 510.0, 490.0, 505.0, 495.0],
404                post_intervention_values: vec![380.0, 390.0, 370.0, 385.0, 375.0],
405            },
406            InterventionData {
407                intervention_variable: "price".to_string(),
408                expected_direction: -1.0,
409                observed_change: -200.0,
410                target_variable: "demand".to_string(),
411                expected_magnitude: 150.0, // ratio = 1.33, within [0.25, 4.0]
412                pre_intervention_values: vec![600.0, 610.0, 590.0, 605.0, 595.0],
413                post_intervention_values: vec![400.0, 410.0, 390.0, 405.0, 395.0],
414            },
415            InterventionData {
416                intervention_variable: "price".to_string(),
417                expected_direction: -1.0,
418                observed_change: -50.0,
419                target_variable: "demand".to_string(),
420                expected_magnitude: 60.0, // ratio = 0.83, within [0.25, 4.0]
421                pre_intervention_values: vec![300.0, 310.0, 290.0, 305.0, 295.0],
422                post_intervention_values: vec![250.0, 260.0, 240.0, 255.0, 245.0],
423            },
424        ];
425
426        let result = evaluator.evaluate(&edges, &interventions).unwrap();
427        assert_eq!(result.intervention_magnitude_accuracy, 1.0);
428        assert!(result.avg_effect_size > 0.0);
429        assert!(result.passes);
430    }
431
432    #[test]
433    fn test_intervention_magnitude_out_of_bounds() {
434        let evaluator = CausalModelEvaluator::new();
435        let edges = vec![CausalEdgeData {
436            source: "marketing".to_string(),
437            target: "sales".to_string(),
438            expected_sign: 1.0,
439            observed_correlation: 0.7,
440        }];
441        // Most interventions have extreme magnitudes (outside 0.25x to 4.0x)
442        let interventions = vec![
443            InterventionData {
444                intervention_variable: "marketing".to_string(),
445                expected_direction: 1.0,
446                observed_change: 10.0,
447                target_variable: "sales".to_string(),
448                expected_magnitude: 1000.0, // ratio = 0.01, below 0.25
449                pre_intervention_values: vec![100.0, 105.0, 95.0],
450                post_intervention_values: vec![110.0, 115.0, 105.0],
451            },
452            InterventionData {
453                intervention_variable: "marketing".to_string(),
454                expected_direction: 1.0,
455                observed_change: 50000.0,
456                target_variable: "sales".to_string(),
457                expected_magnitude: 100.0, // ratio = 500.0, above 4.0
458                pre_intervention_values: vec![200.0, 210.0, 190.0],
459                post_intervention_values: vec![50200.0, 50210.0, 50190.0],
460            },
461            InterventionData {
462                intervention_variable: "marketing".to_string(),
463                expected_direction: 1.0,
464                observed_change: 5.0,
465                target_variable: "sales".to_string(),
466                expected_magnitude: 500.0, // ratio = 0.01, below 0.25
467                pre_intervention_values: vec![100.0, 105.0, 95.0],
468                post_intervention_values: vec![105.0, 110.0, 100.0],
469            },
470            InterventionData {
471                intervention_variable: "marketing".to_string(),
472                expected_direction: 1.0,
473                observed_change: 150.0,
474                target_variable: "sales".to_string(),
475                expected_magnitude: 100.0, // ratio = 1.5, within bounds (the one pass)
476                pre_intervention_values: vec![100.0, 105.0, 95.0],
477                post_intervention_values: vec![250.0, 255.0, 245.0],
478            },
479        ];
480
481        let result = evaluator.evaluate(&edges, &interventions).unwrap();
482        // Only 1 out of 4 is within bounds => 0.25 < 0.60 (default threshold)
483        assert_eq!(result.intervention_magnitude_accuracy, 0.25);
484        assert!(!result.passes);
485        assert!(result
486            .issues
487            .iter()
488            .any(|i| i.contains("magnitude accuracy")));
489    }
490
491    #[test]
492    fn test_effect_size_computation() {
493        let evaluator = CausalModelEvaluator::new();
494        // Create intervention with known pre/post values for Cohen's d verification.
495        // Pre: mean=100, Post: mean=120, pooled_std should be ~5.0
496        // Cohen's d = |120 - 100| / 5.0 = 4.0
497        let interventions = vec![InterventionData {
498            intervention_variable: "treatment".to_string(),
499            expected_direction: 1.0,
500            observed_change: 20.0,
501            target_variable: "outcome".to_string(),
502            expected_magnitude: 20.0,
503            pre_intervention_values: vec![95.0, 100.0, 105.0, 100.0, 100.0],
504            post_intervention_values: vec![115.0, 120.0, 125.0, 120.0, 120.0],
505        }];
506
507        // Manually compute expected Cohen's d:
508        // pre: mean=100, var = ((25+0+25+0+0)/4) = 12.5, std = 3.536
509        // post: mean=120, var = ((25+0+25+0+0)/4) = 12.5, std = 3.536
510        // pooled_var = ((4*12.5 + 4*12.5) / 8) = 12.5
511        // pooled_std = sqrt(12.5) = 3.536
512        // Cohen's d = |120-100| / 3.536 = 5.657
513        let edges = vec![CausalEdgeData {
514            source: "treatment".to_string(),
515            target: "outcome".to_string(),
516            expected_sign: 1.0,
517            observed_correlation: 0.9,
518        }];
519
520        let result = evaluator.evaluate(&edges, &interventions).unwrap();
521        assert!(result.avg_effect_size > 5.0);
522        assert!((result.avg_effect_size - 5.657).abs() < 0.1);
523
524        // Also test with multiple interventions
525        let interventions_multi = vec![
526            InterventionData {
527                intervention_variable: "a".to_string(),
528                expected_direction: 1.0,
529                observed_change: 10.0,
530                target_variable: "b".to_string(),
531                expected_magnitude: 10.0,
532                // pre mean=50, post mean=60, same variance => d = 10/std
533                pre_intervention_values: vec![48.0, 50.0, 52.0],
534                post_intervention_values: vec![58.0, 60.0, 62.0],
535            },
536            InterventionData {
537                intervention_variable: "c".to_string(),
538                expected_direction: 1.0,
539                observed_change: 0.1,
540                target_variable: "d".to_string(),
541                expected_magnitude: 0.1,
542                // pre mean=0, post mean=0 with same std => d ≈ 0
543                pre_intervention_values: vec![0.0, 0.0, 0.0],
544                post_intervention_values: vec![0.0, 0.0, 0.0],
545            },
546        ];
547
548        let result2 = evaluator.evaluate(&edges, &interventions_multi).unwrap();
549        // Second intervention has zero pooled_std, so only first contributes
550        // For first: pre var = 4.0, post var = 4.0, pooled_std = 2.0, d = 10/2 = 5.0
551        assert!((result2.avg_effect_size - 5.0).abs() < 0.01);
552    }
553}