1use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13 pub valid: bool,
15 pub checks: Vec<CausalCheck>,
17 pub violations: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CausalCheck {
24 pub name: String,
26 pub passed: bool,
28 pub details: String,
30}
31
32pub struct CausalValidator;
34
35impl CausalValidator {
36 pub fn validate_causal_structure(
43 samples: &[HashMap<String, f64>],
44 graph: &CausalGraph,
45 ) -> CausalValidationReport {
46 let mut checks = Vec::new();
47 let mut violations = Vec::new();
48
49 let sign_check = Self::check_edge_correlation_signs(samples, graph);
51 if !sign_check.passed {
52 violations.push(sign_check.details.clone());
53 }
54 checks.push(sign_check);
55
56 let strength_check = Self::check_non_edge_weakness(samples, graph);
58 if !strength_check.passed {
59 violations.push(strength_check.details.clone());
60 }
61 checks.push(strength_check);
62
63 let topo_check = Self::check_topological_consistency(samples, graph);
65 if !topo_check.passed {
66 violations.push(topo_check.details.clone());
67 }
68 checks.push(topo_check);
69
70 let valid = checks.iter().all(|c| c.passed);
71
72 CausalValidationReport {
73 valid,
74 checks,
75 violations,
76 }
77 }
78
79 fn check_edge_correlation_signs(
82 samples: &[HashMap<String, f64>],
83 graph: &CausalGraph,
84 ) -> CausalCheck {
85 let mut total_edges = 0;
86 let mut _correct_signs = 0u32;
87 let mut mismatches = Vec::new();
88
89 for edge in &graph.edges {
90 let expected_sign = Self::mechanism_sign(&edge.mechanism);
91 if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95 continue;
96 }
97
98 total_edges += 1;
99
100 let parent_vals: Vec<f64> = samples
101 .iter()
102 .filter_map(|s| s.get(&edge.from).copied())
103 .collect();
104 let child_vals: Vec<f64> = samples
105 .iter()
106 .filter_map(|s| s.get(&edge.to).copied())
107 .collect();
108
109 let corr = pearson_correlation(&parent_vals, &child_vals);
110
111 if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112 _correct_signs += 1;
113 } else {
114 mismatches.push(format!(
115 "{} -> {}: expected sign {}, got correlation {:.4}",
116 edge.from, edge.to, expected_sign, corr
117 ));
118 }
119 }
120
121 let passed = mismatches.is_empty();
122 let details = if passed {
123 format!("All {} edges have correct correlation signs", total_edges)
124 } else {
125 format!(
126 "{}/{} edges have incorrect signs: {}",
127 mismatches.len(),
128 total_edges,
129 mismatches.join("; ")
130 )
131 };
132
133 CausalCheck {
134 name: "edge_correlation_signs".to_string(),
135 passed,
136 details,
137 }
138 }
139
140 fn check_non_edge_weakness(
142 samples: &[HashMap<String, f64>],
143 graph: &CausalGraph,
144 ) -> CausalCheck {
145 let var_names = graph.variable_names();
146
147 let mut edge_corrs = Vec::new();
149 for edge in &graph.edges {
150 let parent_vals: Vec<f64> = samples
151 .iter()
152 .filter_map(|s| s.get(&edge.from).copied())
153 .collect();
154 let child_vals: Vec<f64> = samples
155 .iter()
156 .filter_map(|s| s.get(&edge.to).copied())
157 .collect();
158 let corr = pearson_correlation(&parent_vals, &child_vals).abs();
159 if corr.is_finite() {
160 edge_corrs.push(corr);
161 }
162 }
163
164 let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
166 .edges
167 .iter()
168 .map(|e| (e.from.as_str(), e.to.as_str()))
169 .collect();
170
171 let mut non_edge_corrs = Vec::new();
173 for (i, &vi) in var_names.iter().enumerate() {
174 for &vj in var_names.iter().skip(i + 1) {
175 if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
176 continue;
177 }
178 let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
179 let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
180 let corr = pearson_correlation(&vals_i, &vals_j).abs();
181 if corr.is_finite() {
182 non_edge_corrs.push(corr);
183 }
184 }
185 }
186
187 let avg_edge = if edge_corrs.is_empty() {
188 0.0
189 } else {
190 edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
191 };
192
193 let avg_non_edge = if non_edge_corrs.is_empty() {
194 0.0
195 } else {
196 non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
197 };
198
199 let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
201
202 let details = format!(
203 "Avg edge correlation: {:.4}, avg non-edge correlation: {:.4}",
204 avg_edge, avg_non_edge
205 );
206
207 CausalCheck {
208 name: "non_edge_weakness".to_string(),
209 passed,
210 details,
211 }
212 }
213
214 fn check_topological_consistency(
219 samples: &[HashMap<String, f64>],
220 graph: &CausalGraph,
221 ) -> CausalCheck {
222 let mut total_checked = 0;
223 let mut consistent = 0;
224
225 for edge in &graph.edges {
226 let expected_sign = Self::mechanism_sign(&edge.mechanism);
227 if expected_sign == 0 {
228 continue;
229 }
230
231 let mut parent_vals: Vec<f64> = samples
232 .iter()
233 .filter_map(|s| s.get(&edge.from).copied())
234 .collect();
235 parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236
237 if parent_vals.is_empty() {
238 continue;
239 }
240
241 let median_idx = parent_vals.len() / 2;
242 let median = parent_vals[median_idx];
243
244 let child_low: Vec<f64> = samples
246 .iter()
247 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
248 .filter_map(|s| s.get(&edge.to).copied())
249 .collect();
250
251 let child_high: Vec<f64> = samples
252 .iter()
253 .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
254 .filter_map(|s| s.get(&edge.to).copied())
255 .collect();
256
257 if child_low.is_empty() || child_high.is_empty() {
258 continue;
259 }
260
261 let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
262 let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
263
264 total_checked += 1;
265
266 let actual_sign = if mean_high > mean_low + 1e-10 {
268 1
269 } else if mean_high < mean_low - 1e-10 {
270 -1
271 } else {
272 0
273 };
274
275 if actual_sign == expected_sign || actual_sign == 0 {
276 consistent += 1;
277 }
278 }
279
280 let passed = total_checked == 0 || consistent >= total_checked / 2;
281 let details = format!(
282 "{}/{} edges show consistent conditional mean ordering",
283 consistent, total_checked
284 );
285
286 CausalCheck {
287 name: "topological_consistency".to_string(),
288 passed,
289 details,
290 }
291 }
292
293 fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
296 match mechanism {
297 CausalMechanism::Linear { coefficient } => {
298 if *coefficient > 0.0 {
299 1
300 } else if *coefficient < 0.0 {
301 -1
302 } else {
303 0
304 }
305 }
306 CausalMechanism::Threshold { .. } => {
307 1
309 }
310 CausalMechanism::Logistic { scale, .. } => {
311 if *scale > 0.0 {
312 1
313 } else if *scale < 0.0 {
314 -1
315 } else {
316 0
317 }
318 }
319 CausalMechanism::Polynomial { coefficients } => {
320 for coeff in coefficients.iter().rev() {
322 if *coeff > 0.0 {
323 return 1;
324 } else if *coeff < 0.0 {
325 return -1;
326 }
327 }
328 0
329 }
330 }
331 }
332}
333
334fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
336 let n = x.len().min(y.len());
337 if n < 2 {
338 return 0.0;
339 }
340
341 let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
342 let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
343
344 let mut sum_xy = 0.0;
345 let mut sum_x2 = 0.0;
346 let mut sum_y2 = 0.0;
347
348 for i in 0..n {
349 let dx = x[i] - mean_x;
350 let dy = y[i] - mean_y;
351 sum_xy += dx * dy;
352 sum_x2 += dx * dx;
353 sum_y2 += dy * dy;
354 }
355
356 let denom = (sum_x2 * sum_y2).sqrt();
357 if denom < 1e-15 {
358 0.0
359 } else {
360 sum_xy / denom
361 }
362}
363
364#[cfg(test)]
365#[allow(clippy::unwrap_used)]
366mod tests {
367 use super::*;
368 use crate::causal::graph::CausalGraph;
369 use crate::causal::scm::StructuralCausalModel;
370
371 #[test]
372 fn test_causal_validation_passes_on_correct_data() {
373 let graph = CausalGraph::fraud_detection_template();
374 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
375 let samples = scm.generate(1000, 42).unwrap();
376
377 let report = CausalValidator::validate_causal_structure(&samples, &graph);
378
379 assert!(
380 report.valid,
381 "Validation should pass on correctly generated data. Violations: {:?}",
382 report.violations
383 );
384 assert_eq!(report.checks.len(), 3);
385 assert!(report.violations.is_empty());
386 }
387
388 #[test]
389 fn test_causal_validation_detects_shuffled_columns() {
390 let graph = CausalGraph::fraud_detection_template();
391 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
392 let mut samples = scm.generate(500, 42).unwrap();
393
394 let n = samples.len();
397 let fp_values: Vec<f64> = samples
398 .iter()
399 .filter_map(|s| s.get("fraud_probability").copied())
400 .collect();
401
402 for (i, sample) in samples.iter_mut().enumerate() {
403 let shifted_idx = (i + n / 2) % n;
404 sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
405 }
406
407 let report = CausalValidator::validate_causal_structure(&samples, &graph);
408
409 let has_failure = report.checks.iter().any(|c| !c.passed);
411 assert!(
412 has_failure,
413 "Validation should detect broken causal structure. Checks: {:?}",
414 report.checks
415 );
416 }
417
418 #[test]
419 fn test_causal_pearson_correlation_perfect_positive() {
420 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
421 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
422 let corr = pearson_correlation(&x, &y);
423 assert!(
424 (corr - 1.0).abs() < 1e-10,
425 "Perfect positive correlation expected, got {}",
426 corr
427 );
428 }
429
430 #[test]
431 fn test_causal_pearson_correlation_perfect_negative() {
432 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
433 let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
434 let corr = pearson_correlation(&x, &y);
435 assert!(
436 (corr - (-1.0)).abs() < 1e-10,
437 "Perfect negative correlation expected, got {}",
438 corr
439 );
440 }
441
442 #[test]
443 fn test_causal_pearson_correlation_constant() {
444 let x = vec![1.0, 1.0, 1.0, 1.0];
445 let y = vec![2.0, 4.0, 6.0, 8.0];
446 let corr = pearson_correlation(&x, &y);
447 assert!(
448 corr.abs() < 1e-10,
449 "Correlation with constant should be 0, got {}",
450 corr
451 );
452 }
453
454 #[test]
455 fn test_causal_validation_report_structure() {
456 let graph = CausalGraph::fraud_detection_template();
457 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
458 let samples = scm.generate(200, 42).unwrap();
459
460 let report = CausalValidator::validate_causal_structure(&samples, &graph);
461
462 assert_eq!(report.checks.len(), 3);
464 assert_eq!(report.checks[0].name, "edge_correlation_signs");
465 assert_eq!(report.checks[1].name, "non_edge_weakness");
466 assert_eq!(report.checks[2].name, "topological_consistency");
467
468 for check in &report.checks {
470 assert!(!check.details.is_empty());
471 }
472 }
473
474 #[test]
475 fn test_causal_validation_revenue_cycle() {
476 let graph = CausalGraph::revenue_cycle_template();
477 let scm = StructuralCausalModel::new(graph.clone()).unwrap();
478 let samples = scm.generate(1000, 99).unwrap();
479
480 let report = CausalValidator::validate_causal_structure(&samples, &graph);
481
482 let passing = report.checks.iter().filter(|c| c.passed).count();
484 assert!(
485 passing >= 2,
486 "At least 2 of 3 checks should pass. Checks: {:?}",
487 report.checks
488 );
489 }
490}