1use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13 pub valid: bool,
15 pub checks: Vec<CausalCheck>,
17 pub violations: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CausalCheck {
24 pub name: String,
26 pub passed: bool,
28 pub details: String,
30}
31
32pub struct CausalValidator;
34
35impl CausalValidator {
36 pub fn validate_causal_structure(
43 samples: &[HashMap<String, f64>],
44 graph: &CausalGraph,
45 ) -> CausalValidationReport {
46 let mut checks = Vec::new();
47 let mut violations = Vec::new();
48
49 let sign_check = Self::check_edge_correlation_signs(samples, graph);
51 if !sign_check.passed {
52 violations.push(sign_check.details.clone());
53 }
54 checks.push(sign_check);
55
56 let strength_check = Self::check_non_edge_weakness(samples, graph);
58 if !strength_check.passed {
59 violations.push(strength_check.details.clone());
60 }
61 checks.push(strength_check);
62
63 let topo_check = Self::check_topological_consistency(samples, graph);
65 if !topo_check.passed {
66 violations.push(topo_check.details.clone());
67 }
68 checks.push(topo_check);
69
70 let valid = checks.iter().all(|c| c.passed);
71
72 CausalValidationReport {
73 valid,
74 checks,
75 violations,
76 }
77 }
78
79 fn check_edge_correlation_signs(
82 samples: &[HashMap<String, f64>],
83 graph: &CausalGraph,
84 ) -> CausalCheck {
85 let mut total_edges = 0;
86 let mut correct_signs = 0u32;
87 let mut mismatches = Vec::new();
88
89 for edge in &graph.edges {
90 let expected_sign = Self::mechanism_sign(&edge.mechanism);
91 if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95 continue;
96 }
97
98 total_edges += 1;
99
100 let parent_vals: Vec<f64> = samples
101 .iter()
102 .filter_map(|s| s.get(&edge.from).copied())
103 .collect();
104 let child_vals: Vec<f64> = samples
105 .iter()
106 .filter_map(|s| s.get(&edge.to).copied())
107 .collect();
108
109 let corr = pearson_correlation(&parent_vals, &child_vals);
110
111 if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112 correct_signs += 1;
113 } else {
114 mismatches.push(format!(
115 "{} -> {}: expected sign {}, got correlation {:.4}",
116 edge.from, edge.to, expected_sign, corr
117 ));
118 }
119 }
120
121 let passed = mismatches.is_empty();
122 let details = if passed {
123 format!("All {correct_signs}/{total_edges} edges have correct correlation signs")
124 } else {
125 format!(
126 "{}/{} edges have incorrect signs: {}",
127 mismatches.len(),
128 total_edges,
129 mismatches.join("; ")
130 )
131 };
132
133 CausalCheck {
134 name: "edge_correlation_signs".to_string(),
135 passed,
136 details,
137 }
138 }
139
140 fn check_non_edge_weakness(
142 samples: &[HashMap<String, f64>],
143 graph: &CausalGraph,
144 ) -> CausalCheck {
145 let var_names = graph.variable_names();
146
147 let mut edge_corrs = Vec::new();
149 for edge in &graph.edges {
150 let parent_vals: Vec<f64> = samples
151 .iter()
152 .filter_map(|s| s.get(&edge.from).copied())
153 .collect();
154 let child_vals: Vec<f64> = samples
155 .iter()
156 .filter_map(|s| s.get(&edge.to).copied())
157 .collect();
158 let corr = pearson_correlation(&parent_vals, &child_vals).abs();
159 if corr.is_finite() {
160 edge_corrs.push(corr);
161 }
162 }
163
164 let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
166 .edges
167 .iter()
168 .map(|e| (e.from.as_str(), e.to.as_str()))
169 .collect();
170
171 let mut non_edge_corrs = Vec::new();
173 for (i, &vi) in var_names.iter().enumerate() {
174 for &vj in var_names.iter().skip(i + 1) {
175 if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
176 continue;
177 }
178 let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
179 let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
180 let corr = pearson_correlation(&vals_i, &vals_j).abs();
181 if corr.is_finite() {
182 non_edge_corrs.push(corr);
183 }
184 }
185 }
186
187 let avg_edge = if edge_corrs.is_empty() {
188 0.0
189 } else {
190 edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
191 };
192
193 let avg_non_edge = if non_edge_corrs.is_empty() {
194 0.0
195 } else {
196 non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
197 };
198
199 let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
201
202 let details = format!(
203 "Avg edge correlation: {avg_edge:.4}, avg non-edge correlation: {avg_non_edge:.4}"
204 );
205
206 CausalCheck {
207 name: "non_edge_weakness".to_string(),
208 passed,
209 details,
210 }
211 }
212
213 fn check_topological_consistency(
218 samples: &[HashMap<String, f64>],
219 graph: &CausalGraph,
220 ) -> CausalCheck {
221 let mut total_checked = 0;
222 let mut consistent = 0;
223
224 for edge in &graph.edges {
225 let expected_sign = Self::mechanism_sign(&edge.mechanism);
226 if expected_sign == 0 {
227 continue;
228 }
229
230 let mut parent_vals: Vec<f64> = samples
231 .iter()
232 .filter_map(|s| s.get(&edge.from).copied())
233 .collect();
234 parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
235
236 if parent_vals.is_empty() {
237 continue;
238 }
239
240 let median_idx = parent_vals.len() / 2;
241 let median = parent_vals[median_idx];
242
243 let child_low: Vec<f64> = samples
245 .iter()
246 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
247 .filter_map(|s| s.get(&edge.to).copied())
248 .collect();
249
250 let child_high: Vec<f64> = samples
251 .iter()
252 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
253 .filter_map(|s| s.get(&edge.to).copied())
254 .collect();
255
256 if child_low.is_empty() || child_high.is_empty() {
257 continue;
258 }
259
260 let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
261 let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
262
263 total_checked += 1;
264
265 let actual_sign = if mean_high > mean_low + 1e-10 {
267 1
268 } else if mean_high < mean_low - 1e-10 {
269 -1
270 } else {
271 0
272 };
273
274 if actual_sign == expected_sign || actual_sign == 0 {
275 consistent += 1;
276 }
277 }
278
279 let passed = total_checked == 0 || consistent >= total_checked / 2;
280 let details =
281 format!("{consistent}/{total_checked} edges show consistent conditional mean ordering");
282
283 CausalCheck {
284 name: "topological_consistency".to_string(),
285 passed,
286 details,
287 }
288 }
289
290 fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
293 match mechanism {
294 CausalMechanism::Linear { coefficient } => {
295 if *coefficient > 0.0 {
296 1
297 } else if *coefficient < 0.0 {
298 -1
299 } else {
300 0
301 }
302 }
303 CausalMechanism::Threshold { .. } => {
304 1
306 }
307 CausalMechanism::Logistic { scale, .. } => {
308 if *scale > 0.0 {
309 1
310 } else if *scale < 0.0 {
311 -1
312 } else {
313 0
314 }
315 }
316 CausalMechanism::Polynomial { coefficients } => {
317 for coeff in coefficients.iter().rev() {
319 if *coeff > 0.0 {
320 return 1;
321 } else if *coeff < 0.0 {
322 return -1;
323 }
324 }
325 0
326 }
327 }
328 }
329}
330
331fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
333 let n = x.len().min(y.len());
334 if n < 2 {
335 return 0.0;
336 }
337
338 let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
339 let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
340
341 let mut sum_xy = 0.0;
342 let mut sum_x2 = 0.0;
343 let mut sum_y2 = 0.0;
344
345 for i in 0..n {
346 let dx = x[i] - mean_x;
347 let dy = y[i] - mean_y;
348 sum_xy += dx * dy;
349 sum_x2 += dx * dx;
350 sum_y2 += dy * dy;
351 }
352
353 let denom = (sum_x2 * sum_y2).sqrt();
354 if denom < 1e-15 {
355 0.0
356 } else {
357 sum_xy / denom
358 }
359}
360
361#[cfg(test)]
362#[allow(clippy::unwrap_used)]
363mod tests {
364 use super::*;
365 use crate::causal::graph::CausalGraph;
366 use crate::causal::scm::StructuralCausalModel;
367
368 #[test]
369 fn test_causal_validation_passes_on_correct_data() {
370 let graph = CausalGraph::fraud_detection_template();
371 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
372 let samples = scm.generate(1000, 42).unwrap();
373
374 let report = CausalValidator::validate_causal_structure(&samples, &graph);
375
376 assert!(
377 report.valid,
378 "Validation should pass on correctly generated data. Violations: {:?}",
379 report.violations
380 );
381 assert_eq!(report.checks.len(), 3);
382 assert!(report.violations.is_empty());
383 }
384
385 #[test]
386 fn test_causal_validation_detects_shuffled_columns() {
387 let graph = CausalGraph::fraud_detection_template();
388 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
389 let mut samples = scm.generate(2000, 42).unwrap();
390
391 let n = samples.len();
394 let fp_values: Vec<f64> = samples
395 .iter()
396 .filter_map(|s| s.get("fraud_probability").copied())
397 .collect();
398
399 for (i, sample) in samples.iter_mut().enumerate() {
400 let shifted_idx = (i + n / 2) % n;
401 sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
402 }
403
404 let report = CausalValidator::validate_causal_structure(&samples, &graph);
405
406 let has_failure = report.checks.iter().any(|c| !c.passed);
408 assert!(
409 has_failure,
410 "Validation should detect broken causal structure. Checks: {:?}",
411 report.checks
412 );
413 }
414
415 #[test]
416 fn test_causal_pearson_correlation_perfect_positive() {
417 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
418 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
419 let corr = pearson_correlation(&x, &y);
420 assert!(
421 (corr - 1.0).abs() < 1e-10,
422 "Perfect positive correlation expected, got {}",
423 corr
424 );
425 }
426
427 #[test]
428 fn test_causal_pearson_correlation_perfect_negative() {
429 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
430 let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
431 let corr = pearson_correlation(&x, &y);
432 assert!(
433 (corr - (-1.0)).abs() < 1e-10,
434 "Perfect negative correlation expected, got {}",
435 corr
436 );
437 }
438
439 #[test]
440 fn test_causal_pearson_correlation_constant() {
441 let x = vec![1.0, 1.0, 1.0, 1.0];
442 let y = vec![2.0, 4.0, 6.0, 8.0];
443 let corr = pearson_correlation(&x, &y);
444 assert!(
445 corr.abs() < 1e-10,
446 "Correlation with constant should be 0, got {}",
447 corr
448 );
449 }
450
451 #[test]
452 fn test_causal_validation_report_structure() {
453 let graph = CausalGraph::fraud_detection_template();
454 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
455 let samples = scm.generate(200, 42).unwrap();
456
457 let report = CausalValidator::validate_causal_structure(&samples, &graph);
458
459 assert_eq!(report.checks.len(), 3);
461 assert_eq!(report.checks[0].name, "edge_correlation_signs");
462 assert_eq!(report.checks[1].name, "non_edge_weakness");
463 assert_eq!(report.checks[2].name, "topological_consistency");
464
465 for check in &report.checks {
467 assert!(!check.details.is_empty());
468 }
469 }
470
471 #[test]
472 fn test_causal_validation_revenue_cycle() {
473 let graph = CausalGraph::revenue_cycle_template();
474 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
475 let samples = scm.generate(1000, 99).unwrap();
476
477 let report = CausalValidator::validate_causal_structure(&samples, &graph);
478
479 let passing = report.checks.iter().filter(|c| c.passed).count();
481 assert!(
482 passing >= 2,
483 "At least 2 of 3 checks should pass. Checks: {:?}",
484 report.checks
485 );
486 }
487}