use std::collections::HashMap;
use super::graph::{CausalGraph, CausalMechanism};
#[derive(Debug, Clone)]
pub struct CausalValidationReport {
pub valid: bool,
pub checks: Vec<CausalCheck>,
pub violations: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct CausalCheck {
pub name: String,
pub passed: bool,
pub details: String,
}
pub struct CausalValidator;
impl CausalValidator {
pub fn validate_causal_structure(
samples: &[HashMap<String, f64>],
graph: &CausalGraph,
) -> CausalValidationReport {
let mut checks = Vec::new();
let mut violations = Vec::new();
let sign_check = Self::check_edge_correlation_signs(samples, graph);
if !sign_check.passed {
violations.push(sign_check.details.clone());
}
checks.push(sign_check);
let strength_check = Self::check_non_edge_weakness(samples, graph);
if !strength_check.passed {
violations.push(strength_check.details.clone());
}
checks.push(strength_check);
let topo_check = Self::check_topological_consistency(samples, graph);
if !topo_check.passed {
violations.push(topo_check.details.clone());
}
checks.push(topo_check);
let valid = checks.iter().all(|c| c.passed);
CausalValidationReport {
valid,
checks,
violations,
}
}
fn check_edge_correlation_signs(
samples: &[HashMap<String, f64>],
graph: &CausalGraph,
) -> CausalCheck {
let mut total_edges = 0;
let mut correct_signs = 0u32;
let mut mismatches = Vec::new();
for edge in &graph.edges {
let expected_sign = Self::mechanism_sign(&edge.mechanism);
if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
continue;
}
total_edges += 1;
let parent_vals: Vec<f64> = samples
.iter()
.filter_map(|s| s.get(&edge.from).copied())
.collect();
let child_vals: Vec<f64> = samples
.iter()
.filter_map(|s| s.get(&edge.to).copied())
.collect();
let corr = pearson_correlation(&parent_vals, &child_vals);
if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
correct_signs += 1;
} else {
mismatches.push(format!(
"{} -> {}: expected sign {}, got correlation {:.4}",
edge.from, edge.to, expected_sign, corr
));
}
}
let passed = mismatches.is_empty();
let details = if passed {
format!("All {correct_signs}/{total_edges} edges have correct correlation signs")
} else {
format!(
"{}/{} edges have incorrect signs: {}",
mismatches.len(),
total_edges,
mismatches.join("; ")
)
};
CausalCheck {
name: "edge_correlation_signs".to_string(),
passed,
details,
}
}
fn check_non_edge_weakness(
samples: &[HashMap<String, f64>],
graph: &CausalGraph,
) -> CausalCheck {
let var_names = graph.variable_names();
let mut edge_corrs = Vec::new();
for edge in &graph.edges {
let parent_vals: Vec<f64> = samples
.iter()
.filter_map(|s| s.get(&edge.from).copied())
.collect();
let child_vals: Vec<f64> = samples
.iter()
.filter_map(|s| s.get(&edge.to).copied())
.collect();
let corr = pearson_correlation(&parent_vals, &child_vals).abs();
if corr.is_finite() {
edge_corrs.push(corr);
}
}
let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
.edges
.iter()
.map(|e| (e.from.as_str(), e.to.as_str()))
.collect();
let mut non_edge_corrs = Vec::new();
for (i, &vi) in var_names.iter().enumerate() {
for &vj in var_names.iter().skip(i + 1) {
if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
continue;
}
let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
let corr = pearson_correlation(&vals_i, &vals_j).abs();
if corr.is_finite() {
non_edge_corrs.push(corr);
}
}
}
let avg_edge = if edge_corrs.is_empty() {
0.0
} else {
edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
};
let avg_non_edge = if non_edge_corrs.is_empty() {
0.0
} else {
non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
};
let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
let details = format!(
"Avg edge correlation: {avg_edge:.4}, avg non-edge correlation: {avg_non_edge:.4}"
);
CausalCheck {
name: "non_edge_weakness".to_string(),
passed,
details,
}
}
fn check_topological_consistency(
samples: &[HashMap<String, f64>],
graph: &CausalGraph,
) -> CausalCheck {
let mut total_checked = 0;
let mut consistent = 0;
for edge in &graph.edges {
let expected_sign = Self::mechanism_sign(&edge.mechanism);
if expected_sign == 0 {
continue;
}
let mut parent_vals: Vec<f64> = samples
.iter()
.filter_map(|s| s.get(&edge.from).copied())
.collect();
parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if parent_vals.is_empty() {
continue;
}
let median_idx = parent_vals.len() / 2;
let median = parent_vals[median_idx];
let child_low: Vec<f64> = samples
.iter()
.filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
.filter_map(|s| s.get(&edge.to).copied())
.collect();
let child_high: Vec<f64> = samples
.iter()
.filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
.filter_map(|s| s.get(&edge.to).copied())
.collect();
if child_low.is_empty() || child_high.is_empty() {
continue;
}
let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
total_checked += 1;
let actual_sign = if mean_high > mean_low + 1e-10 {
1
} else if mean_high < mean_low - 1e-10 {
-1
} else {
0
};
if actual_sign == expected_sign || actual_sign == 0 {
consistent += 1;
}
}
let passed = total_checked == 0 || consistent >= total_checked / 2;
let details =
format!("{consistent}/{total_checked} edges show consistent conditional mean ordering");
CausalCheck {
name: "topological_consistency".to_string(),
passed,
details,
}
}
fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
match mechanism {
CausalMechanism::Linear { coefficient } => {
if *coefficient > 0.0 {
1
} else if *coefficient < 0.0 {
-1
} else {
0
}
}
CausalMechanism::Threshold { .. } => {
1
}
CausalMechanism::Logistic { scale, .. } => {
if *scale > 0.0 {
1
} else if *scale < 0.0 {
-1
} else {
0
}
}
CausalMechanism::Polynomial { coefficients } => {
for coeff in coefficients.iter().rev() {
if *coeff > 0.0 {
return 1;
} else if *coeff < 0.0 {
return -1;
}
}
0
}
}
}
}
fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
let n = x.len().min(y.len());
if n < 2 {
return 0.0;
}
let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
let mut sum_xy = 0.0;
let mut sum_x2 = 0.0;
let mut sum_y2 = 0.0;
for i in 0..n {
let dx = x[i] - mean_x;
let dy = y[i] - mean_y;
sum_xy += dx * dy;
sum_x2 += dx * dx;
sum_y2 += dy * dy;
}
let denom = (sum_x2 * sum_y2).sqrt();
if denom < 1e-15 {
0.0
} else {
sum_xy / denom
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::causal::graph::CausalGraph;
use crate::causal::scm::StructuralCausalModel;
#[test]
fn test_causal_validation_passes_on_correct_data() {
let graph = CausalGraph::fraud_detection_template();
let scm = StructuralCausalModel::new(graph.clone()).unwrap();
let samples = scm.generate(1000, 42).unwrap();
let report = CausalValidator::validate_causal_structure(&samples, &graph);
assert!(
report.valid,
"Validation should pass on correctly generated data. Violations: {:?}",
report.violations
);
assert_eq!(report.checks.len(), 3);
assert!(report.violations.is_empty());
}
#[test]
fn test_causal_validation_detects_shuffled_columns() {
let graph = CausalGraph::fraud_detection_template();
let scm = StructuralCausalModel::new(graph.clone()).unwrap();
let mut samples = scm.generate(2000, 42).unwrap();
let n = samples.len();
let fp_values: Vec<f64> = samples
.iter()
.filter_map(|s| s.get("fraud_probability").copied())
.collect();
for (i, sample) in samples.iter_mut().enumerate() {
let shifted_idx = (i + n / 2) % n;
sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
}
let report = CausalValidator::validate_causal_structure(&samples, &graph);
let has_failure = report.checks.iter().any(|c| !c.passed);
assert!(
has_failure,
"Validation should detect broken causal structure. Checks: {:?}",
report.checks
);
}
#[test]
fn test_causal_pearson_correlation_perfect_positive() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let corr = pearson_correlation(&x, &y);
assert!(
(corr - 1.0).abs() < 1e-10,
"Perfect positive correlation expected, got {}",
corr
);
}
#[test]
fn test_causal_pearson_correlation_perfect_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
let corr = pearson_correlation(&x, &y);
assert!(
(corr - (-1.0)).abs() < 1e-10,
"Perfect negative correlation expected, got {}",
corr
);
}
#[test]
fn test_causal_pearson_correlation_constant() {
let x = vec![1.0, 1.0, 1.0, 1.0];
let y = vec![2.0, 4.0, 6.0, 8.0];
let corr = pearson_correlation(&x, &y);
assert!(
corr.abs() < 1e-10,
"Correlation with constant should be 0, got {}",
corr
);
}
#[test]
fn test_causal_validation_report_structure() {
let graph = CausalGraph::fraud_detection_template();
let scm = StructuralCausalModel::new(graph.clone()).unwrap();
let samples = scm.generate(200, 42).unwrap();
let report = CausalValidator::validate_causal_structure(&samples, &graph);
assert_eq!(report.checks.len(), 3);
assert_eq!(report.checks[0].name, "edge_correlation_signs");
assert_eq!(report.checks[1].name, "non_edge_weakness");
assert_eq!(report.checks[2].name, "topological_consistency");
for check in &report.checks {
assert!(!check.details.is_empty());
}
}
#[test]
fn test_causal_validation_revenue_cycle() {
let graph = CausalGraph::revenue_cycle_template();
let scm = StructuralCausalModel::new(graph.clone()).unwrap();
let samples = scm.generate(1000, 99).unwrap();
let report = CausalValidator::validate_causal_structure(&samples, &graph);
let passing = report.checks.iter().filter(|c| c.passed).count();
assert!(
passing >= 2,
"At least 2 of 3 checks should pass. Checks: {:?}",
report.checks
);
}
}