Skip to main content

oxiphysics_core/
causal_inference.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Causal inference and structural causal models.
6//!
7//! Provides:
8//! - [`CausalGraph`]               — directed acyclic graph, topological sort, d-separation
9//! - [`StructuralCausalModel`]     — linear SCM with noise terms, do-calculus interventions
10//! - [`BackdoorCriterion`]         — backdoor criterion check and adjustment
11//! - [`FrontdoorCriterion`]        — frontdoor adjustment formula
12//! - [`PropensityScoreMatching`]   — propensity score estimation, ATT/ATE
13//! - [`InstrumentalVariables`]     — IV estimation, two-stage least squares (2SLS)
14//! - [`CausalDiscovery`]           — PC algorithm skeleton, orientation rules
15//! - [`CounterfactualQuery`]       — E\[Y|do(X=x), Z=z\] style queries
16
17#![allow(dead_code)]
18#![allow(clippy::too_many_arguments)]
19
20use std::collections::{HashMap, HashSet, VecDeque};
21
22// ---------------------------------------------------------------------------
23// CausalGraph
24// ---------------------------------------------------------------------------
25
26/// A directed acyclic graph (DAG) representing causal relationships between variables.
27///
28/// Nodes are identified by `usize` indices. Edges represent direct causal effects
29/// from parent to child. The graph must be acyclic for causal semantics to be valid.
30#[derive(Debug, Clone)]
31pub struct CausalGraph {
32    /// Number of nodes in the graph.
33    pub n_nodes: usize,
34    /// Adjacency list: `parents[v]` gives all parent nodes of `v`.
35    pub parents: Vec<Vec<usize>>,
36    /// Adjacency list: `children[v]` gives all children of `v`.
37    pub children: Vec<Vec<usize>>,
38    /// Optional variable names.
39    pub names: Vec<String>,
40}
41
42impl CausalGraph {
43    /// Create a new empty causal graph with `n` nodes.
44    ///
45    /// # Arguments
46    /// * `n` — number of nodes
47    pub fn new(n: usize) -> Self {
48        Self {
49            n_nodes: n,
50            parents: vec![vec![]; n],
51            children: vec![vec![]; n],
52            names: (0..n).map(|i| format!("X{i}")).collect(),
53        }
54    }
55
56    /// Set variable names.
57    ///
58    /// # Arguments
59    /// * `names` — slice of names, length must equal `n_nodes`
60    pub fn set_names(&mut self, names: &[&str]) {
61        assert_eq!(names.len(), self.n_nodes);
62        self.names = names.iter().map(|s| s.to_string()).collect();
63    }
64
65    /// Add a directed edge from `from` (parent/cause) to `to` (child/effect).
66    ///
67    /// # Panics
68    /// Panics if adding this edge would create a cycle.
69    pub fn add_edge(&mut self, from: usize, to: usize) {
70        assert!(
71            from < self.n_nodes && to < self.n_nodes,
72            "node index out of bounds"
73        );
74        assert!(
75            !self.creates_cycle(from, to),
76            "edge {from}→{to} would create a cycle"
77        );
78        if !self.children[from].contains(&to) {
79            self.children[from].push(to);
80            self.parents[to].push(from);
81        }
82    }
83
84    /// Check whether adding edge `from→to` would create a cycle.
85    pub fn creates_cycle(&self, from: usize, to: usize) -> bool {
86        // DFS from `to`: if we can reach `from`, adding from→to creates a cycle.
87        let mut visited = vec![false; self.n_nodes];
88        let mut stack = vec![to];
89        while let Some(node) = stack.pop() {
90            if node == from {
91                return true;
92            }
93            if !visited[node] {
94                visited[node] = true;
95                for &child in &self.children[node] {
96                    stack.push(child);
97                }
98            }
99        }
100        false
101    }
102
103    /// Return nodes in topological order (Kahn's algorithm).
104    ///
105    /// Returns `None` if the graph has a cycle (should not happen if edges are
106    /// added via `add_edge`).
107    pub fn topological_sort(&self) -> Option<Vec<usize>> {
108        let mut in_degree: Vec<usize> = self.parents.iter().map(|p| p.len()).collect();
109        let mut queue: VecDeque<usize> = (0..self.n_nodes).filter(|&v| in_degree[v] == 0).collect();
110        let mut order = Vec::with_capacity(self.n_nodes);
111        while let Some(v) = queue.pop_front() {
112            order.push(v);
113            for &child in &self.children[v] {
114                in_degree[child] -= 1;
115                if in_degree[child] == 0 {
116                    queue.push_back(child);
117                }
118            }
119        }
120        if order.len() == self.n_nodes {
121            Some(order)
122        } else {
123            None
124        }
125    }
126
127    /// Return all ancestors of node `v` (nodes from which `v` is reachable).
128    pub fn ancestors(&self, v: usize) -> HashSet<usize> {
129        let mut anc = HashSet::new();
130        let mut stack = vec![v];
131        while let Some(node) = stack.pop() {
132            for &p in &self.parents[node] {
133                if anc.insert(p) {
134                    stack.push(p);
135                }
136            }
137        }
138        anc
139    }
140
141    /// Return all descendants of node `v`.
142    pub fn descendants(&self, v: usize) -> HashSet<usize> {
143        let mut desc = HashSet::new();
144        let mut stack = vec![v];
145        while let Some(node) = stack.pop() {
146            for &c in &self.children[node] {
147                if desc.insert(c) {
148                    stack.push(c);
149                }
150            }
151        }
152        desc
153    }
154
155    /// Test d-separation: are node sets `x` and `y` d-separated by conditioning set `z`?
156    ///
157    /// Uses the Bayes Ball algorithm. Returns `true` if `x ⊥ y | z`.
158    pub fn d_separated(&self, x: &[usize], y: &[usize], z: &[usize]) -> bool {
159        let z_set: HashSet<usize> = z.iter().copied().collect();
160        let y_set: HashSet<usize> = y.iter().copied().collect();
161
162        // Collect all ancestors of Z (needed for v-structure blocking)
163        let mut z_ancestors: HashSet<usize> = z_set.clone();
164        for &zv in z {
165            z_ancestors.extend(self.ancestors(zv));
166        }
167
168        // Bayes Ball: (node, direction) where direction=true means "from child"
169        let mut visited: HashSet<(usize, bool)> = HashSet::new();
170        let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
171
172        for &xv in x {
173            // Start going "up" (toward parents) and "down" (toward children)
174            queue.push_back((xv, true)); // via child → go up
175            queue.push_back((xv, false)); // via parent → go down
176        }
177
178        while let Some((node, via_child)) = queue.pop_front() {
179            if visited.contains(&(node, via_child)) {
180                continue;
181            }
182            visited.insert((node, via_child));
183
184            if y_set.contains(&node) {
185                return false; // path found → not d-separated
186            }
187
188            let in_z = z_set.contains(&node);
189            let in_z_anc = z_ancestors.contains(&node);
190
191            if via_child && !in_z {
192                // Arrived via child, node is not in Z
193                // Can traverse: up to parents (chain/fork) and down to children (not blocked)
194                for &p in &self.parents[node] {
195                    queue.push_back((p, true));
196                }
197                for &c in &self.children[node] {
198                    queue.push_back((c, false));
199                }
200            } else if !via_child {
201                // Arrived via parent
202                if !in_z {
203                    // Collider not in Z: blocked going down unless ancestor
204                    for &c in &self.children[node] {
205                        queue.push_back((c, false));
206                    }
207                }
208                if in_z_anc {
209                    // v-structure activated by conditioning on descendant
210                    for &p in &self.parents[node] {
211                        queue.push_back((p, true));
212                    }
213                }
214            }
215        }
216        true
217    }
218
219    /// Return the Markov blanket of node `v`: parents, children, and co-parents.
220    pub fn markov_blanket(&self, v: usize) -> HashSet<usize> {
221        let mut blanket = HashSet::new();
222        for &p in &self.parents[v] {
223            blanket.insert(p);
224        }
225        for &c in &self.children[v] {
226            blanket.insert(c);
227            for &cp in &self.parents[c] {
228                if cp != v {
229                    blanket.insert(cp);
230                }
231            }
232        }
233        blanket
234    }
235
236    /// Check if the graph is acyclic.
237    pub fn is_acyclic(&self) -> bool {
238        self.topological_sort().is_some()
239    }
240}
241
242// ---------------------------------------------------------------------------
243// StructuralCausalModel
244// ---------------------------------------------------------------------------
245
246/// A linear structural causal model (SCM).
247///
248/// Each variable `X_i` is defined as:
249/// `X_i = Σ_j (coeff[i][j] * X_j) + noise_std[i] * ε_i`
250///
251/// where `ε_i ~ N(0,1)` and `j` ranges over parents of `i`.
252#[derive(Debug, Clone)]
253pub struct StructuralCausalModel {
254    /// The underlying causal graph.
255    pub graph: CausalGraph,
256    /// Structural coefficients: `coefficients[i][k]` is the coefficient for the
257    /// k-th parent of node `i`.
258    pub coefficients: Vec<Vec<f64>>,
259    /// Standard deviation of the exogenous noise for each variable.
260    pub noise_std: Vec<f64>,
261    /// Intercept terms for each variable.
262    pub intercepts: Vec<f64>,
263}
264
265impl StructuralCausalModel {
266    /// Create a new linear SCM on `n` variables.
267    ///
268    /// All coefficients are zero, noise std = 1.0, intercepts = 0.0 by default.
269    pub fn new(n: usize) -> Self {
270        Self {
271            graph: CausalGraph::new(n),
272            coefficients: vec![vec![]; n],
273            noise_std: vec![1.0; n],
274            intercepts: vec![0.0; n],
275        }
276    }
277
278    /// Add a causal edge with a specified structural coefficient.
279    ///
280    /// # Arguments
281    /// * `from` — parent (cause) node index
282    /// * `to`   — child (effect) node index
283    /// * `coeff` — linear coefficient
284    pub fn add_edge(&mut self, from: usize, to: usize, coeff: f64) {
285        self.graph.add_edge(from, to);
286        // The k-th parent of `to` is now `from`
287        self.coefficients[to].push(coeff);
288    }
289
290    /// Set the noise standard deviation for variable `v`.
291    pub fn set_noise(&mut self, v: usize, std: f64) {
292        self.noise_std[v] = std;
293    }
294
295    /// Set the intercept for variable `v`.
296    pub fn set_intercept(&mut self, v: usize, intercept: f64) {
297        self.intercepts[v] = intercept;
298    }
299
300    /// Sample one observation from the SCM using provided noise values.
301    ///
302    /// # Arguments
303    /// * `noise` — exogenous noise values `ε_i` for each variable (length = n_nodes)
304    ///
305    /// Returns `x[i]` values in topological order.
306    pub fn sample_with_noise(&self, noise: &[f64]) -> Vec<f64> {
307        let n = self.graph.n_nodes;
308        let order = self
309            .graph
310            .topological_sort()
311            .expect("SCM graph must be acyclic");
312        let mut x = vec![0.0_f64; n];
313        for &v in &order {
314            let val: f64 = self.intercepts[v]
315                + self.graph.parents[v]
316                    .iter()
317                    .zip(self.coefficients[v].iter())
318                    .map(|(&p, &c)| c * x[p])
319                    .sum::<f64>()
320                + self.noise_std[v] * noise[v];
321            x[v] = val;
322        }
323        x
324    }
325
326    /// Perform a do-calculus intervention: set variable `target` to value `val`.
327    ///
328    /// Returns the modified SCM where all incoming edges to `target` are removed
329    /// and its value is fixed at `val` (zero noise, intercept = val).
330    pub fn intervene(&self, target: usize, val: f64) -> Self {
331        let mut scm = self.clone();
332        // Remove all parents of target
333        let parents = scm.graph.parents[target].clone();
334        for &p in &parents {
335            scm.graph.children[p].retain(|&c| c != target);
336        }
337        scm.graph.parents[target].clear();
338        scm.coefficients[target].clear();
339        scm.noise_std[target] = 0.0;
340        scm.intercepts[target] = val;
341        scm
342    }
343
344    /// Compute the average causal effect (ACE) of intervention `do(X_cause = val)`
345    /// on variable `effect`, using the provided noise samples.
346    ///
347    /// # Arguments
348    /// * `cause`      — the variable to intervene on
349    /// * `val`        — the intervention value
350    /// * `effect`     — the outcome variable
351    /// * `noise_samples` — matrix of noise samples, shape `[n_samples][n_nodes]`
352    pub fn average_causal_effect(
353        &self,
354        cause: usize,
355        val: f64,
356        effect: usize,
357        noise_samples: &[Vec<f64>],
358    ) -> f64 {
359        let intervened = self.intervene(cause, val);
360        let mean: f64 = noise_samples
361            .iter()
362            .map(|noise| intervened.sample_with_noise(noise)[effect])
363            .sum::<f64>()
364            / noise_samples.len() as f64;
365        mean
366    }
367
368    /// Compute the total causal effect of `cause` on `effect` analytically
369    /// (only valid for linear SCMs).
370    ///
371    /// Sums all directed path contributions.
372    pub fn total_effect_linear(&self, cause: usize, effect: usize) -> f64 {
373        // BFS/DFS to enumerate all directed paths and multiply coefficients
374        let mut total = 0.0_f64;
375        // Stack of (current_node, accumulated_product)
376        let mut stack: Vec<(usize, f64)> = vec![(cause, 1.0)];
377        while let Some((node, prod)) = stack.pop() {
378            if node == effect && node != cause {
379                total += prod;
380            }
381            for (k, &child) in self.graph.children[node].iter().enumerate() {
382                // Find the index of `node` in child's parent list
383                if let Some(idx) = self.graph.parents[child].iter().position(|&p| p == node) {
384                    let coeff = self.coefficients[child][idx];
385                    let _ = k; // suppress unused warning
386                    stack.push((child, prod * coeff));
387                }
388            }
389        }
390        total
391    }
392}
393
394// ---------------------------------------------------------------------------
395// BackdoorCriterion
396// ---------------------------------------------------------------------------
397
398/// Checks whether a set of variables satisfies the backdoor criterion for
399/// identifying the causal effect of `treatment` on `outcome`.
400///
401/// The backdoor criterion holds if:
402/// 1. No variable in `adjustment_set` is a descendant of `treatment`.
403/// 2. `adjustment_set` blocks all backdoor paths from `treatment` to `outcome`.
404#[derive(Debug, Clone)]
405pub struct BackdoorCriterion {
406    /// The causal graph.
407    pub graph: CausalGraph,
408}
409
410impl BackdoorCriterion {
411    /// Create a new backdoor criterion checker.
412    pub fn new(graph: CausalGraph) -> Self {
413        Self { graph }
414    }
415
416    /// Check if `adjustment_set` satisfies the backdoor criterion for the
417    /// causal effect of `treatment` on `outcome`.
418    ///
419    /// Returns `true` if the criterion is satisfied.
420    pub fn check(&self, treatment: usize, outcome: usize, adjustment_set: &[usize]) -> bool {
421        let desc_treatment = self.graph.descendants(treatment);
422
423        // Criterion 1: no variable in Z is a descendant of X
424        for &z in adjustment_set {
425            if desc_treatment.contains(&z) {
426                return false;
427            }
428        }
429
430        // Criterion 2: Z blocks all backdoor paths
431        // A backdoor path is a path from X to Y that starts with an arrow INTO X.
432        // We check this by creating a modified graph where outgoing edges from X
433        // are removed, and then check d-separation.
434        let mut modified = self.graph.clone();
435        // Remove all outgoing edges from treatment in the modified graph
436        let children_of_treatment = modified.graph_children_of(treatment);
437        for &c in &children_of_treatment {
438            modified.parents[c].retain(|&p| p != treatment);
439        }
440        modified.children[treatment].clear();
441
442        modified.d_separated(&[treatment], &[outcome], adjustment_set)
443    }
444
445    /// Compute the backdoor-adjusted causal effect estimate from observational data.
446    ///
447    /// Uses the adjustment formula: E\[Y|do(X=x)\] = Σ_z E\[Y|X=x, Z=z\] * P(Z=z)
448    ///
449    /// This simplified version takes pre-computed conditional means.
450    ///
451    /// # Arguments
452    /// * `data_x`   — treatment values
453    /// * `data_y`   — outcome values
454    /// * `data_z`   — confounder values (single confounder for simplicity)
455    /// * `x_val`    — the intervention value
456    pub fn adjusted_effect(
457        data_x: &[f64],
458        data_y: &[f64],
459        data_z: &[f64],
460        x_val: f64,
461        _tolerance: f64,
462    ) -> f64 {
463        let n = data_x.len();
464        assert_eq!(data_y.len(), n);
465        assert_eq!(data_z.len(), n);
466
467        // Bin Z into quantile strata for simple adjustment
468        let n_strata = 5usize;
469        let mut sorted_z = data_z.to_vec();
470        sorted_z.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
471        let quantiles: Vec<f64> = (1..n_strata)
472            .map(|i| sorted_z[(i * n) / n_strata])
473            .collect();
474
475        let stratum_of = |z: f64| -> usize {
476            quantiles
477                .iter()
478                .position(|&q| z < q)
479                .unwrap_or(n_strata - 1)
480        };
481
482        // For each stratum, estimate E[Y|X≈x_val, Z=stratum] and P(Z=stratum)
483        let mut stratum_sums_y = vec![0.0_f64; n_strata];
484        let mut stratum_counts = vec![0usize; n_strata];
485        let mut stratum_counts_near_x = vec![0usize; n_strata];
486        let mut stratum_y_near_x = vec![0.0_f64; n_strata];
487
488        let bandwidth = {
489            let mut xs = data_x.to_vec();
490            xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491            let iqr = xs[3 * n / 4] - xs[n / 4];
492            iqr.max(0.1) * 0.5
493        };
494
495        for i in 0..n {
496            let s = stratum_of(data_z[i]);
497            stratum_sums_y[s] += data_y[i];
498            stratum_counts[s] += 1;
499            if (data_x[i] - x_val).abs() < bandwidth {
500                stratum_y_near_x[s] += data_y[i];
501                stratum_counts_near_x[s] += 1;
502            }
503        }
504
505        let mut total = 0.0_f64;
506        for s in 0..n_strata {
507            if stratum_counts[s] == 0 {
508                continue;
509            }
510            let p_z = stratum_counts[s] as f64 / n as f64;
511            let e_y_xz = if stratum_counts_near_x[s] > 0 {
512                stratum_y_near_x[s] / stratum_counts_near_x[s] as f64
513            } else {
514                stratum_sums_y[s] / stratum_counts[s] as f64
515            };
516            total += e_y_xz * p_z;
517        }
518        total
519    }
520}
521
522// Extension trait for internal use
523trait GraphChildrenOf {
524    fn graph_children_of(&self, v: usize) -> Vec<usize>;
525}
526
527impl GraphChildrenOf for CausalGraph {
528    fn graph_children_of(&self, v: usize) -> Vec<usize> {
529        self.children[v].clone()
530    }
531}
532
533// ---------------------------------------------------------------------------
534// FrontdoorCriterion
535// ---------------------------------------------------------------------------
536
537/// Implements the frontdoor adjustment formula for causal effect identification.
538///
539/// The frontdoor criterion allows identification of causal effects through a
540/// mediator set `M` when direct adjustment is not possible.
541#[derive(Debug, Clone)]
542pub struct FrontdoorCriterion {
543    /// The causal graph.
544    pub graph: CausalGraph,
545}
546
547impl FrontdoorCriterion {
548    /// Create a new frontdoor criterion object.
549    pub fn new(graph: CausalGraph) -> Self {
550        Self { graph }
551    }
552
553    /// Check if `mediator_set` satisfies the frontdoor criterion for identifying
554    /// the effect of `treatment` on `outcome`.
555    ///
556    /// Conditions:
557    /// 1. All directed paths from `treatment` to `outcome` are intercepted by `M`.
558    /// 2. No backdoor path from `treatment` to `M` (or blocked by `treatment`).
559    /// 3. All backdoor paths from `M` to `outcome` are blocked by `treatment`.
560    pub fn check(&self, treatment: usize, outcome: usize, mediator_set: &[usize]) -> bool {
561        let med_set: HashSet<usize> = mediator_set.iter().copied().collect();
562
563        // Condition 1: M intercepts all directed paths from X to Y
564        if !self.intercepts_all_paths(treatment, outcome, &med_set) {
565            return false;
566        }
567
568        // Condition 2: no unblocked backdoor from X to M (blocked by ∅)
569        // i.e., X d-separates from M given ∅ in graph with X's parents cut
570        // Simplified: check no common causes of X and M that aren't through X
571        for &m in mediator_set {
572            if !self.graph.d_separated(&[treatment], &[m], &[treatment]) {
573                // Check via empty set
574                let x_anc = self.graph.ancestors(treatment);
575                let m_anc = self.graph.ancestors(m);
576                // If there is overlap in ancestors excluding X's subtree, problem exists
577                // Simplified check for demo purposes
578                let _ = (x_anc, m_anc);
579            }
580        }
581
582        // Condition 3: all backdoor from M to Y are blocked by X
583        for &m in mediator_set {
584            if !self.graph.d_separated(&[m], &[outcome], &[treatment]) {
585                return false;
586            }
587        }
588
589        true
590    }
591
592    /// Check if `med_set` intercepts all directed paths from `src` to `dst`.
593    fn intercepts_all_paths(&self, src: usize, dst: usize, med_set: &HashSet<usize>) -> bool {
594        // DFS: find if any path from src to dst avoids med_set
595        let mut stack: Vec<(usize, Vec<usize>)> = vec![(src, vec![src])];
596        while let Some((node, path)) = stack.pop() {
597            if node == dst {
598                // Found a path; check if med_set is on it (excluding src)
599                let on_path = path[1..].iter().any(|v| med_set.contains(v));
600                if !on_path {
601                    return false;
602                }
603                continue;
604            }
605            for &child in &self.graph.children[node] {
606                if !path.contains(&child) {
607                    let mut new_path = path.clone();
608                    new_path.push(child);
609                    stack.push((child, new_path));
610                }
611            }
612        }
613        true
614    }
615
616    /// Compute the frontdoor-adjusted causal effect using sample data.
617    ///
618    /// E\[Y|do(X=x)\] = Σ_m P(M=m|X=x) Σ_x' E\[Y|M=m, X=x'\] P(X=x')
619    ///
620    /// Simplified discrete approximation for a single mediator.
621    pub fn adjusted_effect(data_x: &[f64], data_m: &[f64], data_y: &[f64], x_val: f64) -> f64 {
622        let n = data_x.len();
623        let bandwidth = {
624            let mut xs = data_x.to_vec();
625            xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
626            let iqr = xs[3 * n / 4] - xs[n / 4];
627            iqr.max(0.1) * 0.4
628        };
629
630        // Approximate E[M|X=x] via kernel smoothing
631        let (mut sum_m, mut w_sum) = (0.0_f64, 0.0_f64);
632        for i in 0..n {
633            let w = gaussian_kernel((data_x[i] - x_val) / bandwidth);
634            sum_m += w * data_m[i];
635            w_sum += w;
636        }
637        let e_m_given_x = if w_sum > 1e-12 { sum_m / w_sum } else { 0.0 };
638
639        // Approximate E[Y|M=m] via kernel smoothing on M
640        let bw_m = {
641            let mut ms = data_m.to_vec();
642            ms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
643            let iqr = ms[3 * n / 4] - ms[n / 4];
644            iqr.max(0.1) * 0.4
645        };
646        let (mut sum_y, mut wy_sum) = (0.0_f64, 0.0_f64);
647        for i in 0..n {
648            let w = gaussian_kernel((data_m[i] - e_m_given_x) / bw_m);
649            sum_y += w * data_y[i];
650            wy_sum += w;
651        }
652        if wy_sum > 1e-12 { sum_y / wy_sum } else { 0.0 }
653    }
654}
655
656/// Gaussian kernel function for kernel smoothing.
657#[allow(dead_code)]
658fn gaussian_kernel(u: f64) -> f64 {
659    (-0.5 * u * u).exp()
660}
661
662// ---------------------------------------------------------------------------
663// PropensityScoreMatching
664// ---------------------------------------------------------------------------
665
666/// Propensity score matching for observational causal inference.
667///
668/// Estimates the probability of treatment assignment P(T=1|X) using logistic
669/// regression, then matches treated and control units.
670#[derive(Debug, Clone)]
671pub struct PropensityScoreMatching {
672    /// Logistic regression weights (length = n_covariates + 1, with intercept).
673    pub weights: Vec<f64>,
674    /// Number of covariates.
675    pub n_covariates: usize,
676}
677
678impl PropensityScoreMatching {
679    /// Create a new propensity score matcher.
680    ///
681    /// # Arguments
682    /// * `n_covariates` — number of covariate dimensions
683    pub fn new(n_covariates: usize) -> Self {
684        Self {
685            weights: vec![0.0; n_covariates + 1],
686            n_covariates,
687        }
688    }
689
690    /// Fit logistic regression to estimate propensity scores via gradient descent.
691    ///
692    /// # Arguments
693    /// * `covariates` — matrix of covariates, shape `[n_obs][n_covariates]`
694    /// * `treatment`  — binary treatment indicator (0 or 1), length `n_obs`
695    /// * `lr`         — learning rate
696    /// * `n_iter`     — number of gradient descent iterations
697    pub fn fit(&mut self, covariates: &[Vec<f64>], treatment: &[f64], lr: f64, n_iter: usize) {
698        let n = covariates.len();
699        assert_eq!(treatment.len(), n);
700        for _ in 0..n_iter {
701            let mut grad = vec![0.0_f64; self.n_covariates + 1];
702            for i in 0..n {
703                let p = self.predict_one(&covariates[i]);
704                let err = p - treatment[i];
705                grad[0] += err; // intercept
706                for j in 0..self.n_covariates {
707                    grad[j + 1] += err * covariates[i][j];
708                }
709            }
710            for k in 0..self.weights.len() {
711                self.weights[k] -= lr * grad[k] / n as f64;
712            }
713        }
714    }
715
716    /// Predict propensity score P(T=1|X=x) for a single observation.
717    pub fn predict_one(&self, x: &[f64]) -> f64 {
718        let logit: f64 = self.weights[0]
719            + x.iter()
720                .zip(self.weights[1..].iter())
721                .map(|(xi, wi)| xi * wi)
722                .sum::<f64>();
723        sigmoid(logit)
724    }
725
726    /// Predict propensity scores for all observations.
727    pub fn predict(&self, covariates: &[Vec<f64>]) -> Vec<f64> {
728        covariates.iter().map(|x| self.predict_one(x)).collect()
729    }
730
731    /// Estimate the Average Treatment Effect (ATE) using IPW (Inverse Probability Weighting).
732    ///
733    /// ATE = E\[Y(1)\] - E\[Y(0)\] = E\[T*Y/e(X)\] - E\[(1-T)*Y/(1-e(X))\]
734    pub fn estimate_ate(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
735        let n = covariates.len();
736        let (mut sum1, mut w1, mut sum0, mut w0) = (0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64);
737        for i in 0..n {
738            let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
739            if treatment[i] > 0.5 {
740                sum1 += outcome[i] / e;
741                w1 += 1.0 / e;
742            } else {
743                sum0 += outcome[i] / (1.0 - e);
744                w0 += 1.0 / (1.0 - e);
745            }
746        }
747        let ey1 = if w1 > 0.0 { sum1 / w1 } else { 0.0 };
748        let ey0 = if w0 > 0.0 { sum0 / w0 } else { 0.0 };
749        ey1 - ey0
750    }
751
752    /// Estimate the Average Treatment Effect on the Treated (ATT).
753    ///
754    /// ATT = E\[Y(1)-Y(0)|T=1\]
755    pub fn estimate_att(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
756        let n = covariates.len();
757        let mut treated_y: Vec<f64> = Vec::new();
758        let mut control_y: Vec<f64> = Vec::new();
759        let mut control_ps: Vec<f64> = Vec::new();
760
761        for i in 0..n {
762            let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
763            if treatment[i] > 0.5 {
764                treated_y.push(outcome[i]);
765            } else {
766                control_y.push(outcome[i]);
767                control_ps.push(e / (1.0 - e)); // odds
768            }
769        }
770
771        if treated_y.is_empty() || control_y.is_empty() {
772            return 0.0;
773        }
774
775        let mean_treated = treated_y.iter().sum::<f64>() / treated_y.len() as f64;
776        let total_weight: f64 = control_ps.iter().sum();
777        let mean_control = if total_weight > 0.0 {
778            control_y
779                .iter()
780                .zip(control_ps.iter())
781                .map(|(y, w)| y * w)
782                .sum::<f64>()
783                / total_weight
784        } else {
785            control_y.iter().sum::<f64>() / control_y.len() as f64
786        };
787
788        mean_treated - mean_control
789    }
790}
791
792/// Logistic sigmoid function.
793fn sigmoid(x: f64) -> f64 {
794    1.0 / (1.0 + (-x).exp())
795}
796
797// ---------------------------------------------------------------------------
798// InstrumentalVariables
799// ---------------------------------------------------------------------------
800
801/// Instrumental variables (IV) estimation and Two-Stage Least Squares (2SLS).
802///
803/// Used when treatment is endogenous (correlated with the error term).
804/// Requires a valid instrument `Z` that:
805/// 1. Is correlated with the treatment `D`.
806/// 2. Affects the outcome `Y` only through `D` (exclusion restriction).
807/// 3. Is independent of unobserved confounders.
808#[derive(Debug, Clone)]
809pub struct InstrumentalVariables {
810    /// Number of endogenous variables.
811    pub n_endogenous: usize,
812    /// Number of instruments.
813    pub n_instruments: usize,
814    /// First-stage coefficients (instrument → treatment).
815    pub first_stage: Vec<f64>,
816    /// Second-stage coefficient (treatment → outcome).
817    pub second_stage: f64,
818}
819
820impl InstrumentalVariables {
821    /// Create a new IV estimator.
822    pub fn new(n_endogenous: usize, n_instruments: usize) -> Self {
823        Self {
824            n_endogenous,
825            n_instruments,
826            first_stage: vec![0.0; n_instruments + 1],
827            second_stage: 0.0,
828        }
829    }
830
831    /// Fit the 2SLS estimator.
832    ///
833    /// Stage 1: regress treatment `d` on instruments `z`.
834    /// Stage 2: regress outcome `y` on predicted treatment `d_hat`.
835    ///
836    /// # Arguments
837    /// * `y` — outcome variable
838    /// * `d` — endogenous treatment variable
839    /// * `z` — instruments matrix, shape `[n_obs][n_instruments]`
840    pub fn fit_2sls(&mut self, y: &[f64], d: &[f64], z: &[Vec<f64>]) {
841        let n = y.len();
842        assert_eq!(d.len(), n);
843        assert_eq!(z.len(), n);
844
845        // Stage 1: OLS of D on Z (including intercept)
846        // Simple single-instrument case for clarity
847        let n_inst = self.n_instruments;
848        let mut d_hat = vec![0.0_f64; n];
849
850        if n_inst == 1 {
851            // Simple IV: β_1 = Cov(Y,Z) / Cov(D,Z)
852            let z_vec: Vec<f64> = z.iter().map(|row| row[0]).collect();
853            let mean_z = z_vec.iter().sum::<f64>() / n as f64;
854            let mean_d = d.iter().sum::<f64>() / n as f64;
855            let mean_y = y.iter().sum::<f64>() / n as f64;
856
857            let cov_dz: f64 = d
858                .iter()
859                .zip(z_vec.iter())
860                .map(|(di, zi)| (di - mean_d) * (zi - mean_z))
861                .sum::<f64>()
862                / n as f64;
863            let cov_yz: f64 = y
864                .iter()
865                .zip(z_vec.iter())
866                .map(|(yi, zi)| (yi - mean_y) * (zi - mean_z))
867                .sum::<f64>()
868                / n as f64;
869            let var_z: f64 = z_vec.iter().map(|zi| (zi - mean_z).powi(2)).sum::<f64>() / n as f64;
870
871            // First stage: d = α0 + α1*z
872            let alpha1 = if var_z.abs() > 1e-12 {
873                cov_dz / var_z
874            } else {
875                0.0
876            };
877            let alpha0 = mean_d - alpha1 * mean_z;
878            self.first_stage[0] = alpha0;
879            self.first_stage[1] = alpha1;
880
881            // IV estimate
882            self.second_stage = if cov_dz.abs() > 1e-12 {
883                cov_yz / cov_dz
884            } else {
885                0.0
886            };
887
888            for i in 0..n {
889                d_hat[i] = alpha0 + alpha1 * z_vec[i];
890            }
891        } else {
892            // Multi-instrument: use OLS for first stage
893            // Simplified: use first instrument only
894            let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
895            let mean_z0 = z0.iter().sum::<f64>() / n as f64;
896            let mean_d = d.iter().sum::<f64>() / n as f64;
897
898            let cov = z0
899                .iter()
900                .zip(d.iter())
901                .map(|(zi, di)| (zi - mean_z0) * (di - mean_d))
902                .sum::<f64>()
903                / n as f64;
904            let var_z0 = z0.iter().map(|zi| (zi - mean_z0).powi(2)).sum::<f64>() / n as f64;
905
906            let alpha1 = if var_z0 > 1e-12 { cov / var_z0 } else { 0.0 };
907            let alpha0 = mean_d - alpha1 * mean_z0;
908            self.first_stage[0] = alpha0;
909            self.first_stage[1] = alpha1;
910
911            for i in 0..n {
912                d_hat[i] = alpha0 + alpha1 * z0[i];
913            }
914
915            // Stage 2: OLS of Y on D_hat
916            let mean_dhat = d_hat.iter().sum::<f64>() / n as f64;
917            let mean_y = y.iter().sum::<f64>() / n as f64;
918            let cov_ydhat: f64 = y
919                .iter()
920                .zip(d_hat.iter())
921                .map(|(yi, di)| (yi - mean_y) * (di - mean_dhat))
922                .sum::<f64>()
923                / n as f64;
924            let var_dhat: f64 =
925                d_hat.iter().map(|di| (di - mean_dhat).powi(2)).sum::<f64>() / n as f64;
926            self.second_stage = if var_dhat > 1e-12 {
927                cov_ydhat / var_dhat
928            } else {
929                0.0
930            };
931        }
932    }
933
934    /// Compute the first-stage F-statistic (instrument relevance test).
935    ///
936    /// Large F (> 10) indicates strong instruments.
937    pub fn first_stage_f_stat(&self, y: &[f64], d: &[f64], z: &[Vec<f64>]) -> f64 {
938        let n = y.len();
939        let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
940        let _mean_z0 = z0.iter().sum::<f64>() / n as f64;
941        let mean_d = d.iter().sum::<f64>() / n as f64;
942
943        let d_hat: Vec<f64> = z0
944            .iter()
945            .map(|zi| self.first_stage[0] + self.first_stage[1] * zi)
946            .collect();
947
948        let ss_res: f64 = d
949            .iter()
950            .zip(d_hat.iter())
951            .map(|(di, dh)| (di - dh).powi(2))
952            .sum();
953        let ss_tot: f64 = d.iter().map(|di| (di - mean_d).powi(2)).sum();
954
955        let r2 = 1.0 - ss_res / ss_tot.max(1e-12);
956        let k = 1.0_f64; // number of instruments
957        let n_f = n as f64;
958        (r2 / k) / ((1.0 - r2) / (n_f - k - 1.0)).max(1e-12)
959    }
960
961    /// Predict the causal effect for a new treatment value.
962    pub fn predict(&self, d_val: f64) -> f64 {
963        self.second_stage * d_val
964    }
965}
966
967// ---------------------------------------------------------------------------
968// CausalDiscovery
969// ---------------------------------------------------------------------------
970
971/// Causal discovery via the PC algorithm.
972///
973/// The PC algorithm learns the structure of a DAG from conditional independence
974/// tests on observational data. It produces a Completed Partially Directed
975/// Acyclic Graph (CPDAG) representing the Markov equivalence class.
976#[derive(Debug, Clone)]
977pub struct CausalDiscovery {
978    /// Number of variables.
979    pub n_vars: usize,
980    /// Adjacency matrix of the skeleton (undirected).
981    pub skeleton: Vec<Vec<bool>>,
982    /// Directed adjacency: `directed[i][j] = true` means i → j is oriented.
983    pub directed: Vec<Vec<bool>>,
984    /// Separation sets: `sep_sets[(i,j)]` = the conditioning set that d-separates i and j.
985    pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
986    /// Significance threshold for independence tests.
987    pub alpha: f64,
988}
989
990impl CausalDiscovery {
991    /// Create a new PC algorithm runner.
992    ///
993    /// # Arguments
994    /// * `n_vars` — number of observed variables
995    /// * `alpha`  — significance level for conditional independence tests
996    pub fn new(n_vars: usize, alpha: f64) -> Self {
997        Self {
998            n_vars,
999            skeleton: vec![vec![true; n_vars]; n_vars],
1000            directed: vec![vec![false; n_vars]; n_vars],
1001            sep_sets: HashMap::new(),
1002            alpha,
1003        }
1004    }
1005
1006    /// Learn the skeleton from a data matrix using partial correlation tests.
1007    ///
1008    /// # Arguments
1009    /// * `data` — data matrix, shape `[n_obs][n_vars]`
1010    pub fn learn_skeleton(&mut self, data: &[Vec<f64>]) {
1011        let n = self.n_vars;
1012
1013        // Remove self-loops
1014        for i in 0..n {
1015            self.skeleton[i][i] = false;
1016        }
1017
1018        // Level 0: unconditional independence
1019        for i in 0..n {
1020            for j in (i + 1)..n {
1021                let r = partial_correlation(data, i, j, &[]);
1022                let p = fisher_z_test(r, data.len(), 0);
1023                if p > self.alpha {
1024                    self.skeleton[i][j] = false;
1025                    self.skeleton[j][i] = false;
1026                    self.sep_sets.insert((i, j), vec![]);
1027                    self.sep_sets.insert((j, i), vec![]);
1028                }
1029            }
1030        }
1031
1032        // Level 1+: conditional independence given conditioning sets
1033        for cond_size in 1..n.saturating_sub(1) {
1034            for i in 0..n {
1035                let adj_i: Vec<usize> = (0..n).filter(|&k| k != i && self.skeleton[i][k]).collect();
1036                for &j in &adj_i {
1037                    if !self.skeleton[i][j] {
1038                        continue;
1039                    }
1040                    let adj_minus_j: Vec<usize> =
1041                        adj_i.iter().copied().filter(|&k| k != j).collect();
1042                    if adj_minus_j.len() < cond_size {
1043                        continue;
1044                    }
1045                    // Test all conditioning sets of size `cond_size`
1046                    for cond_set in subsets(&adj_minus_j, cond_size) {
1047                        let r = partial_correlation(data, i, j, &cond_set);
1048                        let p = fisher_z_test(r, data.len(), cond_size);
1049                        if p > self.alpha {
1050                            self.skeleton[i][j] = false;
1051                            self.skeleton[j][i] = false;
1052                            self.sep_sets.insert((i, j), cond_set.clone());
1053                            self.sep_sets.insert((j, i), cond_set);
1054                            break;
1055                        }
1056                    }
1057                }
1058            }
1059        }
1060    }
1061
1062    /// Orient v-structures (colliders) in the skeleton.
1063    ///
1064    /// For each unshielded triple i — k — j (i and j not adjacent),
1065    /// if k is not in sep(i,j), orient i→k←j.
1066    pub fn orient_v_structures(&mut self) {
1067        let n = self.n_vars;
1068        for i in 0..n {
1069            for k in 0..n {
1070                if i == k || !self.skeleton[i][k] {
1071                    continue;
1072                }
1073                for j in (i + 1)..n {
1074                    if j == k || !self.skeleton[k][j] || self.skeleton[i][j] {
1075                        continue;
1076                    }
1077                    // i — k — j, unshielded
1078                    let sep = self.sep_sets.get(&(i, j)).cloned().unwrap_or_default();
1079                    if !sep.contains(&k) {
1080                        // Orient as collider: i → k ← j
1081                        self.directed[i][k] = true;
1082                        self.directed[j][k] = true;
1083                        self.skeleton[k][i] = false;
1084                        self.skeleton[k][j] = false;
1085                    }
1086                }
1087            }
1088        }
1089    }
1090
1091    /// Apply Meek's orientation rules to complete the CPDAG.
1092    ///
1093    /// Rules R1–R3 propagate orientations to avoid new v-structures and cycles.
1094    pub fn apply_meek_rules(&mut self) {
1095        let n = self.n_vars;
1096        let mut changed = true;
1097        while changed {
1098            changed = false;
1099            // R1: If i→j — k and i not adjacent to k, orient j→k
1100            for i in 0..n {
1101                for j in 0..n {
1102                    if !self.directed[i][j] {
1103                        continue;
1104                    }
1105                    for k in 0..n {
1106                        if k == i || k == j {
1107                            continue;
1108                        }
1109                        if self.skeleton[j][k]
1110                            && !self.directed[j][k]
1111                            && !self.directed[k][j]
1112                            && !self.skeleton[i][k]
1113                        {
1114                            self.directed[j][k] = true;
1115                            self.skeleton[k][j] = false;
1116                            changed = true;
1117                        }
1118                    }
1119                }
1120            }
1121            // R2: If i→k→j and i — j, orient i→j
1122            for i in 0..n {
1123                for j in 0..n {
1124                    if i == j || !self.skeleton[i][j] || self.directed[i][j] {
1125                        continue;
1126                    }
1127                    for k in 0..n {
1128                        if k == i || k == j {
1129                            continue;
1130                        }
1131                        if self.directed[i][k] && self.directed[k][j] {
1132                            self.directed[i][j] = true;
1133                            self.skeleton[j][i] = false;
1134                            changed = true;
1135                        }
1136                    }
1137                }
1138            }
1139        }
1140    }
1141
1142    /// Run the full PC algorithm: skeleton → v-structures → Meek rules.
1143    pub fn run(&mut self, data: &[Vec<f64>]) {
1144        self.learn_skeleton(data);
1145        self.orient_v_structures();
1146        self.apply_meek_rules();
1147    }
1148}
1149
1150/// Compute partial correlation of variables `i` and `j` conditioning on `cond`.
1151///
1152/// Uses recursive partial correlation formula for efficiency.
1153pub fn partial_correlation(data: &[Vec<f64>], i: usize, j: usize, cond: &[usize]) -> f64 {
1154    if cond.is_empty() {
1155        return pearson_correlation(data, i, j);
1156    }
1157    if cond.len() == 1 {
1158        let k = cond[0];
1159        let r_ij = pearson_correlation(data, i, j);
1160        let r_ik = pearson_correlation(data, i, k);
1161        let r_jk = pearson_correlation(data, j, k);
1162        let denom = ((1.0 - r_ik * r_ik) * (1.0 - r_jk * r_jk)).sqrt();
1163        if denom < 1e-12 {
1164            return 0.0;
1165        }
1166        return (r_ij - r_ik * r_jk) / denom;
1167    }
1168    // For larger conditioning sets, use matrix inversion approach
1169    // Simplified: use iterative partial correlations
1170    let last = cond[cond.len() - 1];
1171    let rest = &cond[..cond.len() - 1];
1172    let r_ij_rest = partial_correlation(data, i, j, rest);
1173    let r_ik_rest = partial_correlation(data, i, last, rest);
1174    let r_jk_rest = partial_correlation(data, j, last, rest);
1175    let denom = ((1.0 - r_ik_rest * r_ik_rest) * (1.0 - r_jk_rest * r_jk_rest)).sqrt();
1176    if denom < 1e-12 {
1177        return 0.0;
1178    }
1179    (r_ij_rest - r_ik_rest * r_jk_rest) / denom
1180}
1181
1182/// Pearson correlation between variables `i` and `j` in a data matrix.
1183pub fn pearson_correlation(data: &[Vec<f64>], i: usize, j: usize) -> f64 {
1184    let n = data.len() as f64;
1185    let mean_i = data.iter().map(|row| row[i]).sum::<f64>() / n;
1186    let mean_j = data.iter().map(|row| row[j]).sum::<f64>() / n;
1187    let cov: f64 = data
1188        .iter()
1189        .map(|row| (row[i] - mean_i) * (row[j] - mean_j))
1190        .sum::<f64>()
1191        / n;
1192    let std_i = (data
1193        .iter()
1194        .map(|row| (row[i] - mean_i).powi(2))
1195        .sum::<f64>()
1196        / n)
1197        .sqrt();
1198    let std_j = (data
1199        .iter()
1200        .map(|row| (row[j] - mean_j).powi(2))
1201        .sum::<f64>()
1202        / n)
1203        .sqrt();
1204    if std_i < 1e-12 || std_j < 1e-12 {
1205        return 0.0;
1206    }
1207    (cov / (std_i * std_j)).clamp(-1.0, 1.0)
1208}
1209
1210/// Fisher Z-test for conditional independence.
1211///
1212/// Returns the p-value for the null hypothesis r = 0.
1213pub fn fisher_z_test(r: f64, n: usize, cond_size: usize) -> f64 {
1214    let r = r.clamp(-0.9999, 0.9999);
1215    let z = 0.5 * ((1.0 + r) / (1.0 - r)).ln();
1216    let se = 1.0 / ((n as f64 - cond_size as f64 - 3.0).max(1.0)).sqrt();
1217    let stat = (z / se).abs();
1218    // Two-tailed p-value from standard normal
1219    2.0 * (1.0 - standard_normal_cdf(stat))
1220}
1221
1222fn standard_normal_cdf(x: f64) -> f64 {
1223    let t = 1.0 / (1.0 + 0.2316419 * x.abs());
1224    let poly = t
1225        * (0.319_381_530
1226            + t * (-0.356_563_782
1227                + t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
1228    let pdf = (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt();
1229    let cdf = 1.0 - pdf * poly;
1230    if x >= 0.0 { cdf } else { 1.0 - cdf }
1231}
1232
1233/// Generate all subsets of `set` of exactly size `k`.
1234fn subsets(set: &[usize], k: usize) -> Vec<Vec<usize>> {
1235    if k == 0 {
1236        return vec![vec![]];
1237    }
1238    if set.len() < k {
1239        return vec![];
1240    }
1241    let mut result = Vec::new();
1242    for (i, &v) in set.iter().enumerate() {
1243        let rest = subsets(&set[(i + 1)..], k - 1);
1244        for mut subset in rest {
1245            subset.insert(0, v);
1246            result.push(subset);
1247        }
1248    }
1249    result
1250}
1251
1252// ---------------------------------------------------------------------------
1253// CounterfactualQuery
1254// ---------------------------------------------------------------------------
1255
1256/// Compute counterfactual queries of the form E\[Y | do(X=x), Z=z\].
1257///
1258/// Uses the abduction-action-prediction three-step procedure on a linear SCM.
1259#[derive(Debug, Clone)]
1260pub struct CounterfactualQuery {
1261    /// The structural causal model for the query.
1262    pub scm: StructuralCausalModel,
1263}
1264
1265impl CounterfactualQuery {
1266    /// Create a new counterfactual query engine.
1267    pub fn new(scm: StructuralCausalModel) -> Self {
1268        Self { scm }
1269    }
1270
1271    /// Compute E\[Y | do(X_target=x_val)\] using the interventional distribution.
1272    ///
1273    /// # Arguments
1274    /// * `target`        — the variable to intervene on
1275    /// * `x_val`         — the intervention value
1276    /// * `outcome`       — the outcome variable
1277    /// * `noise_samples` — noise samples for Monte Carlo evaluation
1278    pub fn query_do(
1279        &self,
1280        target: usize,
1281        x_val: f64,
1282        outcome: usize,
1283        noise_samples: &[Vec<f64>],
1284    ) -> f64 {
1285        self.scm
1286            .average_causal_effect(target, x_val, outcome, noise_samples)
1287    }
1288
1289    /// Compute the counterfactual: given that we observed `obs` (variable→value pairs),
1290    /// what would `outcome` have been if `do(target=x_val)`?
1291    ///
1292    /// Three steps:
1293    /// 1. Abduction: infer noise `U` from observations
1294    /// 2. Action: intervene on `target`
1295    /// 3. Prediction: compute outcome under intervention
1296    ///
1297    /// # Arguments
1298    /// * `obs`        — observed values, `obs[v]` = Some(value) if observed
1299    /// * `target`     — intervention variable
1300    /// * `x_val`      — intervention value
1301    /// * `outcome`    — outcome variable index
1302    pub fn counterfactual(
1303        &self,
1304        obs: &[Option<f64>],
1305        target: usize,
1306        x_val: f64,
1307        outcome: usize,
1308    ) -> f64 {
1309        let n = self.scm.graph.n_nodes;
1310        assert_eq!(obs.len(), n);
1311
1312        // Step 1: Abduction — infer noise from observations
1313        // For linear SCM: U_v = X_v - intercept - Σ coeff * X_parent
1314        let order = self
1315            .scm
1316            .graph
1317            .topological_sort()
1318            .expect("SCM must be acyclic");
1319        let mut x = vec![0.0_f64; n];
1320        let mut noise = vec![0.0_f64; n];
1321
1322        for &v in &order {
1323            if let Some(val) = obs[v] {
1324                x[v] = val;
1325                let pred: f64 = self.scm.graph.parents[v]
1326                    .iter()
1327                    .zip(self.scm.coefficients[v].iter())
1328                    .map(|(&p, &c)| c * x[p])
1329                    .sum::<f64>();
1330                let residual = val - self.scm.intercepts[v] - pred;
1331                noise[v] = if self.scm.noise_std[v].abs() > 1e-12 {
1332                    residual / self.scm.noise_std[v]
1333                } else {
1334                    0.0
1335                };
1336            } else {
1337                // Unobserved: use noise = 0 (mean)
1338                noise[v] = 0.0;
1339                let pred: f64 = self.scm.graph.parents[v]
1340                    .iter()
1341                    .zip(self.scm.coefficients[v].iter())
1342                    .map(|(&p, &c)| c * x[p])
1343                    .sum::<f64>();
1344                x[v] = self.scm.intercepts[v] + pred;
1345            }
1346        }
1347
1348        // Step 2: Action — intervene
1349        let intervened = self.scm.intervene(target, x_val);
1350
1351        // Step 3: Prediction — run SCM with inferred noise under intervention
1352        intervened.sample_with_noise(&noise)[outcome]
1353    }
1354
1355    /// Compute the probability of necessity (PN): P(Y=0 | do(X=0), Y=1, X=1).
1356    ///
1357    /// Simplified calculation using noise samples.
1358    pub fn probability_of_necessity(
1359        &self,
1360        treatment: usize,
1361        outcome: usize,
1362        t_val: f64,
1363        t_counter: f64,
1364        threshold_y: f64,
1365        noise_samples: &[Vec<f64>],
1366    ) -> f64 {
1367        let mut count = 0;
1368        let mut denom = 0;
1369        for noise in noise_samples {
1370            // Actual world
1371            let x_actual = self.scm.sample_with_noise(noise);
1372            if x_actual[treatment] < t_val - 0.5 || x_actual[outcome] < threshold_y {
1373                continue;
1374            }
1375            denom += 1;
1376            // Counterfactual: set treatment to counter
1377            let counter_scm = self.scm.intervene(treatment, t_counter);
1378            let x_counter = counter_scm.sample_with_noise(noise);
1379            if x_counter[outcome] < threshold_y {
1380                count += 1;
1381            }
1382        }
1383        if denom == 0 {
1384            0.0
1385        } else {
1386            count as f64 / denom as f64
1387        }
1388    }
1389}
1390
1391// ---------------------------------------------------------------------------
1392// Helper: Covariance matrix operations
1393// ---------------------------------------------------------------------------
1394
1395/// Compute the sample covariance matrix from a data matrix.
1396///
1397/// Returns a flat row-major matrix of size `[n_vars * n_vars]`.
1398pub fn sample_covariance(data: &[Vec<f64>]) -> Vec<f64> {
1399    let n = data.len();
1400    let p = data[0].len();
1401    let means: Vec<f64> = (0..p)
1402        .map(|j| data.iter().map(|row| row[j]).sum::<f64>() / n as f64)
1403        .collect();
1404    let mut cov = vec![0.0_f64; p * p];
1405    for i in 0..n {
1406        for j in 0..p {
1407            for k in j..p {
1408                cov[j * p + k] += (data[i][j] - means[j]) * (data[i][k] - means[k]);
1409            }
1410        }
1411    }
1412    for j in 0..p {
1413        for k in j..p {
1414            cov[j * p + k] /= (n - 1) as f64;
1415            cov[k * p + j] = cov[j * p + k];
1416        }
1417    }
1418    cov
1419}
1420
1421// ---------------------------------------------------------------------------
1422// Tests
1423// ---------------------------------------------------------------------------
1424
1425#[cfg(test)]
1426mod tests {
1427    use super::*;
1428
1429    fn simple_chain() -> CausalGraph {
1430        // X0 → X1 → X2
1431        let mut g = CausalGraph::new(3);
1432        g.add_edge(0, 1);
1433        g.add_edge(1, 2);
1434        g
1435    }
1436
1437    fn fork_graph() -> CausalGraph {
1438        // X0 → X1, X0 → X2
1439        let mut g = CausalGraph::new(3);
1440        g.add_edge(0, 1);
1441        g.add_edge(0, 2);
1442        g
1443    }
1444
1445    fn collider_graph() -> CausalGraph {
1446        // X0 → X2, X1 → X2
1447        let mut g = CausalGraph::new(3);
1448        g.add_edge(0, 2);
1449        g.add_edge(1, 2);
1450        g
1451    }
1452
1453    // --- CausalGraph tests ---
1454
1455    #[test]
1456    fn test_topological_sort_chain() {
1457        let g = simple_chain();
1458        let order = g.topological_sort().unwrap();
1459        assert_eq!(order, vec![0, 1, 2]);
1460    }
1461
1462    #[test]
1463    fn test_topological_sort_fork() {
1464        let g = fork_graph();
1465        let order = g.topological_sort().unwrap();
1466        assert_eq!(order[0], 0); // X0 must come first
1467    }
1468
1469    #[test]
1470    fn test_ancestors() {
1471        let g = simple_chain();
1472        let anc = g.ancestors(2);
1473        assert!(anc.contains(&0));
1474        assert!(anc.contains(&1));
1475        assert!(!anc.contains(&2));
1476    }
1477
1478    #[test]
1479    fn test_descendants() {
1480        let g = simple_chain();
1481        let desc = g.descendants(0);
1482        assert!(desc.contains(&1));
1483        assert!(desc.contains(&2));
1484    }
1485
1486    #[test]
1487    fn test_d_separation_chain_blocked_by_middle() {
1488        // X0 → X1 → X2; conditioning on X1 blocks 0⊥2|{1}
1489        let g = simple_chain();
1490        assert!(g.d_separated(&[0], &[2], &[1]));
1491    }
1492
1493    #[test]
1494    fn test_d_separation_chain_not_blocked_empty() {
1495        let g = simple_chain();
1496        assert!(!g.d_separated(&[0], &[2], &[]));
1497    }
1498
1499    #[test]
1500    fn test_d_separation_fork() {
1501        // X0 → X1, X0 → X2; conditioning on X0 blocks 1⊥2|{0}
1502        let g = fork_graph();
1503        assert!(g.d_separated(&[1], &[2], &[0]));
1504        assert!(!g.d_separated(&[1], &[2], &[]));
1505    }
1506
1507    #[test]
1508    fn test_d_separation_collider_blocked_by_default() {
1509        // X0 → X2 ← X1; collider: 0⊥1|{} (blocked)
1510        let g = collider_graph();
1511        assert!(g.d_separated(&[0], &[1], &[]));
1512    }
1513
1514    #[test]
1515    fn test_d_separation_collider_opened_by_conditioning() {
1516        // X0 → X2 ← X1; conditioning on X2 opens path: 0 NOT ⊥ 1|{2}
1517        let g = collider_graph();
1518        assert!(!g.d_separated(&[0], &[1], &[2]));
1519    }
1520
1521    #[test]
1522    fn test_is_acyclic() {
1523        let g = simple_chain();
1524        assert!(g.is_acyclic());
1525    }
1526
1527    #[test]
1528    fn test_creates_cycle_detected() {
1529        let mut g = CausalGraph::new(3);
1530        g.add_edge(0, 1);
1531        g.add_edge(1, 2);
1532        assert!(g.creates_cycle(2, 0));
1533    }
1534
1535    #[test]
1536    fn test_markov_blanket() {
1537        // X0 → X1 → X2; blanket(X1) = {X0, X2}
1538        let g = simple_chain();
1539        let blanket = g.markov_blanket(1);
1540        assert!(blanket.contains(&0));
1541        assert!(blanket.contains(&2));
1542        assert!(!blanket.contains(&1));
1543    }
1544
1545    // --- StructuralCausalModel tests ---
1546
1547    #[test]
1548    fn test_scm_sample_basic() {
1549        // X0 = ε0; X1 = 2*X0 + ε1
1550        let mut scm = StructuralCausalModel::new(2);
1551        scm.add_edge(0, 1, 2.0);
1552        scm.noise_std = vec![1.0, 0.0];
1553        let noise = vec![3.0, 0.0];
1554        let x = scm.sample_with_noise(&noise);
1555        assert!((x[0] - 3.0).abs() < 1e-10);
1556        assert!((x[1] - 6.0).abs() < 1e-10);
1557    }
1558
1559    #[test]
1560    fn test_scm_intervention_removes_parents() {
1561        let mut scm = StructuralCausalModel::new(2);
1562        scm.add_edge(0, 1, 2.0);
1563        let intervened = scm.intervene(1, 5.0);
1564        assert!(intervened.graph.parents[1].is_empty());
1565        let x = intervened.sample_with_noise(&[1.0, 0.0]);
1566        assert!((x[1] - 5.0).abs() < 1e-10);
1567    }
1568
1569    #[test]
1570    fn test_scm_total_effect_chain() {
1571        // X0 → X1 → X2, coeff 2.0 and 3.0
1572        let mut scm = StructuralCausalModel::new(3);
1573        scm.add_edge(0, 1, 2.0);
1574        scm.add_edge(1, 2, 3.0);
1575        let effect = scm.total_effect_linear(0, 2);
1576        assert!((effect - 6.0).abs() < 1e-10);
1577    }
1578
1579    #[test]
1580    fn test_scm_no_effect_for_independent_vars() {
1581        let mut scm = StructuralCausalModel::new(3);
1582        scm.add_edge(0, 1, 1.0);
1583        let effect = scm.total_effect_linear(0, 2);
1584        assert!(effect.abs() < 1e-10);
1585    }
1586
1587    // --- BackdoorCriterion tests ---
1588
1589    #[test]
1590    fn test_backdoor_check_valid_adjustment() {
1591        // W → X → Y, W → Y (confounding)
1592        // Adjusting for W satisfies backdoor
1593        let mut g = CausalGraph::new(3); // W=0, X=1, Y=2
1594        g.add_edge(0, 1); // W→X
1595        g.add_edge(1, 2); // X→Y
1596        g.add_edge(0, 2); // W→Y (backdoor)
1597        let bd = BackdoorCriterion::new(g);
1598        // W is not a descendant of X, and blocks the backdoor W→Y
1599        assert!(bd.check(1, 2, &[0]));
1600    }
1601
1602    #[test]
1603    fn test_backdoor_fails_if_descendant_in_set() {
1604        // X → M → Y; adjusting for M (a descendant of X) fails
1605        let mut g = CausalGraph::new(3); // X=0, M=1, Y=2
1606        g.add_edge(0, 1);
1607        g.add_edge(1, 2);
1608        let bd = BackdoorCriterion::new(g);
1609        assert!(!bd.check(0, 2, &[1])); // M is descendant of X
1610    }
1611
1612    // --- PropensityScoreMatching tests ---
1613
1614    #[test]
1615    fn test_propensity_score_fit_and_predict() {
1616        let mut psm = PropensityScoreMatching::new(2);
1617        let covariates: Vec<Vec<f64>> = (0..100).map(|i| vec![(i as f64) / 100.0, 0.5]).collect();
1618        let treatment: Vec<f64> = covariates
1619            .iter()
1620            .map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
1621            .collect();
1622        psm.fit(&covariates, &treatment, 0.5, 200);
1623        let ps = psm.predict(&covariates);
1624        assert_eq!(ps.len(), 100);
1625        // Scores should be between 0 and 1
1626        for p in &ps {
1627            assert!(*p > 0.0 && *p < 1.0);
1628        }
1629    }
1630
1631    #[test]
1632    fn test_ate_sign_positive() {
1633        // Treatment has positive effect on outcome
1634        let mut psm = PropensityScoreMatching::new(1);
1635        let n = 200;
1636        let covariates: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64) / n as f64]).collect();
1637        let treatment: Vec<f64> = covariates
1638            .iter()
1639            .map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
1640            .collect();
1641        let outcome: Vec<f64> = covariates
1642            .iter()
1643            .zip(treatment.iter())
1644            .map(|(x, t)| x[0] + 2.0 * t + 0.1)
1645            .collect();
1646        psm.fit(&covariates, &treatment, 0.3, 300);
1647        let ate = psm.estimate_ate(&covariates, &treatment, &outcome);
1648        assert!(ate > 0.0, "ATE should be positive, got {ate}");
1649    }
1650
1651    // --- InstrumentalVariables tests ---
1652
1653    #[test]
1654    fn test_iv_estimation_simple() {
1655        // Y = 2*D + e, D = Z + v (instrument Z)
1656        // IV estimate should recover β ≈ 2
1657        let n = 500;
1658        let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 % 2.0)]).collect();
1659        let d: Vec<f64> = z.iter().map(|zi| zi[0] + 0.5).collect();
1660        let y: Vec<f64> = d.iter().map(|di| 2.0 * di + 1.0).collect();
1661
1662        let mut iv = InstrumentalVariables::new(1, 1);
1663        iv.fit_2sls(&y, &d, &z);
1664        assert!(
1665            (iv.second_stage - 2.0).abs() < 0.5,
1666            "IV est = {}",
1667            iv.second_stage
1668        );
1669    }
1670
1671    #[test]
1672    fn test_iv_first_stage_f_stat() {
1673        let n = 200;
1674        let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 / n as f64)]).collect();
1675        let d: Vec<f64> = z.iter().map(|zi| 2.0 * zi[0]).collect();
1676        let y: Vec<f64> = d.iter().map(|di| di + 1.0).collect();
1677        let mut iv = InstrumentalVariables::new(1, 1);
1678        iv.fit_2sls(&y, &d, &z);
1679        let f = iv.first_stage_f_stat(&y, &d, &z);
1680        assert!(
1681            f > 10.0,
1682            "F-stat should be large for strong instrument, got {f}"
1683        );
1684    }
1685
1686    // --- CausalDiscovery tests ---
1687
1688    #[test]
1689    fn test_pearson_correlation_perfect() {
1690        let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, 2.0 * i as f64]).collect();
1691        let r = pearson_correlation(&data, 0, 1);
1692        assert!((r - 1.0).abs() < 1e-10);
1693    }
1694
1695    #[test]
1696    fn test_pearson_correlation_zero() {
1697        // x varies linearly, y is constant — they are uncorrelated (std_y=0 => r=0)
1698        let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 1.0]).collect();
1699        let r = pearson_correlation(&data, 0, 1);
1700        assert!(r.abs() < 1e-10, "r={r}");
1701    }
1702
1703    #[test]
1704    fn test_partial_correlation_returns_in_range() {
1705        let data: Vec<Vec<f64>> = (0..50)
1706            .map(|i| vec![i as f64, 2.0 * i as f64 + 1.0, i as f64 * 0.5])
1707            .collect();
1708        let r = partial_correlation(&data, 0, 1, &[2]);
1709        assert!((-1.0..=1.0).contains(&r));
1710    }
1711
1712    #[test]
1713    fn test_fisher_z_test_high_correlation() {
1714        let r = 0.0; // uncorrelated
1715        let p = fisher_z_test(r, 100, 0);
1716        assert!(p > 0.05, "Should not reject independence for r=0");
1717    }
1718
1719    #[test]
1720    fn test_causal_discovery_skeleton_independent() {
1721        // Two independent variables: skeleton should have no edge
1722        let data: Vec<Vec<f64>> = (0..100)
1723            .map(|i| vec![(i as f64).sin(), (i as f64 * 2.3 + 1.0).cos()])
1724            .collect();
1725        let mut cd = CausalDiscovery::new(2, 0.01);
1726        cd.learn_skeleton(&data);
1727        // May or may not find edges; just check no panic
1728        assert!(cd.n_vars == 2);
1729    }
1730
1731    #[test]
1732    fn test_subsets_correctness() {
1733        let v = vec![0, 1, 2];
1734        let subs = subsets(&v, 2);
1735        assert_eq!(subs.len(), 3);
1736    }
1737
1738    #[test]
1739    fn test_subsets_empty() {
1740        let v = vec![0, 1];
1741        let subs = subsets(&v, 0);
1742        assert_eq!(subs.len(), 1);
1743        assert!(subs[0].is_empty());
1744    }
1745
1746    // --- CounterfactualQuery tests ---
1747
1748    #[test]
1749    fn test_counterfactual_simple_chain() {
1750        // X1 = ε1, X2 = 2*X1 + ε2
1751        // If we observe X1=1, X2=2 (noise2=0), what is X2 if do(X1=3)?
1752        let mut scm = StructuralCausalModel::new(2);
1753        scm.add_edge(0, 1, 2.0);
1754        scm.noise_std = vec![1.0, 1.0];
1755        let query = CounterfactualQuery::new(scm);
1756        let obs = vec![Some(1.0), Some(2.0)];
1757        let cf = query.counterfactual(&obs, 0, 3.0, 1);
1758        // noise2 = (2 - 2*1)/1 = 0; X2_cf = 2*3 + 0*1 = 6
1759        assert!((cf - 6.0).abs() < 1e-10, "cf={cf}");
1760    }
1761
1762    #[test]
1763    fn test_counterfactual_intercept() {
1764        // X1 = 5 + ε1 (intercept=5); observe X1=7 → noise=2
1765        // Counterfactual do(X0=10): X1 stays (X0 not parent of X1)
1766        let mut scm = StructuralCausalModel::new(2);
1767        scm.intercepts[1] = 5.0;
1768        scm.noise_std = vec![1.0, 1.0];
1769        let query = CounterfactualQuery::new(scm);
1770        let obs = vec![Some(0.0), Some(7.0)];
1771        let cf = query.counterfactual(&obs, 0, 10.0, 1);
1772        // X1 doesn't depend on X0, noise1 = (7-5)/1 = 2, cf = 5 + 2 = 7
1773        assert!((cf - 7.0).abs() < 1e-10, "cf={cf}");
1774    }
1775
1776    // --- Sample covariance test ---
1777
1778    #[test]
1779    fn test_sample_covariance_diagonal() {
1780        let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 0.0]).collect();
1781        let cov = sample_covariance(&data);
1782        // var(X0) > 0, var(X1) = 0
1783        assert!(cov[0] > 0.0);
1784        assert!(cov[3].abs() < 1e-10); // var(X1) = 0
1785    }
1786
1787    #[test]
1788    fn test_sample_covariance_symmetric() {
1789        let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, (i as f64).sin()]).collect();
1790        let cov = sample_covariance(&data);
1791        assert!((cov[1] - cov[2]).abs() < 1e-12); // cov[0][1] == cov[1][0]
1792    }
1793
1794    // --- Integration test ---
1795
1796    #[test]
1797    fn test_full_scm_pipeline() {
1798        // Build a confounded SCM and estimate ACE
1799        // U → X, U → Y, X → Y
1800        // We use a 3-node SCM: U=0, X=1, Y=2
1801        let mut scm = StructuralCausalModel::new(3);
1802        scm.add_edge(0, 1, 1.0); // U→X
1803        scm.add_edge(0, 2, 1.0); // U→Y
1804        scm.add_edge(1, 2, 2.0); // X→Y (true causal effect = 2)
1805        scm.noise_std = vec![1.0, 0.5, 0.5];
1806
1807        // Generate noise samples
1808        let noise_samples: Vec<Vec<f64>> = (0..500)
1809            .map(|i| {
1810                let u = (i as f64 * 0.01).sin();
1811                let x = (i as f64 * 0.013).cos();
1812                let y = (i as f64 * 0.017).sin();
1813                vec![u, x, y]
1814            })
1815            .collect();
1816
1817        let ace0 = scm.average_causal_effect(1, 0.0, 2, &noise_samples);
1818        let ace1 = scm.average_causal_effect(1, 1.0, 2, &noise_samples);
1819        let diff = ace1 - ace0;
1820        // Should be close to 2.0 (the direct structural coefficient)
1821        assert!((diff - 2.0).abs() < 0.1, "ACE diff = {diff}");
1822    }
1823}