Skip to main content

datasynth_core/causal/
validation.rs

1//! Causal structure validation.
2//!
3//! Validates that generated samples respect the causal structure defined by the graph,
4//! checking correlation signs, edge strength, and topological consistency.
5
6use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10/// Report from causal structure validation.
11#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13    /// Whether all checks passed.
14    pub valid: bool,
15    /// Individual check results.
16    pub checks: Vec<CausalCheck>,
17    /// Human-readable violation descriptions.
18    pub violations: Vec<String>,
19}
20
21/// Result of a single validation check.
22#[derive(Debug, Clone)]
23pub struct CausalCheck {
24    /// Name of the check.
25    pub name: String,
26    /// Whether the check passed.
27    pub passed: bool,
28    /// Details about the check result.
29    pub details: String,
30}
31
32/// Validator for causal structure consistency.
33pub struct CausalValidator;
34
35impl CausalValidator {
36    /// Validate that samples respect the causal structure of the graph.
37    ///
38    /// Performs three checks:
39    /// 1. Edge correlation signs match mechanism coefficient signs
40    /// 2. Non-edges have weaker average correlation than edges
41    /// 3. Topological ordering holds in conditional means
42    pub fn validate_causal_structure(
43        samples: &[HashMap<String, f64>],
44        graph: &CausalGraph,
45    ) -> CausalValidationReport {
46        let mut checks = Vec::new();
47        let mut violations = Vec::new();
48
49        // Check 1: Edge correlation signs
50        let sign_check = Self::check_edge_correlation_signs(samples, graph);
51        if !sign_check.passed {
52            violations.push(sign_check.details.clone());
53        }
54        checks.push(sign_check);
55
56        // Check 2: Non-edges have weaker correlation than edges
57        let strength_check = Self::check_non_edge_weakness(samples, graph);
58        if !strength_check.passed {
59            violations.push(strength_check.details.clone());
60        }
61        checks.push(strength_check);
62
63        // Check 3: Topological ordering in conditional means
64        let topo_check = Self::check_topological_consistency(samples, graph);
65        if !topo_check.passed {
66            violations.push(topo_check.details.clone());
67        }
68        checks.push(topo_check);
69
70        let valid = checks.iter().all(|c| c.passed);
71
72        CausalValidationReport {
73            valid,
74            checks,
75            violations,
76        }
77    }
78
79    /// Check 1: For each edge, verify correlation between parent and child
80    /// has the expected sign (based on mechanism coefficient sign).
81    fn check_edge_correlation_signs(
82        samples: &[HashMap<String, f64>],
83        graph: &CausalGraph,
84    ) -> CausalCheck {
85        let mut total_edges = 0;
86        let mut correct_signs = 0u32;
87        let mut mismatches = Vec::new();
88
89        for edge in &graph.edges {
90            let expected_sign = Self::mechanism_sign(&edge.mechanism);
91            // Skip edges where we can't reliably determine expected sign.
92            // Threshold mechanisms produce binary outputs where correlation
93            // with the continuous parent is often very weak or indeterminate.
94            if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95                continue;
96            }
97
98            total_edges += 1;
99
100            let parent_vals: Vec<f64> = samples
101                .iter()
102                .filter_map(|s| s.get(&edge.from).copied())
103                .collect();
104            let child_vals: Vec<f64> = samples
105                .iter()
106                .filter_map(|s| s.get(&edge.to).copied())
107                .collect();
108
109            let corr = pearson_correlation(&parent_vals, &child_vals);
110
111            if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112                correct_signs += 1;
113            } else {
114                mismatches.push(format!(
115                    "{} -> {}: expected sign {}, got correlation {:.4}",
116                    edge.from, edge.to, expected_sign, corr
117                ));
118            }
119        }
120
121        let passed = mismatches.is_empty();
122        let details = if passed {
123            format!("All {correct_signs}/{total_edges} edges have correct correlation signs")
124        } else {
125            format!(
126                "{}/{} edges have incorrect signs: {}",
127                mismatches.len(),
128                total_edges,
129                mismatches.join("; ")
130            )
131        };
132
133        CausalCheck {
134            name: "edge_correlation_signs".to_string(),
135            passed,
136            details,
137        }
138    }
139
140    /// Check 2: Verify non-edges have weaker correlation than edges (on average).
141    fn check_non_edge_weakness(
142        samples: &[HashMap<String, f64>],
143        graph: &CausalGraph,
144    ) -> CausalCheck {
145        let var_names = graph.variable_names();
146
147        // Compute average absolute correlation for edges
148        let mut edge_corrs = Vec::new();
149        for edge in &graph.edges {
150            let parent_vals: Vec<f64> = samples
151                .iter()
152                .filter_map(|s| s.get(&edge.from).copied())
153                .collect();
154            let child_vals: Vec<f64> = samples
155                .iter()
156                .filter_map(|s| s.get(&edge.to).copied())
157                .collect();
158            let corr = pearson_correlation(&parent_vals, &child_vals).abs();
159            if corr.is_finite() {
160                edge_corrs.push(corr);
161            }
162        }
163
164        // Build set of edge pairs for fast lookup
165        let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
166            .edges
167            .iter()
168            .map(|e| (e.from.as_str(), e.to.as_str()))
169            .collect();
170
171        // Compute average absolute correlation for non-edges (direct only)
172        let mut non_edge_corrs = Vec::new();
173        for (i, &vi) in var_names.iter().enumerate() {
174            for &vj in var_names.iter().skip(i + 1) {
175                if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
176                    continue;
177                }
178                let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
179                let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
180                let corr = pearson_correlation(&vals_i, &vals_j).abs();
181                if corr.is_finite() {
182                    non_edge_corrs.push(corr);
183                }
184            }
185        }
186
187        let avg_edge = if edge_corrs.is_empty() {
188            0.0
189        } else {
190            edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
191        };
192
193        let avg_non_edge = if non_edge_corrs.is_empty() {
194            0.0
195        } else {
196            non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
197        };
198
199        // Non-edges should have weaker average correlation than edges
200        let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
201
202        let details = format!(
203            "Avg edge correlation: {avg_edge:.4}, avg non-edge correlation: {avg_non_edge:.4}"
204        );
205
206        CausalCheck {
207            name: "non_edge_weakness".to_string(),
208            passed,
209            details,
210        }
211    }
212
213    /// Check 3: Verify topological ordering holds in conditional means.
214    ///
215    /// For parent -> child edges, the mean of child should shift when we split
216    /// samples by parent median.
217    fn check_topological_consistency(
218        samples: &[HashMap<String, f64>],
219        graph: &CausalGraph,
220    ) -> CausalCheck {
221        let mut total_checked = 0;
222        let mut consistent = 0;
223
224        for edge in &graph.edges {
225            let expected_sign = Self::mechanism_sign(&edge.mechanism);
226            if expected_sign == 0 {
227                continue;
228            }
229
230            let mut parent_vals: Vec<f64> = samples
231                .iter()
232                .filter_map(|s| s.get(&edge.from).copied())
233                .collect();
234            parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
235
236            if parent_vals.is_empty() {
237                continue;
238            }
239
240            let median_idx = parent_vals.len() / 2;
241            let median = parent_vals[median_idx];
242
243            // Split child values by parent median
244            let child_low: Vec<f64> = samples
245                .iter()
246                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
247                .filter_map(|s| s.get(&edge.to).copied())
248                .collect();
249
250            let child_high: Vec<f64> = samples
251                .iter()
252                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
253                .filter_map(|s| s.get(&edge.to).copied())
254                .collect();
255
256            if child_low.is_empty() || child_high.is_empty() {
257                continue;
258            }
259
260            let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
261            let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
262
263            total_checked += 1;
264
265            // Check that the direction of mean shift matches expected sign
266            let actual_sign = if mean_high > mean_low + 1e-10 {
267                1
268            } else if mean_high < mean_low - 1e-10 {
269                -1
270            } else {
271                0
272            };
273
274            if actual_sign == expected_sign || actual_sign == 0 {
275                consistent += 1;
276            }
277        }
278
279        let passed = total_checked == 0 || consistent >= total_checked / 2;
280        let details =
281            format!("{consistent}/{total_checked} edges show consistent conditional mean ordering");
282
283        CausalCheck {
284            name: "topological_consistency".to_string(),
285            passed,
286            details,
287        }
288    }
289
290    /// Determine the expected sign of a mechanism's effect.
291    /// Returns 1 for positive, -1 for negative, 0 for indeterminate.
292    fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
293        match mechanism {
294            CausalMechanism::Linear { coefficient } => {
295                if *coefficient > 0.0 {
296                    1
297                } else if *coefficient < 0.0 {
298                    -1
299                } else {
300                    0
301                }
302            }
303            CausalMechanism::Threshold { .. } => {
304                // Threshold is monotonically non-decreasing (0 or 1)
305                1
306            }
307            CausalMechanism::Logistic { scale, .. } => {
308                if *scale > 0.0 {
309                    1
310                } else if *scale < 0.0 {
311                    -1
312                } else {
313                    0
314                }
315            }
316            CausalMechanism::Polynomial { coefficients } => {
317                // Use sign of highest non-zero coefficient as a heuristic
318                for coeff in coefficients.iter().rev() {
319                    if *coeff > 0.0 {
320                        return 1;
321                    } else if *coeff < 0.0 {
322                        return -1;
323                    }
324                }
325                0
326            }
327        }
328    }
329}
330
331/// Compute Pearson correlation coefficient between two vectors.
332fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
333    let n = x.len().min(y.len());
334    if n < 2 {
335        return 0.0;
336    }
337
338    let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
339    let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
340
341    let mut sum_xy = 0.0;
342    let mut sum_x2 = 0.0;
343    let mut sum_y2 = 0.0;
344
345    for i in 0..n {
346        let dx = x[i] - mean_x;
347        let dy = y[i] - mean_y;
348        sum_xy += dx * dy;
349        sum_x2 += dx * dx;
350        sum_y2 += dy * dy;
351    }
352
353    let denom = (sum_x2 * sum_y2).sqrt();
354    if denom < 1e-15 {
355        0.0
356    } else {
357        sum_xy / denom
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::causal::graph::CausalGraph;
365    use crate::causal::scm::StructuralCausalModel;
366
367    #[test]
368    fn test_causal_validation_passes_on_correct_data() {
369        let graph = CausalGraph::fraud_detection_template();
370        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
371        let samples = scm.generate(1000, 42).unwrap();
372
373        let report = CausalValidator::validate_causal_structure(&samples, &graph);
374
375        assert!(
376            report.valid,
377            "Validation should pass on correctly generated data. Violations: {:?}",
378            report.violations
379        );
380        assert_eq!(report.checks.len(), 3);
381        assert!(report.violations.is_empty());
382    }
383
384    #[test]
385    fn test_causal_validation_detects_shuffled_columns() {
386        let graph = CausalGraph::fraud_detection_template();
387        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
388        let mut samples = scm.generate(2000, 42).unwrap();
389
390        // Shuffle the fraud_probability column by rotating values.
391        // This breaks the causal relationship between parents and fraud_probability.
392        let n = samples.len();
393        let fp_values: Vec<f64> = samples
394            .iter()
395            .filter_map(|s| s.get("fraud_probability").copied())
396            .collect();
397
398        for (i, sample) in samples.iter_mut().enumerate() {
399            let shifted_idx = (i + n / 2) % n;
400            sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
401        }
402
403        let report = CausalValidator::validate_causal_structure(&samples, &graph);
404
405        // At least one check should fail when causal structure is broken
406        let has_failure = report.checks.iter().any(|c| !c.passed);
407        assert!(
408            has_failure,
409            "Validation should detect broken causal structure. Checks: {:?}",
410            report.checks
411        );
412    }
413
414    #[test]
415    fn test_causal_pearson_correlation_perfect_positive() {
416        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
417        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
418        let corr = pearson_correlation(&x, &y);
419        assert!(
420            (corr - 1.0).abs() < 1e-10,
421            "Perfect positive correlation expected, got {}",
422            corr
423        );
424    }
425
426    #[test]
427    fn test_causal_pearson_correlation_perfect_negative() {
428        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
429        let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
430        let corr = pearson_correlation(&x, &y);
431        assert!(
432            (corr - (-1.0)).abs() < 1e-10,
433            "Perfect negative correlation expected, got {}",
434            corr
435        );
436    }
437
438    #[test]
439    fn test_causal_pearson_correlation_constant() {
440        let x = vec![1.0, 1.0, 1.0, 1.0];
441        let y = vec![2.0, 4.0, 6.0, 8.0];
442        let corr = pearson_correlation(&x, &y);
443        assert!(
444            corr.abs() < 1e-10,
445            "Correlation with constant should be 0, got {}",
446            corr
447        );
448    }
449
450    #[test]
451    fn test_causal_validation_report_structure() {
452        let graph = CausalGraph::fraud_detection_template();
453        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
454        let samples = scm.generate(200, 42).unwrap();
455
456        let report = CausalValidator::validate_causal_structure(&samples, &graph);
457
458        // Should always produce exactly 3 checks
459        assert_eq!(report.checks.len(), 3);
460        assert_eq!(report.checks[0].name, "edge_correlation_signs");
461        assert_eq!(report.checks[1].name, "non_edge_weakness");
462        assert_eq!(report.checks[2].name, "topological_consistency");
463
464        // Each check should have non-empty details
465        for check in &report.checks {
466            assert!(!check.details.is_empty());
467        }
468    }
469
470    #[test]
471    fn test_causal_validation_revenue_cycle() {
472        let graph = CausalGraph::revenue_cycle_template();
473        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
474        let samples = scm.generate(1000, 99).unwrap();
475
476        let report = CausalValidator::validate_causal_structure(&samples, &graph);
477
478        // Most checks should pass on correctly generated data
479        let passing = report.checks.iter().filter(|c| c.passed).count();
480        assert!(
481            passing >= 2,
482            "At least 2 of 3 checks should pass. Checks: {:?}",
483            report.checks
484        );
485    }
486}