Skip to main content

datasynth_core/causal/
validation.rs

1//! Causal structure validation.
2//!
3//! Validates that generated samples respect the causal structure defined by the graph,
4//! checking correlation signs, edge strength, and topological consistency.
5
6use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10/// Report from causal structure validation.
11#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13    /// Whether all checks passed.
14    pub valid: bool,
15    /// Individual check results.
16    pub checks: Vec<CausalCheck>,
17    /// Human-readable violation descriptions.
18    pub violations: Vec<String>,
19}
20
21/// Result of a single validation check.
22#[derive(Debug, Clone)]
23pub struct CausalCheck {
24    /// Name of the check.
25    pub name: String,
26    /// Whether the check passed.
27    pub passed: bool,
28    /// Details about the check result.
29    pub details: String,
30}
31
32/// Validator for causal structure consistency.
33pub struct CausalValidator;
34
35impl CausalValidator {
36    /// Validate that samples respect the causal structure of the graph.
37    ///
38    /// Performs three checks:
39    /// 1. Edge correlation signs match mechanism coefficient signs
40    /// 2. Non-edges have weaker average correlation than edges
41    /// 3. Topological ordering holds in conditional means
42    pub fn validate_causal_structure(
43        samples: &[HashMap<String, f64>],
44        graph: &CausalGraph,
45    ) -> CausalValidationReport {
46        let mut checks = Vec::new();
47        let mut violations = Vec::new();
48
49        // Check 1: Edge correlation signs
50        let sign_check = Self::check_edge_correlation_signs(samples, graph);
51        if !sign_check.passed {
52            violations.push(sign_check.details.clone());
53        }
54        checks.push(sign_check);
55
56        // Check 2: Non-edges have weaker correlation than edges
57        let strength_check = Self::check_non_edge_weakness(samples, graph);
58        if !strength_check.passed {
59            violations.push(strength_check.details.clone());
60        }
61        checks.push(strength_check);
62
63        // Check 3: Topological ordering in conditional means
64        let topo_check = Self::check_topological_consistency(samples, graph);
65        if !topo_check.passed {
66            violations.push(topo_check.details.clone());
67        }
68        checks.push(topo_check);
69
70        let valid = checks.iter().all(|c| c.passed);
71
72        CausalValidationReport {
73            valid,
74            checks,
75            violations,
76        }
77    }
78
79    /// Check 1: For each edge, verify correlation between parent and child
80    /// has the expected sign (based on mechanism coefficient sign).
81    fn check_edge_correlation_signs(
82        samples: &[HashMap<String, f64>],
83        graph: &CausalGraph,
84    ) -> CausalCheck {
85        let mut total_edges = 0;
86        let mut correct_signs = 0u32;
87        let mut mismatches = Vec::new();
88
89        for edge in &graph.edges {
90            let expected_sign = Self::mechanism_sign(&edge.mechanism);
91            // Skip edges where we can't reliably determine expected sign.
92            // Threshold mechanisms produce binary outputs where correlation
93            // with the continuous parent is often very weak or indeterminate.
94            if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95                continue;
96            }
97
98            total_edges += 1;
99
100            let parent_vals: Vec<f64> = samples
101                .iter()
102                .filter_map(|s| s.get(&edge.from).copied())
103                .collect();
104            let child_vals: Vec<f64> = samples
105                .iter()
106                .filter_map(|s| s.get(&edge.to).copied())
107                .collect();
108
109            let corr = pearson_correlation(&parent_vals, &child_vals);
110
111            if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112                correct_signs += 1;
113            } else {
114                mismatches.push(format!(
115                    "{} -> {}: expected sign {}, got correlation {:.4}",
116                    edge.from, edge.to, expected_sign, corr
117                ));
118            }
119        }
120
121        let passed = mismatches.is_empty();
122        let details = if passed {
123            format!("All {correct_signs}/{total_edges} edges have correct correlation signs")
124        } else {
125            format!(
126                "{}/{} edges have incorrect signs: {}",
127                mismatches.len(),
128                total_edges,
129                mismatches.join("; ")
130            )
131        };
132
133        CausalCheck {
134            name: "edge_correlation_signs".to_string(),
135            passed,
136            details,
137        }
138    }
139
140    /// Check 2: Verify non-edges have weaker correlation than edges (on average).
141    fn check_non_edge_weakness(
142        samples: &[HashMap<String, f64>],
143        graph: &CausalGraph,
144    ) -> CausalCheck {
145        let var_names = graph.variable_names();
146
147        // Compute average absolute correlation for edges
148        let mut edge_corrs = Vec::new();
149        for edge in &graph.edges {
150            let parent_vals: Vec<f64> = samples
151                .iter()
152                .filter_map(|s| s.get(&edge.from).copied())
153                .collect();
154            let child_vals: Vec<f64> = samples
155                .iter()
156                .filter_map(|s| s.get(&edge.to).copied())
157                .collect();
158            let corr = pearson_correlation(&parent_vals, &child_vals).abs();
159            if corr.is_finite() {
160                edge_corrs.push(corr);
161            }
162        }
163
164        // Build set of edge pairs for fast lookup
165        let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
166            .edges
167            .iter()
168            .map(|e| (e.from.as_str(), e.to.as_str()))
169            .collect();
170
171        // Compute average absolute correlation for non-edges (direct only)
172        let mut non_edge_corrs = Vec::new();
173        for (i, &vi) in var_names.iter().enumerate() {
174            for &vj in var_names.iter().skip(i + 1) {
175                if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
176                    continue;
177                }
178                let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
179                let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
180                let corr = pearson_correlation(&vals_i, &vals_j).abs();
181                if corr.is_finite() {
182                    non_edge_corrs.push(corr);
183                }
184            }
185        }
186
187        let avg_edge = if edge_corrs.is_empty() {
188            0.0
189        } else {
190            edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
191        };
192
193        let avg_non_edge = if non_edge_corrs.is_empty() {
194            0.0
195        } else {
196            non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
197        };
198
199        // Non-edges should have weaker average correlation than edges
200        let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
201
202        let details = format!(
203            "Avg edge correlation: {avg_edge:.4}, avg non-edge correlation: {avg_non_edge:.4}"
204        );
205
206        CausalCheck {
207            name: "non_edge_weakness".to_string(),
208            passed,
209            details,
210        }
211    }
212
213    /// Check 3: Verify topological ordering holds in conditional means.
214    ///
215    /// For parent -> child edges, the mean of child should shift when we split
216    /// samples by parent median.
217    fn check_topological_consistency(
218        samples: &[HashMap<String, f64>],
219        graph: &CausalGraph,
220    ) -> CausalCheck {
221        let mut total_checked = 0;
222        let mut consistent = 0;
223
224        for edge in &graph.edges {
225            let expected_sign = Self::mechanism_sign(&edge.mechanism);
226            if expected_sign == 0 {
227                continue;
228            }
229
230            let mut parent_vals: Vec<f64> = samples
231                .iter()
232                .filter_map(|s| s.get(&edge.from).copied())
233                .collect();
234            parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
235
236            if parent_vals.is_empty() {
237                continue;
238            }
239
240            let median_idx = parent_vals.len() / 2;
241            let median = parent_vals[median_idx];
242
243            // Split child values by parent median
244            let child_low: Vec<f64> = samples
245                .iter()
246                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
247                .filter_map(|s| s.get(&edge.to).copied())
248                .collect();
249
250            let child_high: Vec<f64> = samples
251                .iter()
252                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
253                .filter_map(|s| s.get(&edge.to).copied())
254                .collect();
255
256            if child_low.is_empty() || child_high.is_empty() {
257                continue;
258            }
259
260            let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
261            let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
262
263            total_checked += 1;
264
265            // Check that the direction of mean shift matches expected sign
266            let actual_sign = if mean_high > mean_low + 1e-10 {
267                1
268            } else if mean_high < mean_low - 1e-10 {
269                -1
270            } else {
271                0
272            };
273
274            if actual_sign == expected_sign || actual_sign == 0 {
275                consistent += 1;
276            }
277        }
278
279        let passed = total_checked == 0 || consistent >= total_checked / 2;
280        let details =
281            format!("{consistent}/{total_checked} edges show consistent conditional mean ordering");
282
283        CausalCheck {
284            name: "topological_consistency".to_string(),
285            passed,
286            details,
287        }
288    }
289
290    /// Determine the expected sign of a mechanism's effect.
291    /// Returns 1 for positive, -1 for negative, 0 for indeterminate.
292    fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
293        match mechanism {
294            CausalMechanism::Linear { coefficient } => {
295                if *coefficient > 0.0 {
296                    1
297                } else if *coefficient < 0.0 {
298                    -1
299                } else {
300                    0
301                }
302            }
303            CausalMechanism::Threshold { .. } => {
304                // Threshold is monotonically non-decreasing (0 or 1)
305                1
306            }
307            CausalMechanism::Logistic { scale, .. } => {
308                if *scale > 0.0 {
309                    1
310                } else if *scale < 0.0 {
311                    -1
312                } else {
313                    0
314                }
315            }
316            CausalMechanism::Polynomial { coefficients } => {
317                // Use sign of highest non-zero coefficient as a heuristic
318                for coeff in coefficients.iter().rev() {
319                    if *coeff > 0.0 {
320                        return 1;
321                    } else if *coeff < 0.0 {
322                        return -1;
323                    }
324                }
325                0
326            }
327        }
328    }
329}
330
331/// Compute Pearson correlation coefficient between two vectors.
332fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
333    let n = x.len().min(y.len());
334    if n < 2 {
335        return 0.0;
336    }
337
338    let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
339    let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
340
341    let mut sum_xy = 0.0;
342    let mut sum_x2 = 0.0;
343    let mut sum_y2 = 0.0;
344
345    for i in 0..n {
346        let dx = x[i] - mean_x;
347        let dy = y[i] - mean_y;
348        sum_xy += dx * dy;
349        sum_x2 += dx * dx;
350        sum_y2 += dy * dy;
351    }
352
353    let denom = (sum_x2 * sum_y2).sqrt();
354    if denom < 1e-15 {
355        0.0
356    } else {
357        sum_xy / denom
358    }
359}
360
361#[cfg(test)]
362#[allow(clippy::unwrap_used)]
363mod tests {
364    use super::*;
365    use crate::causal::graph::CausalGraph;
366    use crate::causal::scm::StructuralCausalModel;
367
368    #[test]
369    fn test_causal_validation_passes_on_correct_data() {
370        let graph = CausalGraph::fraud_detection_template();
371        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
372        let samples = scm.generate(1000, 42).unwrap();
373
374        let report = CausalValidator::validate_causal_structure(&samples, &graph);
375
376        assert!(
377            report.valid,
378            "Validation should pass on correctly generated data. Violations: {:?}",
379            report.violations
380        );
381        assert_eq!(report.checks.len(), 3);
382        assert!(report.violations.is_empty());
383    }
384
385    #[test]
386    fn test_causal_validation_detects_shuffled_columns() {
387        let graph = CausalGraph::fraud_detection_template();
388        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
389        let mut samples = scm.generate(2000, 42).unwrap();
390
391        // Shuffle the fraud_probability column by rotating values.
392        // This breaks the causal relationship between parents and fraud_probability.
393        let n = samples.len();
394        let fp_values: Vec<f64> = samples
395            .iter()
396            .filter_map(|s| s.get("fraud_probability").copied())
397            .collect();
398
399        for (i, sample) in samples.iter_mut().enumerate() {
400            let shifted_idx = (i + n / 2) % n;
401            sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
402        }
403
404        let report = CausalValidator::validate_causal_structure(&samples, &graph);
405
406        // At least one check should fail when causal structure is broken
407        let has_failure = report.checks.iter().any(|c| !c.passed);
408        assert!(
409            has_failure,
410            "Validation should detect broken causal structure. Checks: {:?}",
411            report.checks
412        );
413    }
414
415    #[test]
416    fn test_causal_pearson_correlation_perfect_positive() {
417        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
418        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
419        let corr = pearson_correlation(&x, &y);
420        assert!(
421            (corr - 1.0).abs() < 1e-10,
422            "Perfect positive correlation expected, got {}",
423            corr
424        );
425    }
426
427    #[test]
428    fn test_causal_pearson_correlation_perfect_negative() {
429        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
430        let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
431        let corr = pearson_correlation(&x, &y);
432        assert!(
433            (corr - (-1.0)).abs() < 1e-10,
434            "Perfect negative correlation expected, got {}",
435            corr
436        );
437    }
438
439    #[test]
440    fn test_causal_pearson_correlation_constant() {
441        let x = vec![1.0, 1.0, 1.0, 1.0];
442        let y = vec![2.0, 4.0, 6.0, 8.0];
443        let corr = pearson_correlation(&x, &y);
444        assert!(
445            corr.abs() < 1e-10,
446            "Correlation with constant should be 0, got {}",
447            corr
448        );
449    }
450
451    #[test]
452    fn test_causal_validation_report_structure() {
453        let graph = CausalGraph::fraud_detection_template();
454        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
455        let samples = scm.generate(200, 42).unwrap();
456
457        let report = CausalValidator::validate_causal_structure(&samples, &graph);
458
459        // Should always produce exactly 3 checks
460        assert_eq!(report.checks.len(), 3);
461        assert_eq!(report.checks[0].name, "edge_correlation_signs");
462        assert_eq!(report.checks[1].name, "non_edge_weakness");
463        assert_eq!(report.checks[2].name, "topological_consistency");
464
465        // Each check should have non-empty details
466        for check in &report.checks {
467            assert!(!check.details.is_empty());
468        }
469    }
470
471    #[test]
472    fn test_causal_validation_revenue_cycle() {
473        let graph = CausalGraph::revenue_cycle_template();
474        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
475        let samples = scm.generate(1000, 99).unwrap();
476
477        let report = CausalValidator::validate_causal_structure(&samples, &graph);
478
479        // Most checks should pass on correctly generated data
480        let passing = report.checks.iter().filter(|c| c.passed).count();
481        assert!(
482            passing >= 2,
483            "At least 2 of 3 checks should pass. Checks: {:?}",
484            report.checks
485        );
486    }
487}