1use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13 pub valid: bool,
15 pub checks: Vec<CausalCheck>,
17 pub violations: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CausalCheck {
24 pub name: String,
26 pub passed: bool,
28 pub details: String,
30}
31
32pub struct CausalValidator;
34
35impl CausalValidator {
36 pub fn validate_causal_structure(
43 samples: &[HashMap<String, f64>],
44 graph: &CausalGraph,
45 ) -> CausalValidationReport {
46 let mut checks = Vec::new();
47 let mut violations = Vec::new();
48
49 let sign_check = Self::check_edge_correlation_signs(samples, graph);
51 if !sign_check.passed {
52 violations.push(sign_check.details.clone());
53 }
54 checks.push(sign_check);
55
56 let strength_check = Self::check_non_edge_weakness(samples, graph);
58 if !strength_check.passed {
59 violations.push(strength_check.details.clone());
60 }
61 checks.push(strength_check);
62
63 let topo_check = Self::check_topological_consistency(samples, graph);
65 if !topo_check.passed {
66 violations.push(topo_check.details.clone());
67 }
68 checks.push(topo_check);
69
70 let valid = checks.iter().all(|c| c.passed);
71
72 CausalValidationReport {
73 valid,
74 checks,
75 violations,
76 }
77 }
78
79 fn check_edge_correlation_signs(
82 samples: &[HashMap<String, f64>],
83 graph: &CausalGraph,
84 ) -> CausalCheck {
85 let mut total_edges = 0;
86 let mut correct_signs = 0u32;
87 let mut mismatches = Vec::new();
88
89 for edge in &graph.edges {
90 let expected_sign = Self::mechanism_sign(&edge.mechanism);
91 if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95 continue;
96 }
97
98 total_edges += 1;
99
100 let parent_vals: Vec<f64> = samples
101 .iter()
102 .filter_map(|s| s.get(&edge.from).copied())
103 .collect();
104 let child_vals: Vec<f64> = samples
105 .iter()
106 .filter_map(|s| s.get(&edge.to).copied())
107 .collect();
108
109 let corr = pearson_correlation(&parent_vals, &child_vals);
110
111 if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112 correct_signs += 1;
113 } else {
114 mismatches.push(format!(
115 "{} -> {}: expected sign {}, got correlation {:.4}",
116 edge.from, edge.to, expected_sign, corr
117 ));
118 }
119 }
120
121 let passed = mismatches.is_empty();
122 let details = if passed {
123 format!(
124 "All {}/{} edges have correct correlation signs",
125 correct_signs, total_edges
126 )
127 } else {
128 format!(
129 "{}/{} edges have incorrect signs: {}",
130 mismatches.len(),
131 total_edges,
132 mismatches.join("; ")
133 )
134 };
135
136 CausalCheck {
137 name: "edge_correlation_signs".to_string(),
138 passed,
139 details,
140 }
141 }
142
143 fn check_non_edge_weakness(
145 samples: &[HashMap<String, f64>],
146 graph: &CausalGraph,
147 ) -> CausalCheck {
148 let var_names = graph.variable_names();
149
150 let mut edge_corrs = Vec::new();
152 for edge in &graph.edges {
153 let parent_vals: Vec<f64> = samples
154 .iter()
155 .filter_map(|s| s.get(&edge.from).copied())
156 .collect();
157 let child_vals: Vec<f64> = samples
158 .iter()
159 .filter_map(|s| s.get(&edge.to).copied())
160 .collect();
161 let corr = pearson_correlation(&parent_vals, &child_vals).abs();
162 if corr.is_finite() {
163 edge_corrs.push(corr);
164 }
165 }
166
167 let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
169 .edges
170 .iter()
171 .map(|e| (e.from.as_str(), e.to.as_str()))
172 .collect();
173
174 let mut non_edge_corrs = Vec::new();
176 for (i, &vi) in var_names.iter().enumerate() {
177 for &vj in var_names.iter().skip(i + 1) {
178 if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
179 continue;
180 }
181 let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
182 let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
183 let corr = pearson_correlation(&vals_i, &vals_j).abs();
184 if corr.is_finite() {
185 non_edge_corrs.push(corr);
186 }
187 }
188 }
189
190 let avg_edge = if edge_corrs.is_empty() {
191 0.0
192 } else {
193 edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
194 };
195
196 let avg_non_edge = if non_edge_corrs.is_empty() {
197 0.0
198 } else {
199 non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
200 };
201
202 let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
204
205 let details = format!(
206 "Avg edge correlation: {:.4}, avg non-edge correlation: {:.4}",
207 avg_edge, avg_non_edge
208 );
209
210 CausalCheck {
211 name: "non_edge_weakness".to_string(),
212 passed,
213 details,
214 }
215 }
216
217 fn check_topological_consistency(
222 samples: &[HashMap<String, f64>],
223 graph: &CausalGraph,
224 ) -> CausalCheck {
225 let mut total_checked = 0;
226 let mut consistent = 0;
227
228 for edge in &graph.edges {
229 let expected_sign = Self::mechanism_sign(&edge.mechanism);
230 if expected_sign == 0 {
231 continue;
232 }
233
234 let mut parent_vals: Vec<f64> = samples
235 .iter()
236 .filter_map(|s| s.get(&edge.from).copied())
237 .collect();
238 parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240 if parent_vals.is_empty() {
241 continue;
242 }
243
244 let median_idx = parent_vals.len() / 2;
245 let median = parent_vals[median_idx];
246
247 let child_low: Vec<f64> = samples
249 .iter()
250 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
251 .filter_map(|s| s.get(&edge.to).copied())
252 .collect();
253
254 let child_high: Vec<f64> = samples
255 .iter()
256 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
257 .filter_map(|s| s.get(&edge.to).copied())
258 .collect();
259
260 if child_low.is_empty() || child_high.is_empty() {
261 continue;
262 }
263
264 let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
265 let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
266
267 total_checked += 1;
268
269 let actual_sign = if mean_high > mean_low + 1e-10 {
271 1
272 } else if mean_high < mean_low - 1e-10 {
273 -1
274 } else {
275 0
276 };
277
278 if actual_sign == expected_sign || actual_sign == 0 {
279 consistent += 1;
280 }
281 }
282
283 let passed = total_checked == 0 || consistent >= total_checked / 2;
284 let details = format!(
285 "{}/{} edges show consistent conditional mean ordering",
286 consistent, total_checked
287 );
288
289 CausalCheck {
290 name: "topological_consistency".to_string(),
291 passed,
292 details,
293 }
294 }
295
296 fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
299 match mechanism {
300 CausalMechanism::Linear { coefficient } => {
301 if *coefficient > 0.0 {
302 1
303 } else if *coefficient < 0.0 {
304 -1
305 } else {
306 0
307 }
308 }
309 CausalMechanism::Threshold { .. } => {
310 1
312 }
313 CausalMechanism::Logistic { scale, .. } => {
314 if *scale > 0.0 {
315 1
316 } else if *scale < 0.0 {
317 -1
318 } else {
319 0
320 }
321 }
322 CausalMechanism::Polynomial { coefficients } => {
323 for coeff in coefficients.iter().rev() {
325 if *coeff > 0.0 {
326 return 1;
327 } else if *coeff < 0.0 {
328 return -1;
329 }
330 }
331 0
332 }
333 }
334 }
335}
336
337fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
339 let n = x.len().min(y.len());
340 if n < 2 {
341 return 0.0;
342 }
343
344 let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
345 let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
346
347 let mut sum_xy = 0.0;
348 let mut sum_x2 = 0.0;
349 let mut sum_y2 = 0.0;
350
351 for i in 0..n {
352 let dx = x[i] - mean_x;
353 let dy = y[i] - mean_y;
354 sum_xy += dx * dy;
355 sum_x2 += dx * dx;
356 sum_y2 += dy * dy;
357 }
358
359 let denom = (sum_x2 * sum_y2).sqrt();
360 if denom < 1e-15 {
361 0.0
362 } else {
363 sum_xy / denom
364 }
365}
366
367#[cfg(test)]
368#[allow(clippy::unwrap_used)]
369mod tests {
370 use super::*;
371 use crate::causal::graph::CausalGraph;
372 use crate::causal::scm::StructuralCausalModel;
373
374 #[test]
375 fn test_causal_validation_passes_on_correct_data() {
376 let graph = CausalGraph::fraud_detection_template();
377 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
378 let samples = scm.generate(1000, 42).unwrap();
379
380 let report = CausalValidator::validate_causal_structure(&samples, &graph);
381
382 assert!(
383 report.valid,
384 "Validation should pass on correctly generated data. Violations: {:?}",
385 report.violations
386 );
387 assert_eq!(report.checks.len(), 3);
388 assert!(report.violations.is_empty());
389 }
390
391 #[test]
392 fn test_causal_validation_detects_shuffled_columns() {
393 let graph = CausalGraph::fraud_detection_template();
394 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
395 let mut samples = scm.generate(2000, 42).unwrap();
396
397 let n = samples.len();
400 let fp_values: Vec<f64> = samples
401 .iter()
402 .filter_map(|s| s.get("fraud_probability").copied())
403 .collect();
404
405 for (i, sample) in samples.iter_mut().enumerate() {
406 let shifted_idx = (i + n / 2) % n;
407 sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
408 }
409
410 let report = CausalValidator::validate_causal_structure(&samples, &graph);
411
412 let has_failure = report.checks.iter().any(|c| !c.passed);
414 assert!(
415 has_failure,
416 "Validation should detect broken causal structure. Checks: {:?}",
417 report.checks
418 );
419 }
420
421 #[test]
422 fn test_causal_pearson_correlation_perfect_positive() {
423 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
424 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
425 let corr = pearson_correlation(&x, &y);
426 assert!(
427 (corr - 1.0).abs() < 1e-10,
428 "Perfect positive correlation expected, got {}",
429 corr
430 );
431 }
432
433 #[test]
434 fn test_causal_pearson_correlation_perfect_negative() {
435 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
436 let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
437 let corr = pearson_correlation(&x, &y);
438 assert!(
439 (corr - (-1.0)).abs() < 1e-10,
440 "Perfect negative correlation expected, got {}",
441 corr
442 );
443 }
444
445 #[test]
446 fn test_causal_pearson_correlation_constant() {
447 let x = vec![1.0, 1.0, 1.0, 1.0];
448 let y = vec![2.0, 4.0, 6.0, 8.0];
449 let corr = pearson_correlation(&x, &y);
450 assert!(
451 corr.abs() < 1e-10,
452 "Correlation with constant should be 0, got {}",
453 corr
454 );
455 }
456
457 #[test]
458 fn test_causal_validation_report_structure() {
459 let graph = CausalGraph::fraud_detection_template();
460 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
461 let samples = scm.generate(200, 42).unwrap();
462
463 let report = CausalValidator::validate_causal_structure(&samples, &graph);
464
465 assert_eq!(report.checks.len(), 3);
467 assert_eq!(report.checks[0].name, "edge_correlation_signs");
468 assert_eq!(report.checks[1].name, "non_edge_weakness");
469 assert_eq!(report.checks[2].name, "topological_consistency");
470
471 for check in &report.checks {
473 assert!(!check.details.is_empty());
474 }
475 }
476
477 #[test]
478 fn test_causal_validation_revenue_cycle() {
479 let graph = CausalGraph::revenue_cycle_template();
480 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
481 let samples = scm.generate(1000, 99).unwrap();
482
483 let report = CausalValidator::validate_causal_structure(&samples, &graph);
484
485 let passing = report.checks.iter().filter(|c| c.passed).count();
487 assert!(
488 passing >= 2,
489 "At least 2 of 3 checks should pass. Checks: {:?}",
490 report.checks
491 );
492 }
493}