Skip to main content

scirs2_stats/causal/
mod.rs

1//! Causal Inference Methods
2//!
3//! This module provides a comprehensive suite of causal inference estimators:
4//!
5//! ## Sub-modules
6//!
7//! | Module | Methods |
8//! |--------|---------|
9//! | [`instrumental_variables`] | 2SLS, LIML, Hausman test, weak-instrument diagnostics |
10//! | [`difference_in_differences`] | DiD (TWFE), synthetic control, event study, staggered DiD |
11//! | [`regression_discontinuity`] | Sharp RDD, fuzzy RDD, bandwidth selection, RD plots |
12//! | [`propensity_score`] | Logistic PS model, IPW, nearest-neighbour / kernel matching |
13//!
14//! ## Quick start
15//!
16//! ```rust
17//! use scirs2_stats::causal::instrumental_variables::{IVEstimator, WeakInstrumentTest};
18//! use scirs2_stats::causal::propensity_score::{
19//!     PropensityScoreModel, IPW, PSMatching, MatchingMethod,
20//! };
21//! ```
22//!
23//! ## References
24//!
25//! - Angrist, J.D. & Pischke, J.-S. (2009). Mostly Harmless Econometrics.
26//! - Callaway, B. & Sant'Anna, P.H.C. (2021). Difference-in-Differences with
27//!   Multiple Time Periods.
28//! - Imbens, G.W. & Kalyanaraman, K. (2012). Optimal Bandwidth Choice for
29//!   the Regression Discontinuity Estimator.
30//! - Rosenbaum, P.R. & Rubin, D.B. (1983). The Central Role of the Propensity
31//!   Score in Observational Studies for Causal Effects.
32
33pub mod difference_in_differences;
34pub mod instrumental_variables;
35pub mod propensity_score;
36pub mod regression_discontinuity;
37
38// ---------------------------------------------------------------------------
39// Re-exports — instrumental variables
40// ---------------------------------------------------------------------------
41
42pub use instrumental_variables::{
43    HausmanResult, HausmanTest, IVEstimator, IVResult, WeakInstrumentResult, WeakInstrumentTest,
44    LIML,
45};
46
47// ---------------------------------------------------------------------------
48// Re-exports — difference-in-differences
49// ---------------------------------------------------------------------------
50
51pub use difference_in_differences::{
52    AttGt, DiD, DiDResult, EventCoefficient, EventStudy, EventStudyResult, StaggeredDiD,
53    StaggeredDiDResult, SyntheticControl,
54};
55
56// ---------------------------------------------------------------------------
57// Re-exports — regression discontinuity
58// ---------------------------------------------------------------------------
59
60pub use regression_discontinuity::{
61    BandwidthMethod, BandwidthSelector, FuzzyRDD, RDDPlot, RDDResult, RDD,
62};
63
64// ---------------------------------------------------------------------------
65// Re-exports — propensity score
66// ---------------------------------------------------------------------------
67
68pub use propensity_score::{
69    MatchingMethod, MatchingResult, OverlapCheck, OverlapResult, PSMatching, PSResult,
70    PropensityScoreModel, TrimMethod, IPW,
71};
72
73/// Convenience function: fit a propensity score model and estimate ATE/ATT/ATC via IPW.
74pub use propensity_score::ps_estimate;
75
76// ---------------------------------------------------------------------------
77// Structural Equation Models
78// ---------------------------------------------------------------------------
79
80pub mod sem;
81
82pub use sem::{satisfies_backdoor, IdentificationResult, LinearEquation, SEMWithIntercepts, SEM};
83
84// ---------------------------------------------------------------------------
85// Linear SEM with ndarray interface
86// ---------------------------------------------------------------------------
87
88pub mod conditional_independence;
89pub mod fci_algorithm;
90pub mod hedge;
91pub mod id_algorithm;
92pub mod linear_sem;
93pub mod pc_algorithm;
94pub mod semi_markov_graph;
95pub mod symbolic_prob;
96
97pub use linear_sem::{LinearSEM, LinearSEMWithIntercepts};
98
99// ---------------------------------------------------------------------------
100// Causal graph types for constraint-based algorithms (PC, FCI, etc.)
101// ---------------------------------------------------------------------------
102
103/// Marks on the endpoint of an edge in a mixed graph.
104///
105/// Used by constraint-based causal discovery algorithms (PC, FCI) to represent
106/// different types of edges in CPDAGs and PAGs:
107///
108/// - `Tail` — definite tail (non-ancestral mark, as in `→` tails, `—` tails).
109/// - `Arrow` — definite arrowhead.
110/// - `Circle` — unknown endpoint (used by FCI for partial ancestral graphs).
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum EdgeMark {
113    /// Definite tail: non-arrowhead endpoint.
114    Tail,
115    /// Definite arrowhead.
116    Arrow,
117    /// Unknown endpoint mark (FCI/PAG only).
118    Circle,
119}
120
121/// Mixed graph for causal discovery algorithms.
122///
123/// Each edge `(i, j)` is stored as a pair `(mark_at_i, mark_at_j)`, where
124/// `mark_at_i` is the mark at node `i` (endpoint facing node `i`) and
125/// `mark_at_j` is the mark at node `j`.
126///
127/// - Directed edge `i → j`: `(Tail, Arrow)` stored at entry `(i, j)`.
128/// - Undirected edge `i — j`: `(Tail, Tail)`.
129/// - Bidirected edge `i ↔ j`: `(Arrow, Arrow)`.
130/// - Circle endpoint `i o→ j`: `(Circle, Arrow)`.
131#[derive(Debug, Clone)]
132pub struct CausalGraph {
133    /// Variable names.
134    pub var_names: Vec<String>,
135    /// Adjacency: `edges[i][j] = Some((mark_at_i_from_j, mark_at_j_from_i))`.
136    /// If `edges[i][j].is_some()` then `edges[j][i].is_some()` as well.
137    edges: Vec<Vec<Option<(EdgeMark, EdgeMark)>>>,
138    /// Separation sets: `sep[i][j]` is the set that d-separates i and j.
139    pub sep_sets: Vec<Vec<Option<Vec<usize>>>>,
140}
141
142impl CausalGraph {
143    /// Create a new graph with the given variable names, initially fully connected.
144    pub fn new(var_names: &[&str]) -> Self {
145        let p = var_names.len();
146        Self {
147            var_names: var_names.iter().map(|s| s.to_string()).collect(),
148            edges: vec![vec![None; p]; p],
149            sep_sets: vec![vec![None; p]; p],
150        }
151    }
152
153    /// Number of variables (nodes).
154    pub fn num_vars(&self) -> usize {
155        self.var_names.len()
156    }
157
158    /// Set or update an edge between `i` and `j`.
159    ///
160    /// `mark_at_i` is the endpoint mark at node `i`; `mark_at_j` is the mark at `j`.
161    /// Setting an edge is symmetric: `edges[i][j]` and `edges[j][i]` are both updated.
162    pub fn set_edge(&mut self, i: usize, j: usize, mark_at_i: EdgeMark, mark_at_j: EdgeMark) {
163        self.edges[i][j] = Some((mark_at_i, mark_at_j));
164        self.edges[j][i] = Some((mark_at_j, mark_at_i));
165    }
166
167    /// Remove an edge between `i` and `j`.
168    pub fn remove_edge(&mut self, i: usize, j: usize) {
169        self.edges[i][j] = None;
170        self.edges[j][i] = None;
171    }
172
173    /// Whether there is any edge between `i` and `j`.
174    pub fn is_adjacent(&self, i: usize, j: usize) -> bool {
175        self.edges[i][j].is_some()
176    }
177
178    /// Get the mark at node `to` on the edge from `from` to `to`.
179    ///
180    /// Returns `None` if the edge doesn't exist.
181    /// The returned mark is the one facing node `to` (i.e., the arrowhead/tail at `to`).
182    pub fn get_mark_at(&self, from: usize, to: usize) -> Option<EdgeMark> {
183        self.edges[from][to].map(|(_, mark_at_to)| mark_at_to)
184    }
185
186    /// Get the mark at node `from` on the edge between `from` and `to`.
187    ///
188    /// Returns `None` if the edge doesn't exist.
189    pub fn get_mark_from(&self, from: usize, to: usize) -> Option<EdgeMark> {
190        self.edges[from][to].map(|(mark_at_from, _)| mark_at_from)
191    }
192
193    /// Whether there is a directed edge `i → j` (tail at `i`, arrow at `j`).
194    pub fn is_directed(&self, i: usize, j: usize) -> bool {
195        matches!(self.edges[i][j], Some((EdgeMark::Tail, EdgeMark::Arrow)))
196    }
197
198    /// Whether there is an undirected edge `i — j` (tail at both ends).
199    pub fn is_undirected(&self, i: usize, j: usize) -> bool {
200        matches!(self.edges[i][j], Some((EdgeMark::Tail, EdgeMark::Tail)))
201    }
202
203    /// Whether there is a bidirected edge `i ↔ j` (arrow at both ends).
204    pub fn is_bidirected(&self, i: usize, j: usize) -> bool {
205        matches!(self.edges[i][j], Some((EdgeMark::Arrow, EdgeMark::Arrow)))
206    }
207
208    /// Return an iterator over the neighbours of node `i`.
209    pub fn neighbors(&self, i: usize) -> impl Iterator<Item = usize> + '_ {
210        (0..self.num_vars()).filter(move |&j| j != i && self.is_adjacent(i, j))
211    }
212
213    /// Return the separation set for nodes `i` and `j`, if any.
214    pub fn get_sep_set(&self, i: usize, j: usize) -> Option<&Vec<usize>> {
215        self.sep_sets[i][j].as_ref()
216    }
217
218    /// Set the separation set for nodes `i` and `j`.
219    pub fn set_sep_set(&mut self, i: usize, j: usize, sep: Vec<usize>) {
220        self.sep_sets[i][j] = Some(sep.clone());
221        self.sep_sets[j][i] = Some(sep);
222    }
223
224    /// Initialize the graph as a complete undirected graph (all edges `i — j`).
225    pub fn make_complete(&mut self) {
226        let p = self.num_vars();
227        for i in 0..p {
228            for j in 0..p {
229                if i != j {
230                    self.edges[i][j] = Some((EdgeMark::Tail, EdgeMark::Tail));
231                }
232            }
233        }
234    }
235}