Skip to main content

datasynth_core/causal/
validation.rs

1//! Causal structure validation.
2//!
3//! Validates that generated samples respect the causal structure defined by the graph,
4//! checking correlation signs, edge strength, and topological consistency.
5
6use std::collections::HashMap;
7
8use super::graph::{CausalGraph, CausalMechanism};
9
10/// Report from causal structure validation.
11#[derive(Debug, Clone)]
12pub struct CausalValidationReport {
13    /// Whether all checks passed.
14    pub valid: bool,
15    /// Individual check results.
16    pub checks: Vec<CausalCheck>,
17    /// Human-readable violation descriptions.
18    pub violations: Vec<String>,
19}
20
21/// Result of a single validation check.
22#[derive(Debug, Clone)]
23pub struct CausalCheck {
24    /// Name of the check.
25    pub name: String,
26    /// Whether the check passed.
27    pub passed: bool,
28    /// Details about the check result.
29    pub details: String,
30}
31
32/// Validator for causal structure consistency.
33pub struct CausalValidator;
34
35impl CausalValidator {
36    /// Validate that samples respect the causal structure of the graph.
37    ///
38    /// Performs three checks:
39    /// 1. Edge correlation signs match mechanism coefficient signs
40    /// 2. Non-edges have weaker average correlation than edges
41    /// 3. Topological ordering holds in conditional means
42    pub fn validate_causal_structure(
43        samples: &[HashMap<String, f64>],
44        graph: &CausalGraph,
45    ) -> CausalValidationReport {
46        let mut checks = Vec::new();
47        let mut violations = Vec::new();
48
49        // Check 1: Edge correlation signs
50        let sign_check = Self::check_edge_correlation_signs(samples, graph);
51        if !sign_check.passed {
52            violations.push(sign_check.details.clone());
53        }
54        checks.push(sign_check);
55
56        // Check 2: Non-edges have weaker correlation than edges
57        let strength_check = Self::check_non_edge_weakness(samples, graph);
58        if !strength_check.passed {
59            violations.push(strength_check.details.clone());
60        }
61        checks.push(strength_check);
62
63        // Check 3: Topological ordering in conditional means
64        let topo_check = Self::check_topological_consistency(samples, graph);
65        if !topo_check.passed {
66            violations.push(topo_check.details.clone());
67        }
68        checks.push(topo_check);
69
70        let valid = checks.iter().all(|c| c.passed);
71
72        CausalValidationReport {
73            valid,
74            checks,
75            violations,
76        }
77    }
78
79    /// Check 1: For each edge, verify correlation between parent and child
80    /// has the expected sign (based on mechanism coefficient sign).
81    fn check_edge_correlation_signs(
82        samples: &[HashMap<String, f64>],
83        graph: &CausalGraph,
84    ) -> CausalCheck {
85        let mut total_edges = 0;
86        let mut _correct_signs = 0u32;
87        let mut mismatches = Vec::new();
88
89        for edge in &graph.edges {
90            let expected_sign = Self::mechanism_sign(&edge.mechanism);
91            // Skip edges where we can't reliably determine expected sign.
92            // Threshold mechanisms produce binary outputs where correlation
93            // with the continuous parent is often very weak or indeterminate.
94            if expected_sign == 0 || matches!(edge.mechanism, CausalMechanism::Threshold { .. }) {
95                continue;
96            }
97
98            total_edges += 1;
99
100            let parent_vals: Vec<f64> = samples
101                .iter()
102                .filter_map(|s| s.get(&edge.from).copied())
103                .collect();
104            let child_vals: Vec<f64> = samples
105                .iter()
106                .filter_map(|s| s.get(&edge.to).copied())
107                .collect();
108
109            let corr = pearson_correlation(&parent_vals, &child_vals);
110
111            if (expected_sign > 0 && corr > -0.05) || (expected_sign < 0 && corr < 0.05) {
112                _correct_signs += 1;
113            } else {
114                mismatches.push(format!(
115                    "{} -> {}: expected sign {}, got correlation {:.4}",
116                    edge.from, edge.to, expected_sign, corr
117                ));
118            }
119        }
120
121        let passed = mismatches.is_empty();
122        let details = if passed {
123            format!("All {} edges have correct correlation signs", total_edges)
124        } else {
125            format!(
126                "{}/{} edges have incorrect signs: {}",
127                mismatches.len(),
128                total_edges,
129                mismatches.join("; ")
130            )
131        };
132
133        CausalCheck {
134            name: "edge_correlation_signs".to_string(),
135            passed,
136            details,
137        }
138    }
139
140    /// Check 2: Verify non-edges have weaker correlation than edges (on average).
141    fn check_non_edge_weakness(
142        samples: &[HashMap<String, f64>],
143        graph: &CausalGraph,
144    ) -> CausalCheck {
145        let var_names = graph.variable_names();
146
147        // Compute average absolute correlation for edges
148        let mut edge_corrs = Vec::new();
149        for edge in &graph.edges {
150            let parent_vals: Vec<f64> = samples
151                .iter()
152                .filter_map(|s| s.get(&edge.from).copied())
153                .collect();
154            let child_vals: Vec<f64> = samples
155                .iter()
156                .filter_map(|s| s.get(&edge.to).copied())
157                .collect();
158            let corr = pearson_correlation(&parent_vals, &child_vals).abs();
159            if corr.is_finite() {
160                edge_corrs.push(corr);
161            }
162        }
163
164        // Build set of edge pairs for fast lookup
165        let edge_pairs: std::collections::HashSet<(&str, &str)> = graph
166            .edges
167            .iter()
168            .map(|e| (e.from.as_str(), e.to.as_str()))
169            .collect();
170
171        // Compute average absolute correlation for non-edges (direct only)
172        let mut non_edge_corrs = Vec::new();
173        for (i, &vi) in var_names.iter().enumerate() {
174            for &vj in var_names.iter().skip(i + 1) {
175                if edge_pairs.contains(&(vi, vj)) || edge_pairs.contains(&(vj, vi)) {
176                    continue;
177                }
178                let vals_i: Vec<f64> = samples.iter().filter_map(|s| s.get(vi).copied()).collect();
179                let vals_j: Vec<f64> = samples.iter().filter_map(|s| s.get(vj).copied()).collect();
180                let corr = pearson_correlation(&vals_i, &vals_j).abs();
181                if corr.is_finite() {
182                    non_edge_corrs.push(corr);
183                }
184            }
185        }
186
187        let avg_edge = if edge_corrs.is_empty() {
188            0.0
189        } else {
190            edge_corrs.iter().sum::<f64>() / edge_corrs.len() as f64
191        };
192
193        let avg_non_edge = if non_edge_corrs.is_empty() {
194            0.0
195        } else {
196            non_edge_corrs.iter().sum::<f64>() / non_edge_corrs.len() as f64
197        };
198
199        // Non-edges should have weaker average correlation than edges
200        let passed = non_edge_corrs.is_empty() || avg_non_edge <= avg_edge + 0.1;
201
202        let details = format!(
203            "Avg edge correlation: {:.4}, avg non-edge correlation: {:.4}",
204            avg_edge, avg_non_edge
205        );
206
207        CausalCheck {
208            name: "non_edge_weakness".to_string(),
209            passed,
210            details,
211        }
212    }
213
214    /// Check 3: Verify topological ordering holds in conditional means.
215    ///
216    /// For parent -> child edges, the mean of child should shift when we split
217    /// samples by parent median.
218    fn check_topological_consistency(
219        samples: &[HashMap<String, f64>],
220        graph: &CausalGraph,
221    ) -> CausalCheck {
222        let mut total_checked = 0;
223        let mut consistent = 0;
224
225        for edge in &graph.edges {
226            let expected_sign = Self::mechanism_sign(&edge.mechanism);
227            if expected_sign == 0 {
228                continue;
229            }
230
231            let mut parent_vals: Vec<f64> = samples
232                .iter()
233                .filter_map(|s| s.get(&edge.from).copied())
234                .collect();
235            parent_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236
237            if parent_vals.is_empty() {
238                continue;
239            }
240
241            let median_idx = parent_vals.len() / 2;
242            let median = parent_vals[median_idx];
243
244            // Split child values by parent median
245            let child_low: Vec<f64> = samples
246                .iter()
247                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) <= median)
248                .filter_map(|s| s.get(&edge.to).copied())
249                .collect();
250
251            let child_high: Vec<f64> = samples
252                .iter()
253                .filter(|s| s.get(&edge.from).copied().unwrap_or(0.0) > median)
254                .filter_map(|s| s.get(&edge.to).copied())
255                .collect();
256
257            if child_low.is_empty() || child_high.is_empty() {
258                continue;
259            }
260
261            let mean_low = child_low.iter().sum::<f64>() / child_low.len() as f64;
262            let mean_high = child_high.iter().sum::<f64>() / child_high.len() as f64;
263
264            total_checked += 1;
265
266            // Check that the direction of mean shift matches expected sign
267            let actual_sign = if mean_high > mean_low + 1e-10 {
268                1
269            } else if mean_high < mean_low - 1e-10 {
270                -1
271            } else {
272                0
273            };
274
275            if actual_sign == expected_sign || actual_sign == 0 {
276                consistent += 1;
277            }
278        }
279
280        let passed = total_checked == 0 || consistent >= total_checked / 2;
281        let details = format!(
282            "{}/{} edges show consistent conditional mean ordering",
283            consistent, total_checked
284        );
285
286        CausalCheck {
287            name: "topological_consistency".to_string(),
288            passed,
289            details,
290        }
291    }
292
293    /// Determine the expected sign of a mechanism's effect.
294    /// Returns 1 for positive, -1 for negative, 0 for indeterminate.
295    fn mechanism_sign(mechanism: &CausalMechanism) -> i32 {
296        match mechanism {
297            CausalMechanism::Linear { coefficient } => {
298                if *coefficient > 0.0 {
299                    1
300                } else if *coefficient < 0.0 {
301                    -1
302                } else {
303                    0
304                }
305            }
306            CausalMechanism::Threshold { .. } => {
307                // Threshold is monotonically non-decreasing (0 or 1)
308                1
309            }
310            CausalMechanism::Logistic { scale, .. } => {
311                if *scale > 0.0 {
312                    1
313                } else if *scale < 0.0 {
314                    -1
315                } else {
316                    0
317                }
318            }
319            CausalMechanism::Polynomial { coefficients } => {
320                // Use sign of highest non-zero coefficient as a heuristic
321                for coeff in coefficients.iter().rev() {
322                    if *coeff > 0.0 {
323                        return 1;
324                    } else if *coeff < 0.0 {
325                        return -1;
326                    }
327                }
328                0
329            }
330        }
331    }
332}
333
334/// Compute Pearson correlation coefficient between two vectors.
335fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
336    let n = x.len().min(y.len());
337    if n < 2 {
338        return 0.0;
339    }
340
341    let mean_x = x.iter().take(n).sum::<f64>() / n as f64;
342    let mean_y = y.iter().take(n).sum::<f64>() / n as f64;
343
344    let mut sum_xy = 0.0;
345    let mut sum_x2 = 0.0;
346    let mut sum_y2 = 0.0;
347
348    for i in 0..n {
349        let dx = x[i] - mean_x;
350        let dy = y[i] - mean_y;
351        sum_xy += dx * dy;
352        sum_x2 += dx * dx;
353        sum_y2 += dy * dy;
354    }
355
356    let denom = (sum_x2 * sum_y2).sqrt();
357    if denom < 1e-15 {
358        0.0
359    } else {
360        sum_xy / denom
361    }
362}
363
364#[cfg(test)]
365#[allow(clippy::unwrap_used)]
366mod tests {
367    use super::*;
368    use crate::causal::graph::CausalGraph;
369    use crate::causal::scm::StructuralCausalModel;
370
371    #[test]
372    fn test_causal_validation_passes_on_correct_data() {
373        let graph = CausalGraph::fraud_detection_template();
374        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
375        let samples = scm.generate(1000, 42).unwrap();
376
377        let report = CausalValidator::validate_causal_structure(&samples, &graph);
378
379        assert!(
380            report.valid,
381            "Validation should pass on correctly generated data. Violations: {:?}",
382            report.violations
383        );
384        assert_eq!(report.checks.len(), 3);
385        assert!(report.violations.is_empty());
386    }
387
388    #[test]
389    fn test_causal_validation_detects_shuffled_columns() {
390        let graph = CausalGraph::fraud_detection_template();
391        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
392        let mut samples = scm.generate(500, 42).unwrap();
393
394        // Shuffle the fraud_probability column by rotating values.
395        // This breaks the causal relationship between parents and fraud_probability.
396        let n = samples.len();
397        let fp_values: Vec<f64> = samples
398            .iter()
399            .filter_map(|s| s.get("fraud_probability").copied())
400            .collect();
401
402        for (i, sample) in samples.iter_mut().enumerate() {
403            let shifted_idx = (i + n / 2) % n;
404            sample.insert("fraud_probability".to_string(), fp_values[shifted_idx]);
405        }
406
407        let report = CausalValidator::validate_causal_structure(&samples, &graph);
408
409        // At least one check should fail when causal structure is broken
410        let has_failure = report.checks.iter().any(|c| !c.passed);
411        assert!(
412            has_failure,
413            "Validation should detect broken causal structure. Checks: {:?}",
414            report.checks
415        );
416    }
417
418    #[test]
419    fn test_causal_pearson_correlation_perfect_positive() {
420        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
421        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
422        let corr = pearson_correlation(&x, &y);
423        assert!(
424            (corr - 1.0).abs() < 1e-10,
425            "Perfect positive correlation expected, got {}",
426            corr
427        );
428    }
429
430    #[test]
431    fn test_causal_pearson_correlation_perfect_negative() {
432        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
433        let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
434        let corr = pearson_correlation(&x, &y);
435        assert!(
436            (corr - (-1.0)).abs() < 1e-10,
437            "Perfect negative correlation expected, got {}",
438            corr
439        );
440    }
441
442    #[test]
443    fn test_causal_pearson_correlation_constant() {
444        let x = vec![1.0, 1.0, 1.0, 1.0];
445        let y = vec![2.0, 4.0, 6.0, 8.0];
446        let corr = pearson_correlation(&x, &y);
447        assert!(
448            corr.abs() < 1e-10,
449            "Correlation with constant should be 0, got {}",
450            corr
451        );
452    }
453
454    #[test]
455    fn test_causal_validation_report_structure() {
456        let graph = CausalGraph::fraud_detection_template();
457        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
458        let samples = scm.generate(200, 42).unwrap();
459
460        let report = CausalValidator::validate_causal_structure(&samples, &graph);
461
462        // Should always produce exactly 3 checks
463        assert_eq!(report.checks.len(), 3);
464        assert_eq!(report.checks[0].name, "edge_correlation_signs");
465        assert_eq!(report.checks[1].name, "non_edge_weakness");
466        assert_eq!(report.checks[2].name, "topological_consistency");
467
468        // Each check should have non-empty details
469        for check in &report.checks {
470            assert!(!check.details.is_empty());
471        }
472    }
473
474    #[test]
475    fn test_causal_validation_revenue_cycle() {
476        let graph = CausalGraph::revenue_cycle_template();
477        let scm = StructuralCausalModel::new(graph.clone()).unwrap();
478        let samples = scm.generate(1000, 99).unwrap();
479
480        let report = CausalValidator::validate_causal_structure(&samples, &graph);
481
482        // Most checks should pass on correctly generated data
483        let passing = report.checks.iter().filter(|c| c.passed).count();
484        assert!(
485            passing >= 2,
486            "At least 2 of 3 checks should pass. Checks: {:?}",
487            report.checks
488        );
489    }
490}