Skip to main content

datasynth_core/causal/
validation.rs

1//! Causal structure validation.
2//!
3//! Validates that generated samples respect the causal structure defined by the graph,
4//! checking correlation signs, edge strength, and topological consistency.
5
6use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10/// Report from causal structure validation.
11#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13    /// Whether all checks passed.
14    pub valid: bool,
15    /// Individual check results.
16    pub checks: Vec<CausalCheck>,
17    /// Human-readable violation descriptions.
18    pub violations: Vec<String>,
19}
20
21/// Result of a single validation check.
22#[derive(Debug, Clone)]
23pub struct CausalCheck {
24    /// Name of the check.
25    pub name: String,
26    /// Whether the check passed.
27    pub passed: bool,
28    /// Details about the check result.
29    pub details: String,
30}
31
32/// Validator for causal structure consistency.
33pub struct CausalValidator;
34
35impl CausalValidator {
36    /// Validate that samples respect the causal structure of the graph.
37    ///
38    /// Performs three checks:
39    /// 1. Edge correlation signs match mechanism coefficient signs
40    /// 2. Non-edges have weaker average correlation than edges
41    /// 3. Topological ordering holds in conditional means
42    pub fn validate_causal_structure(
43        samples: &[HashMap<String, f64>],
44        graph: &CausalGraph,
45    ) -> CausalValidationReport {
46        let mut checks = Vec::new();
47        let mut violations = Vec::new();
48
49        // Check 1: Edge correlation signs
50        let sign_check = Self::check_edge_correlation_signs(samples, graph);
51        if !sign_check.passed {
52            violations.push(sign_check.details.clone());
53        }
54        checks.push(sign_check);
55
56        // Check 2: Non-edges have weaker correlation than edges
57        let strength_check = Self::check_non_edge_weakness(samples, graph);
58        if !strength_check.passed {
59            violations.push(strength_check.details.clone());
60        }
61        checks.push(strength_check);
62
63        // Check 3: Topological ordering in conditional means
64        let topo_check = Self::check_topological_consistency(samples, graph);
65        if !topo_check.passed {
66            violations.push(topo_check.details.clone());
67        }
68        checks.push(topo_check);
69
70        let valid = checks.iter().all(|c| c.passed);
71
72        CausalValidationReport {
73            valid,
74            checks,
75            violations,
76        }
77    }
78
79    /// Check 1: For each edge, verify correlation between parent and child
80    /// has the expected sign (based on mechanism coefficient sign).
81    fn check_edge_correlation_signs(
82        samples: &[HashMap<String, f64>],
83        graph: &CausalGraph,
84    ) -> CausalCheck {
85        let mut total_edges = 0;
86        let mut correct_signs = 0u32;
87        let mut mismatches = Vec::new();
88
89        for edge in &graph.edges {
90            let expected_sign = Self::mechanism_sign(&edge.mechanism);
91            // Skip edges where we can't reliably determine expected sign.
92            // Threshold mechanisms produce binary outputs where correlation
93            // with the continuous parent is often very weak or indeterminate.
94            if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95                continue;
96            }
97
98            total_edges += 1;
99
100            let parent_vals: Vec<f64> = samples
101                .iter()
102                .filter_map(|s| s.get(&edge.from).copied())
103                .collect();
104            let child_vals: Vec<f64> = samples
105                .iter()
106                .filter_map(|s| s.get(&edge.to).copied())
107                .collect();
108
109            let corr = pearson_correlation(&parent_vals, &child_vals);
110
111            if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112                correct_signs += 1;
113            } else {
114                mismatches.push(format!(
115                    "{} -> {}: expected sign {}, got correlation {:.4}",
116                    edge.from, edge.to, expected_sign, corr
117                ));
118            }
119        }
120
121        let passed = mismatches.is_empty();
122        let details = if passed {
123            format!(
124                "All {}/{} edges have correct correlation signs",
125                correct_signs, total_edges
126            )
127        } else {
128            format!(
129                "{}/{} edges have incorrect signs: {}",
130                mismatches.len(),
131                total_edges,
132                mismatches.join("; ")
133            )
134        };
135
136        CausalCheck {
137            name: "edge_correlation_signs".to_string(),
138            passed,
139            details,
140        }
141    }
142
143    /// Check 2: Verify non-edges have weaker correlation than edges (on average).
144    fn check_non_edge_weakness(
145        samples: &[HashMap<String, f64>],
146        graph: &CausalGraph,
147    ) -> CausalCheck {
148        let var_names = graph.variable_names();
149
150        // Compute average absolute correlation for edges
151        let mut edge_corrs = Vec::new();
152        for edge in &graph.edges {
153            let parent_vals: Vec<f64> = samples
154                .iter()
155                .filter_map(|s| s.get(&edge.from).copied())
156                .collect();
157            let child_vals: Vec<f64> = samples
158                .iter()
159                .filter_map(|s| s.get(&edge.to).copied())
160                .collect();
161            let corr = pearson_correlation(&parent_vals, &child_vals).abs();
162            if corr.is_finite() {
163                edge_corrs.push(corr);
164            }
165        }
166
167        // Build set of edge pairs for fast lookup
168        let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
169            .edges
170            .iter()
171            .map(|e| (e.from.as_str(), e.to.as_str()))
172            .collect();
173
174        // Compute average absolute correlation for non-edges (direct only)
175        let mut non_edge_corrs = Vec::new();
176        for (i, &vi) in var_names.iter().enumerate() {
177            for &vj in var_names.iter().skip(i + 1) {
178                if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
179                    continue;
180                }
181                let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
182                let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
183                let corr = pearson_correlation(&vals_i, &vals_j).abs();
184                if corr.is_finite() {
185                    non_edge_corrs.push(corr);
186                }
187            }
188        }
189
190        let avg_edge = if edge_corrs.is_empty() {
191            0.0
192        } else {
193            edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
194        };
195
196        let avg_non_edge = if non_edge_corrs.is_empty() {
197            0.0
198        } else {
199            non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
200        };
201
202        // Non-edges should have weaker average correlation than edges
203        let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
204
205        let details = format!(
206            "Avg edge correlation: {:.4}, avg non-edge correlation: {:.4}",
207            avg_edge, avg_non_edge
208        );
209
210        CausalCheck {
211            name: "non_edge_weakness".to_string(),
212            passed,
213            details,
214        }
215    }
216
217    /// Check 3: Verify topological ordering holds in conditional means.
218    ///
219    /// For parent -> child edges, the mean of child should shift when we split
220    /// samples by parent median.
221    fn check_topological_consistency(
222        samples: &[HashMap<String, f64>],
223        graph: &CausalGraph,
224    ) -> CausalCheck {
225        let mut total_checked = 0;
226        let mut consistent = 0;
227
228        for edge in &graph.edges {
229            let expected_sign = Self::mechanism_sign(&edge.mechanism);
230            if expected_sign == 0 {
231                continue;
232            }
233
234            let mut parent_vals: Vec<f64> = samples
235                .iter()
236                .filter_map(|s| s.get(&edge.from).copied())
237                .collect();
238            parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240            if parent_vals.is_empty() {
241                continue;
242            }
243
244            let median_idx = parent_vals.len() / 2;
245            let median = parent_vals[median_idx];
246
247            // Split child values by parent median
248            let child_low: Vec<f64> = samples
249                .iter()
250                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
251                .filter_map(|s| s.get(&edge.to).copied())
252                .collect();
253
254            let child_high: Vec<f64> = samples
255                .iter()
256                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
257                .filter_map(|s| s.get(&edge.to).copied())
258                .collect();
259
260            if child_low.is_empty() || child_high.is_empty() {
261                continue;
262            }
263
264            let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
265            let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
266
267            total_checked += 1;
268
269            // Check that the direction of mean shift matches expected sign
270            let actual_sign = if mean_high > mean_low + 1e-10 {
271                1
272            } else if mean_high < mean_low - 1e-10 {
273                -1
274            } else {
275                0
276            };
277
278            if actual_sign == expected_sign || actual_sign == 0 {
279                consistent += 1;
280            }
281        }
282
283        let passed = total_checked == 0 || consistent >= total_checked / 2;
284        let details = format!(
285            "{}/{} edges show consistent conditional mean ordering",
286            consistent, total_checked
287        );
288
289        CausalCheck {
290            name: "topological_consistency".to_string(),
291            passed,
292            details,
293        }
294    }
295
296    /// Determine the expected sign of a mechanism's effect.
297    /// Returns 1 for positive, -1 for negative, 0 for indeterminate.
298    fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
299        match mechanism {
300            CausalMechanism::Linear { coefficient } => {
301                if *coefficient > 0.0 {
302                    1
303                } else if *coefficient < 0.0 {
304                    -1
305                } else {
306                    0
307                }
308            }
309            CausalMechanism::Threshold { .. } => {
310                // Threshold is monotonically non-decreasing (0 or 1)
311                1
312            }
313            CausalMechanism::Logistic { scale, .. } => {
314                if *scale > 0.0 {
315                    1
316                } else if *scale < 0.0 {
317                    -1
318                } else {
319                    0
320                }
321            }
322            CausalMechanism::Polynomial { coefficients } => {
323                // Use sign of highest non-zero coefficient as a heuristic
324                for coeff in coefficients.iter().rev() {
325                    if *coeff > 0.0 {
326                        return 1;
327                    } else if *coeff < 0.0 {
328                        return -1;
329                    }
330                }
331                0
332            }
333        }
334    }
335}
336
337/// Compute Pearson correlation coefficient between two vectors.
338fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
339    let n = x.len().min(y.len());
340    if n < 2 {
341        return 0.0;
342    }
343
344    let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
345    let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
346
347    let mut sum_xy = 0.0;
348    let mut sum_x2 = 0.0;
349    let mut sum_y2 = 0.0;
350
351    for i in 0..n {
352        let dx = x[i] - mean_x;
353        let dy = y[i] - mean_y;
354        sum_xy += dx * dy;
355        sum_x2 += dx * dx;
356        sum_y2 += dy * dy;
357    }
358
359    let denom = (sum_x2 * sum_y2).sqrt();
360    if denom < 1e-15 {
361        0.0
362    } else {
363        sum_xy / denom
364    }
365}
366
367#[cfg(test)]
368#[allow(clippy::unwrap_used)]
369mod tests {
370    use super::*;
371    use crate::causal::graph::CausalGraph;
372    use crate::causal::scm::StructuralCausalModel;
373
374    #[test]
375    fn test_causal_validation_passes_on_correct_data() {
376        let graph = CausalGraph::fraud_detection_template();
377        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
378        let samples = scm.generate(1000, 42).unwrap();
379
380        let report = CausalValidator::validate_causal_structure(&samples, &graph);
381
382        assert!(
383            report.valid,
384            "Validation should pass on correctly generated data. Violations: {:?}",
385            report.violations
386        );
387        assert_eq!(report.checks.len(), 3);
388        assert!(report.violations.is_empty());
389    }
390
391    #[test]
392    fn test_causal_validation_detects_shuffled_columns() {
393        let graph = CausalGraph::fraud_detection_template();
394        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
395        let mut samples = scm.generate(2000, 42).unwrap();
396
397        // Shuffle the fraud_probability column by rotating values.
398        // This breaks the causal relationship between parents and fraud_probability.
399        let n = samples.len();
400        let fp_values: Vec<f64> = samples
401            .iter()
402            .filter_map(|s| s.get("fraud_probability").copied())
403            .collect();
404
405        for (i, sample) in samples.iter_mut().enumerate() {
406            let shifted_idx = (i + n / 2) % n;
407            sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
408        }
409
410        let report = CausalValidator::validate_causal_structure(&samples, &graph);
411
412        // At least one check should fail when causal structure is broken
413        let has_failure = report.checks.iter().any(|c| !c.passed);
414        assert!(
415            has_failure,
416            "Validation should detect broken causal structure. Checks: {:?}",
417            report.checks
418        );
419    }
420
421    #[test]
422    fn test_causal_pearson_correlation_perfect_positive() {
423        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
424        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
425        let corr = pearson_correlation(&x, &y);
426        assert!(
427            (corr - 1.0).abs() < 1e-10,
428            "Perfect positive correlation expected, got {}",
429            corr
430        );
431    }
432
433    #[test]
434    fn test_causal_pearson_correlation_perfect_negative() {
435        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
436        let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
437        let corr = pearson_correlation(&x, &y);
438        assert!(
439            (corr - (-1.0)).abs() < 1e-10,
440            "Perfect negative correlation expected, got {}",
441            corr
442        );
443    }
444
445    #[test]
446    fn test_causal_pearson_correlation_constant() {
447        let x = vec![1.0, 1.0, 1.0, 1.0];
448        let y = vec![2.0, 4.0, 6.0, 8.0];
449        let corr = pearson_correlation(&x, &y);
450        assert!(
451            corr.abs() < 1e-10,
452            "Correlation with constant should be 0, got {}",
453            corr
454        );
455    }
456
457    #[test]
458    fn test_causal_validation_report_structure() {
459        let graph = CausalGraph::fraud_detection_template();
460        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
461        let samples = scm.generate(200, 42).unwrap();
462
463        let report = CausalValidator::validate_causal_structure(&samples, &graph);
464
465        // Should always produce exactly 3 checks
466        assert_eq!(report.checks.len(), 3);
467        assert_eq!(report.checks[0].name, "edge_correlation_signs");
468        assert_eq!(report.checks[1].name, "non_edge_weakness");
469        assert_eq!(report.checks[2].name, "topological_consistency");
470
471        // Each check should have non-empty details
472        for check in &report.checks {
473            assert!(!check.details.is_empty());
474        }
475    }
476
477    #[test]
478    fn test_causal_validation_revenue_cycle() {
479        let graph = CausalGraph::revenue_cycle_template();
480        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
481        let samples = scm.generate(1000, 99).unwrap();
482
483        let report = CausalValidator::validate_causal_structure(&samples, &graph);
484
485        // Most checks should pass on correctly generated data
486        let passing = report.checks.iter().filter(|c| c.passed).count();
487        assert!(
488            passing >= 2,
489            "At least 2 of 3 checks should pass. Checks: {:?}",
490            report.checks
491        );
492    }
493}