Skip to main content

datasynth_eval/causal/
mod.rs

1//! Causal model evaluator.
2//!
3//! Validates causal model preservation including edge correlation sign accuracy,
4//! topological consistency (DAG structure), and intervention effect direction.
5
6use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Causal edge data for validation.
11#[derive(Debug, Clone)]
12pub struct CausalEdgeData {
13    /// Source variable.
14    pub source: String,
15    /// Target variable.
16    pub target: String,
17    /// Expected correlation sign: +1.0 for positive, -1.0 for negative.
18    pub expected_sign: f64,
19    /// Observed correlation between source and target.
20    pub observed_correlation: f64,
21}
22
23/// Intervention data for validation.
24#[derive(Debug, Clone)]
25pub struct InterventionData {
26    /// Variable intervened upon.
27    pub intervention_variable: String,
28    /// Expected effect direction on target: +1.0 for increase, -1.0 for decrease.
29    pub expected_direction: f64,
30    /// Observed change in target.
31    pub observed_change: f64,
32    /// Target variable.
33    pub target_variable: String,
34}
35
36/// Thresholds for causal model evaluation.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CausalThresholds {
39    /// Minimum edge correlation sign accuracy.
40    pub min_sign_accuracy: f64,
41    /// Minimum intervention effect accuracy.
42    pub min_intervention_accuracy: f64,
43}
44
45impl Default for CausalThresholds {
46    fn default() -> Self {
47        Self {
48            min_sign_accuracy: 0.80,
49            min_intervention_accuracy: 0.70,
50        }
51    }
52}
53
54/// Results of causal model evaluation.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CausalModelEvaluation {
57    /// Edge correlation sign accuracy: fraction of edges with correct sign.
58    pub edge_correlation_sign_accuracy: f64,
59    /// Whether the graph is topologically consistent (DAG - no cycles).
60    pub topological_consistency: bool,
61    /// Intervention effect accuracy: fraction with correct direction.
62    pub intervention_effect_accuracy: f64,
63    /// Total edges evaluated.
64    pub total_edges: usize,
65    /// Total interventions evaluated.
66    pub total_interventions: usize,
67    /// Overall pass/fail.
68    pub passes: bool,
69    /// Issues found.
70    pub issues: Vec<String>,
71}
72
73/// Evaluator for causal model preservation.
74pub struct CausalModelEvaluator {
75    thresholds: CausalThresholds,
76}
77
78impl CausalModelEvaluator {
79    /// Create a new evaluator with default thresholds.
80    pub fn new() -> Self {
81        Self {
82            thresholds: CausalThresholds::default(),
83        }
84    }
85
86    /// Create with custom thresholds.
87    pub fn with_thresholds(thresholds: CausalThresholds) -> Self {
88        Self { thresholds }
89    }
90
91    /// Check if the edge set forms a DAG (no cycles) using Kahn's algorithm.
92    fn is_dag(edges: &[CausalEdgeData]) -> bool {
93        let mut in_degree: HashMap<&str, usize> = HashMap::new();
94        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
95
96        // Initialize all nodes
97        for edge in edges {
98            in_degree.entry(edge.source.as_str()).or_insert(0);
99            in_degree.entry(edge.target.as_str()).or_insert(0);
100            adj.entry(edge.source.as_str()).or_default();
101        }
102
103        // Build adjacency and in-degree
104        for edge in edges {
105            adj.entry(edge.source.as_str())
106                .or_default()
107                .push(edge.target.as_str());
108            *in_degree.entry(edge.target.as_str()).or_insert(0) += 1;
109        }
110
111        // Kahn's algorithm
112        let mut queue: VecDeque<&str> = in_degree
113            .iter()
114            .filter(|(_, &d)| d == 0)
115            .map(|(&n, _)| n)
116            .collect();
117        let mut visited = 0usize;
118
119        while let Some(node) = queue.pop_front() {
120            visited += 1;
121            if let Some(neighbors) = adj.get(node) {
122                for &neighbor in neighbors {
123                    if let Some(d) = in_degree.get_mut(neighbor) {
124                        *d -= 1;
125                        if *d == 0 {
126                            queue.push_back(neighbor);
127                        }
128                    }
129                }
130            }
131        }
132
133        visited == in_degree.len()
134    }
135
136    /// Evaluate causal model data.
137    pub fn evaluate(
138        &self,
139        edges: &[CausalEdgeData],
140        interventions: &[InterventionData],
141    ) -> EvalResult<CausalModelEvaluation> {
142        let mut issues = Vec::new();
143
144        // 1. Edge correlation sign accuracy
145        let sign_correct = edges
146            .iter()
147            .filter(|e| {
148                // Signs match: both positive or both negative
149                e.expected_sign * e.observed_correlation > 0.0
150                    || (e.expected_sign.abs() < f64::EPSILON && e.observed_correlation.abs() < 0.05)
151            })
152            .count();
153        let edge_correlation_sign_accuracy = if edges.is_empty() {
154            1.0
155        } else {
156            sign_correct as f64 / edges.len() as f64
157        };
158
159        // 2. Topological consistency (DAG check)
160        let topological_consistency = if edges.is_empty() {
161            true
162        } else {
163            Self::is_dag(edges)
164        };
165
166        // 3. Intervention effect direction
167        let intervention_correct = interventions
168            .iter()
169            .filter(|i| i.expected_direction * i.observed_change > 0.0)
170            .count();
171        let intervention_effect_accuracy = if interventions.is_empty() {
172            1.0
173        } else {
174            intervention_correct as f64 / interventions.len() as f64
175        };
176
177        // Check thresholds
178        if edge_correlation_sign_accuracy < self.thresholds.min_sign_accuracy {
179            issues.push(format!(
180                "Edge sign accuracy {:.3} < {:.3}",
181                edge_correlation_sign_accuracy, self.thresholds.min_sign_accuracy
182            ));
183        }
184        if !topological_consistency {
185            issues.push("Causal graph contains cycles (not a DAG)".to_string());
186        }
187        if intervention_effect_accuracy < self.thresholds.min_intervention_accuracy {
188            issues.push(format!(
189                "Intervention accuracy {:.3} < {:.3}",
190                intervention_effect_accuracy, self.thresholds.min_intervention_accuracy
191            ));
192        }
193
194        let passes = issues.is_empty();
195
196        Ok(CausalModelEvaluation {
197            edge_correlation_sign_accuracy,
198            topological_consistency,
199            intervention_effect_accuracy,
200            total_edges: edges.len(),
201            total_interventions: interventions.len(),
202            passes,
203            issues,
204        })
205    }
206}
207
208impl Default for CausalModelEvaluator {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214#[cfg(test)]
215#[allow(clippy::unwrap_used)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_valid_causal_model() {
221        let evaluator = CausalModelEvaluator::new();
222        let edges = vec![
223            CausalEdgeData {
224                source: "revenue".to_string(),
225                target: "profit".to_string(),
226                expected_sign: 1.0,
227                observed_correlation: 0.85,
228            },
229            CausalEdgeData {
230                source: "cost".to_string(),
231                target: "profit".to_string(),
232                expected_sign: -1.0,
233                observed_correlation: -0.70,
234            },
235        ];
236        let interventions = vec![InterventionData {
237            intervention_variable: "revenue".to_string(),
238            expected_direction: 1.0,
239            observed_change: 5000.0,
240            target_variable: "profit".to_string(),
241        }];
242
243        let result = evaluator.evaluate(&edges, &interventions).unwrap();
244        assert!(result.passes);
245        assert!(result.topological_consistency);
246        assert_eq!(result.edge_correlation_sign_accuracy, 1.0);
247    }
248
249    #[test]
250    fn test_cyclic_graph() {
251        let evaluator = CausalModelEvaluator::new();
252        let edges = vec![
253            CausalEdgeData {
254                source: "A".to_string(),
255                target: "B".to_string(),
256                expected_sign: 1.0,
257                observed_correlation: 0.5,
258            },
259            CausalEdgeData {
260                source: "B".to_string(),
261                target: "C".to_string(),
262                expected_sign: 1.0,
263                observed_correlation: 0.5,
264            },
265            CausalEdgeData {
266                source: "C".to_string(),
267                target: "A".to_string(), // Cycle!
268                expected_sign: 1.0,
269                observed_correlation: 0.5,
270            },
271        ];
272
273        let result = evaluator.evaluate(&edges, &[]).unwrap();
274        assert!(!result.topological_consistency);
275        assert!(!result.passes);
276    }
277
278    #[test]
279    fn test_wrong_signs() {
280        let evaluator = CausalModelEvaluator::new();
281        let edges = vec![CausalEdgeData {
282            source: "revenue".to_string(),
283            target: "profit".to_string(),
284            expected_sign: 1.0,
285            observed_correlation: -0.5, // Wrong sign
286        }];
287
288        let result = evaluator.evaluate(&edges, &[]).unwrap();
289        assert!(!result.passes);
290        assert_eq!(result.edge_correlation_sign_accuracy, 0.0);
291    }
292
293    #[test]
294    fn test_empty() {
295        let evaluator = CausalModelEvaluator::new();
296        let result = evaluator.evaluate(&[], &[]).unwrap();
297        assert!(result.passes);
298    }
299}