1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10#[derive(Debug, Clone)]
12pub struct CausalEdgeData {
13 pub source: String,
15 pub target: String,
17 pub expected_sign: f64,
19 pub observed_correlation: f64,
21}
22
23#[derive(Debug, Clone)]
25pub struct InterventionData {
26 pub intervention_variable: String,
28 pub expected_direction: f64,
30 pub observed_change: f64,
32 pub target_variable: String,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CausalThresholds {
39 pub min_sign_accuracy: f64,
41 pub min_intervention_accuracy: f64,
43}
44
45impl Default for CausalThresholds {
46 fn default() -> Self {
47 Self {
48 min_sign_accuracy: 0.80,
49 min_intervention_accuracy: 0.70,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CausalModelEvaluation {
57 pub edge_correlation_sign_accuracy: f64,
59 pub topological_consistency: bool,
61 pub intervention_effect_accuracy: f64,
63 pub total_edges: usize,
65 pub total_interventions: usize,
67 pub passes: bool,
69 pub issues: Vec<String>,
71}
72
73pub struct CausalModelEvaluator {
75 thresholds: CausalThresholds,
76}
77
78impl CausalModelEvaluator {
79 pub fn new() -> Self {
81 Self {
82 thresholds: CausalThresholds::default(),
83 }
84 }
85
86 pub fn with_thresholds(thresholds: CausalThresholds) -> Self {
88 Self { thresholds }
89 }
90
91 fn is_dag(edges: &[CausalEdgeData]) -> bool {
93 let mut in_degree: HashMap<&str, usize> = HashMap::new();
94 let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
95
96 for edge in edges {
98 in_degree.entry(edge.source.as_str()).or_insert(0);
99 in_degree.entry(edge.target.as_str()).or_insert(0);
100 adj.entry(edge.source.as_str()).or_default();
101 }
102
103 for edge in edges {
105 adj.entry(edge.source.as_str())
106 .or_default()
107 .push(edge.target.as_str());
108 *in_degree.entry(edge.target.as_str()).or_insert(0) += 1;
109 }
110
111 let mut queue: VecDeque<&str> = in_degree
113 .iter()
114 .filter(|(_, &d)| d == 0)
115 .map(|(&n, _)| n)
116 .collect();
117 let mut visited = 0usize;
118
119 while let Some(node) = queue.pop_front() {
120 visited += 1;
121 if let Some(neighbors) = adj.get(node) {
122 for &neighbor in neighbors {
123 if let Some(d) = in_degree.get_mut(neighbor) {
124 *d -= 1;
125 if *d == 0 {
126 queue.push_back(neighbor);
127 }
128 }
129 }
130 }
131 }
132
133 visited == in_degree.len()
134 }
135
136 pub fn evaluate(
138 &self,
139 edges: &[CausalEdgeData],
140 interventions: &[InterventionData],
141 ) -> EvalResult<CausalModelEvaluation> {
142 let mut issues = Vec::new();
143
144 let sign_correct = edges
146 .iter()
147 .filter(|e| {
148 e.expected_sign * e.observed_correlation > 0.0
150 || (e.expected_sign.abs() < f64::EPSILON && e.observed_correlation.abs() < 0.05)
151 })
152 .count();
153 let edge_correlation_sign_accuracy = if edges.is_empty() {
154 1.0
155 } else {
156 sign_correct as f64 / edges.len() as f64
157 };
158
159 let topological_consistency = if edges.is_empty() {
161 true
162 } else {
163 Self::is_dag(edges)
164 };
165
166 let intervention_correct = interventions
168 .iter()
169 .filter(|i| i.expected_direction * i.observed_change > 0.0)
170 .count();
171 let intervention_effect_accuracy = if interventions.is_empty() {
172 1.0
173 } else {
174 intervention_correct as f64 / interventions.len() as f64
175 };
176
177 if edge_correlation_sign_accuracy < self.thresholds.min_sign_accuracy {
179 issues.push(format!(
180 "Edge sign accuracy {:.3} < {:.3}",
181 edge_correlation_sign_accuracy, self.thresholds.min_sign_accuracy
182 ));
183 }
184 if !topological_consistency {
185 issues.push("Causal graph contains cycles (not a DAG)".to_string());
186 }
187 if intervention_effect_accuracy < self.thresholds.min_intervention_accuracy {
188 issues.push(format!(
189 "Intervention accuracy {:.3} < {:.3}",
190 intervention_effect_accuracy, self.thresholds.min_intervention_accuracy
191 ));
192 }
193
194 let passes = issues.is_empty();
195
196 Ok(CausalModelEvaluation {
197 edge_correlation_sign_accuracy,
198 topological_consistency,
199 intervention_effect_accuracy,
200 total_edges: edges.len(),
201 total_interventions: interventions.len(),
202 passes,
203 issues,
204 })
205 }
206}
207
208impl Default for CausalModelEvaluator {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[cfg(test)]
215#[allow(clippy::unwrap_used)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_valid_causal_model() {
221 let evaluator = CausalModelEvaluator::new();
222 let edges = vec![
223 CausalEdgeData {
224 source: "revenue".to_string(),
225 target: "profit".to_string(),
226 expected_sign: 1.0,
227 observed_correlation: 0.85,
228 },
229 CausalEdgeData {
230 source: "cost".to_string(),
231 target: "profit".to_string(),
232 expected_sign: -1.0,
233 observed_correlation: -0.70,
234 },
235 ];
236 let interventions = vec![InterventionData {
237 intervention_variable: "revenue".to_string(),
238 expected_direction: 1.0,
239 observed_change: 5000.0,
240 target_variable: "profit".to_string(),
241 }];
242
243 let result = evaluator.evaluate(&edges, &interventions).unwrap();
244 assert!(result.passes);
245 assert!(result.topological_consistency);
246 assert_eq!(result.edge_correlation_sign_accuracy, 1.0);
247 }
248
249 #[test]
250 fn test_cyclic_graph() {
251 let evaluator = CausalModelEvaluator::new();
252 let edges = vec![
253 CausalEdgeData {
254 source: "A".to_string(),
255 target: "B".to_string(),
256 expected_sign: 1.0,
257 observed_correlation: 0.5,
258 },
259 CausalEdgeData {
260 source: "B".to_string(),
261 target: "C".to_string(),
262 expected_sign: 1.0,
263 observed_correlation: 0.5,
264 },
265 CausalEdgeData {
266 source: "C".to_string(),
267 target: "A".to_string(), expected_sign: 1.0,
269 observed_correlation: 0.5,
270 },
271 ];
272
273 let result = evaluator.evaluate(&edges, &[]).unwrap();
274 assert!(!result.topological_consistency);
275 assert!(!result.passes);
276 }
277
278 #[test]
279 fn test_wrong_signs() {
280 let evaluator = CausalModelEvaluator::new();
281 let edges = vec![CausalEdgeData {
282 source: "revenue".to_string(),
283 target: "profit".to_string(),
284 expected_sign: 1.0,
285 observed_correlation: -0.5, }];
287
288 let result = evaluator.evaluate(&edges, &[]).unwrap();
289 assert!(!result.passes);
290 assert_eq!(result.edge_correlation_sign_accuracy, 0.0);
291 }
292
293 #[test]
294 fn test_empty() {
295 let evaluator = CausalModelEvaluator::new();
296 let result = evaluator.evaluate(&[], &[]).unwrap();
297 assert!(result.passes);
298 }
299}