#![allow(clippy::needless_range_loop)]
#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct CausalGraph {
pub n_nodes: usize,
pub parents: Vec<Vec<usize>>,
pub children: Vec<Vec<usize>>,
pub names: Vec<String>,
}
impl CausalGraph {
pub fn new(n: usize) -> Self {
Self {
n_nodes: n,
parents: vec![vec![]; n],
children: vec![vec![]; n],
names: (0..n).map(|i| format!("X{i}")).collect(),
}
}
pub fn set_names(&mut self, names: &[&str]) {
assert_eq!(names.len(), self.n_nodes);
self.names = names.iter().map(|s| s.to_string()).collect();
}
pub fn add_edge(&mut self, from: usize, to: usize) {
assert!(
from < self.n_nodes && to < self.n_nodes,
"node index out of bounds"
);
assert!(
!self.creates_cycle(from, to),
"edge {from}→{to} would create a cycle"
);
if !self.children[from].contains(&to) {
self.children[from].push(to);
self.parents[to].push(from);
}
}
pub fn creates_cycle(&self, from: usize, to: usize) -> bool {
let mut visited = vec![false; self.n_nodes];
let mut stack = vec![to];
while let Some(node) = stack.pop() {
if node == from {
return true;
}
if !visited[node] {
visited[node] = true;
for &child in &self.children[node] {
stack.push(child);
}
}
}
false
}
pub fn topological_sort(&self) -> Option<Vec<usize>> {
let mut in_degree: Vec<usize> = self.parents.iter().map(|p| p.len()).collect();
let mut queue: VecDeque<usize> = (0..self.n_nodes).filter(|&v| in_degree[v] == 0).collect();
let mut order = Vec::with_capacity(self.n_nodes);
while let Some(v) = queue.pop_front() {
order.push(v);
for &child in &self.children[v] {
in_degree[child] -= 1;
if in_degree[child] == 0 {
queue.push_back(child);
}
}
}
if order.len() == self.n_nodes {
Some(order)
} else {
None
}
}
pub fn ancestors(&self, v: usize) -> HashSet<usize> {
let mut anc = HashSet::new();
let mut stack = vec![v];
while let Some(node) = stack.pop() {
for &p in &self.parents[node] {
if anc.insert(p) {
stack.push(p);
}
}
}
anc
}
pub fn descendants(&self, v: usize) -> HashSet<usize> {
let mut desc = HashSet::new();
let mut stack = vec![v];
while let Some(node) = stack.pop() {
for &c in &self.children[node] {
if desc.insert(c) {
stack.push(c);
}
}
}
desc
}
pub fn d_separated(&self, x: &[usize], y: &[usize], z: &[usize]) -> bool {
let z_set: HashSet<usize> = z.iter().copied().collect();
let y_set: HashSet<usize> = y.iter().copied().collect();
let mut z_ancestors: HashSet<usize> = z_set.clone();
for &zv in z {
z_ancestors.extend(self.ancestors(zv));
}
let mut visited: HashSet<(usize, bool)> = HashSet::new();
let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
for &xv in x {
queue.push_back((xv, true)); queue.push_back((xv, false)); }
while let Some((node, via_child)) = queue.pop_front() {
if visited.contains(&(node, via_child)) {
continue;
}
visited.insert((node, via_child));
if y_set.contains(&node) {
return false; }
let in_z = z_set.contains(&node);
let in_z_anc = z_ancestors.contains(&node);
if via_child && !in_z {
for &p in &self.parents[node] {
queue.push_back((p, true));
}
for &c in &self.children[node] {
queue.push_back((c, false));
}
} else if !via_child {
if !in_z {
for &c in &self.children[node] {
queue.push_back((c, false));
}
}
if in_z_anc {
for &p in &self.parents[node] {
queue.push_back((p, true));
}
}
}
}
true
}
pub fn markov_blanket(&self, v: usize) -> HashSet<usize> {
let mut blanket = HashSet::new();
for &p in &self.parents[v] {
blanket.insert(p);
}
for &c in &self.children[v] {
blanket.insert(c);
for &cp in &self.parents[c] {
if cp != v {
blanket.insert(cp);
}
}
}
blanket
}
pub fn is_acyclic(&self) -> bool {
self.topological_sort().is_some()
}
}
#[derive(Debug, Clone)]
pub struct StructuralCausalModel {
pub graph: CausalGraph,
pub coefficients: Vec<Vec<f64>>,
pub noise_std: Vec<f64>,
pub intercepts: Vec<f64>,
}
impl StructuralCausalModel {
pub fn new(n: usize) -> Self {
Self {
graph: CausalGraph::new(n),
coefficients: vec![vec![]; n],
noise_std: vec![1.0; n],
intercepts: vec![0.0; n],
}
}
pub fn add_edge(&mut self, from: usize, to: usize, coeff: f64) {
self.graph.add_edge(from, to);
self.coefficients[to].push(coeff);
}
pub fn set_noise(&mut self, v: usize, std: f64) {
self.noise_std[v] = std;
}
pub fn set_intercept(&mut self, v: usize, intercept: f64) {
self.intercepts[v] = intercept;
}
pub fn sample_with_noise(&self, noise: &[f64]) -> Vec<f64> {
let n = self.graph.n_nodes;
let order = self
.graph
.topological_sort()
.expect("SCM graph must be acyclic");
let mut x = vec![0.0_f64; n];
for &v in &order {
let val: f64 = self.intercepts[v]
+ self.graph.parents[v]
.iter()
.zip(self.coefficients[v].iter())
.map(|(&p, &c)| c * x[p])
.sum::<f64>()
+ self.noise_std[v] * noise[v];
x[v] = val;
}
x
}
pub fn intervene(&self, target: usize, val: f64) -> Self {
let mut scm = self.clone();
let parents = scm.graph.parents[target].clone();
for &p in &parents {
scm.graph.children[p].retain(|&c| c != target);
}
scm.graph.parents[target].clear();
scm.coefficients[target].clear();
scm.noise_std[target] = 0.0;
scm.intercepts[target] = val;
scm
}
pub fn average_causal_effect(
&self,
cause: usize,
val: f64,
effect: usize,
noise_samples: &[Vec<f64>],
) -> f64 {
let intervened = self.intervene(cause, val);
let mean: f64 = noise_samples
.iter()
.map(|noise| intervened.sample_with_noise(noise)[effect])
.sum::<f64>()
/ noise_samples.len() as f64;
mean
}
pub fn total_effect_linear(&self, cause: usize, effect: usize) -> f64 {
let mut total = 0.0_f64;
let mut stack: Vec<(usize, f64)> = vec![(cause, 1.0)];
while let Some((node, prod)) = stack.pop() {
if node == effect && node != cause {
total += prod;
}
for (k, &child) in self.graph.children[node].iter().enumerate() {
if let Some(idx) = self.graph.parents[child].iter().position(|&p| p == node) {
let coeff = self.coefficients[child][idx];
let _ = k; stack.push((child, prod * coeff));
}
}
}
total
}
}
#[derive(Debug, Clone)]
pub struct BackdoorCriterion {
pub graph: CausalGraph,
}
impl BackdoorCriterion {
pub fn new(graph: CausalGraph) -> Self {
Self { graph }
}
pub fn check(&self, treatment: usize, outcome: usize, adjustment_set: &[usize]) -> bool {
let desc_treatment = self.graph.descendants(treatment);
for &z in adjustment_set {
if desc_treatment.contains(&z) {
return false;
}
}
let mut modified = self.graph.clone();
let children_of_treatment = modified.graph_children_of(treatment);
for &c in &children_of_treatment {
modified.parents[c].retain(|&p| p != treatment);
}
modified.children[treatment].clear();
modified.d_separated(&[treatment], &[outcome], adjustment_set)
}
pub fn adjusted_effect(
data_x: &[f64],
data_y: &[f64],
data_z: &[f64],
x_val: f64,
_tolerance: f64,
) -> f64 {
let n = data_x.len();
assert_eq!(data_y.len(), n);
assert_eq!(data_z.len(), n);
let n_strata = 5usize;
let mut sorted_z = data_z.to_vec();
sorted_z.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let quantiles: Vec<f64> = (1..n_strata)
.map(|i| sorted_z[(i * n) / n_strata])
.collect();
let stratum_of = |z: f64| -> usize {
quantiles
.iter()
.position(|&q| z < q)
.unwrap_or(n_strata - 1)
};
let mut stratum_sums_y = vec![0.0_f64; n_strata];
let mut stratum_counts = vec![0usize; n_strata];
let mut stratum_counts_near_x = vec![0usize; n_strata];
let mut stratum_y_near_x = vec![0.0_f64; n_strata];
let bandwidth = {
let mut xs = data_x.to_vec();
xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let iqr = xs[3 * n / 4] - xs[n / 4];
iqr.max(0.1) * 0.5
};
for i in 0..n {
let s = stratum_of(data_z[i]);
stratum_sums_y[s] += data_y[i];
stratum_counts[s] += 1;
if (data_x[i] - x_val).abs() < bandwidth {
stratum_y_near_x[s] += data_y[i];
stratum_counts_near_x[s] += 1;
}
}
let mut total = 0.0_f64;
for s in 0..n_strata {
if stratum_counts[s] == 0 {
continue;
}
let p_z = stratum_counts[s] as f64 / n as f64;
let e_y_xz = if stratum_counts_near_x[s] > 0 {
stratum_y_near_x[s] / stratum_counts_near_x[s] as f64
} else {
stratum_sums_y[s] / stratum_counts[s] as f64
};
total += e_y_xz * p_z;
}
total
}
}
trait GraphChildrenOf {
fn graph_children_of(&self, v: usize) -> Vec<usize>;
}
impl GraphChildrenOf for CausalGraph {
fn graph_children_of(&self, v: usize) -> Vec<usize> {
self.children[v].clone()
}
}
#[derive(Debug, Clone)]
pub struct FrontdoorCriterion {
pub graph: CausalGraph,
}
impl FrontdoorCriterion {
pub fn new(graph: CausalGraph) -> Self {
Self { graph }
}
pub fn check(&self, treatment: usize, outcome: usize, mediator_set: &[usize]) -> bool {
let med_set: HashSet<usize> = mediator_set.iter().copied().collect();
if !self.intercepts_all_paths(treatment, outcome, &med_set) {
return false;
}
for &m in mediator_set {
if !self.graph.d_separated(&[treatment], &[m], &[treatment]) {
let x_anc = self.graph.ancestors(treatment);
let m_anc = self.graph.ancestors(m);
let _ = (x_anc, m_anc);
}
}
for &m in mediator_set {
if !self.graph.d_separated(&[m], &[outcome], &[treatment]) {
return false;
}
}
true
}
fn intercepts_all_paths(&self, src: usize, dst: usize, med_set: &HashSet<usize>) -> bool {
let mut stack: Vec<(usize, Vec<usize>)> = vec![(src, vec![src])];
while let Some((node, path)) = stack.pop() {
if node == dst {
let on_path = path[1..].iter().any(|v| med_set.contains(v));
if !on_path {
return false;
}
continue;
}
for &child in &self.graph.children[node] {
if !path.contains(&child) {
let mut new_path = path.clone();
new_path.push(child);
stack.push((child, new_path));
}
}
}
true
}
pub fn adjusted_effect(data_x: &[f64], data_m: &[f64], data_y: &[f64], x_val: f64) -> f64 {
let n = data_x.len();
let bandwidth = {
let mut xs = data_x.to_vec();
xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let iqr = xs[3 * n / 4] - xs[n / 4];
iqr.max(0.1) * 0.4
};
let (mut sum_m, mut w_sum) = (0.0_f64, 0.0_f64);
for i in 0..n {
let w = gaussian_kernel((data_x[i] - x_val) / bandwidth);
sum_m += w * data_m[i];
w_sum += w;
}
let e_m_given_x = if w_sum > 1e-12 { sum_m / w_sum } else { 0.0 };
let bw_m = {
let mut ms = data_m.to_vec();
ms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let iqr = ms[3 * n / 4] - ms[n / 4];
iqr.max(0.1) * 0.4
};
let (mut sum_y, mut wy_sum) = (0.0_f64, 0.0_f64);
for i in 0..n {
let w = gaussian_kernel((data_m[i] - e_m_given_x) / bw_m);
sum_y += w * data_y[i];
wy_sum += w;
}
if wy_sum > 1e-12 { sum_y / wy_sum } else { 0.0 }
}
}
#[allow(dead_code)]
fn gaussian_kernel(u: f64) -> f64 {
(-0.5 * u * u).exp()
}
#[derive(Debug, Clone)]
pub struct PropensityScoreMatching {
pub weights: Vec<f64>,
pub n_covariates: usize,
}
impl PropensityScoreMatching {
pub fn new(n_covariates: usize) -> Self {
Self {
weights: vec![0.0; n_covariates + 1],
n_covariates,
}
}
pub fn fit(&mut self, covariates: &[Vec<f64>], treatment: &[f64], lr: f64, n_iter: usize) {
let n = covariates.len();
assert_eq!(treatment.len(), n);
for _ in 0..n_iter {
let mut grad = vec![0.0_f64; self.n_covariates + 1];
for i in 0..n {
let p = self.predict_one(&covariates[i]);
let err = p - treatment[i];
grad[0] += err; for j in 0..self.n_covariates {
grad[j + 1] += err * covariates[i][j];
}
}
for k in 0..self.weights.len() {
self.weights[k] -= lr * grad[k] / n as f64;
}
}
}
pub fn predict_one(&self, x: &[f64]) -> f64 {
let logit: f64 = self.weights[0]
+ x.iter()
.zip(self.weights[1..].iter())
.map(|(xi, wi)| xi * wi)
.sum::<f64>();
sigmoid(logit)
}
pub fn predict(&self, covariates: &[Vec<f64>]) -> Vec<f64> {
covariates.iter().map(|x| self.predict_one(x)).collect()
}
pub fn estimate_ate(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
let n = covariates.len();
let (mut sum1, mut w1, mut sum0, mut w0) = (0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64);
for i in 0..n {
let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
if treatment[i] > 0.5 {
sum1 += outcome[i] / e;
w1 += 1.0 / e;
} else {
sum0 += outcome[i] / (1.0 - e);
w0 += 1.0 / (1.0 - e);
}
}
let ey1 = if w1 > 0.0 { sum1 / w1 } else { 0.0 };
let ey0 = if w0 > 0.0 { sum0 / w0 } else { 0.0 };
ey1 - ey0
}
pub fn estimate_att(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
let n = covariates.len();
let mut treated_y: Vec<f64> = Vec::new();
let mut control_y: Vec<f64> = Vec::new();
let mut control_ps: Vec<f64> = Vec::new();
for i in 0..n {
let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
if treatment[i] > 0.5 {
treated_y.push(outcome[i]);
} else {
control_y.push(outcome[i]);
control_ps.push(e / (1.0 - e)); }
}
if treated_y.is_empty() || control_y.is_empty() {
return 0.0;
}
let mean_treated = treated_y.iter().sum::<f64>() / treated_y.len() as f64;
let total_weight: f64 = control_ps.iter().sum();
let mean_control = if total_weight > 0.0 {
control_y
.iter()
.zip(control_ps.iter())
.map(|(y, w)| y * w)
.sum::<f64>()
/ total_weight
} else {
control_y.iter().sum::<f64>() / control_y.len() as f64
};
mean_treated - mean_control
}
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
#[derive(Debug, Clone)]
pub struct InstrumentalVariables {
pub n_endogenous: usize,
pub n_instruments: usize,
pub first_stage: Vec<f64>,
pub second_stage: f64,
}
impl InstrumentalVariables {
pub fn new(n_endogenous: usize, n_instruments: usize) -> Self {
Self {
n_endogenous,
n_instruments,
first_stage: vec![0.0; n_instruments + 1],
second_stage: 0.0,
}
}
pub fn fit_2sls(&mut self, y: &[f64], d: &[f64], z: &[Vec<f64>]) {
let n = y.len();
assert_eq!(d.len(), n);
assert_eq!(z.len(), n);
let n_inst = self.n_instruments;
let mut d_hat = vec![0.0_f64; n];
if n_inst == 1 {
let z_vec: Vec<f64> = z.iter().map(|row| row[0]).collect();
let mean_z = z_vec.iter().sum::<f64>() / n as f64;
let mean_d = d.iter().sum::<f64>() / n as f64;
let mean_y = y.iter().sum::<f64>() / n as f64;
let cov_dz: f64 = d
.iter()
.zip(z_vec.iter())
.map(|(di, zi)| (di - mean_d) * (zi - mean_z))
.sum::<f64>()
/ n as f64;
let cov_yz: f64 = y
.iter()
.zip(z_vec.iter())
.map(|(yi, zi)| (yi - mean_y) * (zi - mean_z))
.sum::<f64>()
/ n as f64;
let var_z: f64 = z_vec.iter().map(|zi| (zi - mean_z).powi(2)).sum::<f64>() / n as f64;
let alpha1 = if var_z.abs() > 1e-12 {
cov_dz / var_z
} else {
0.0
};
let alpha0 = mean_d - alpha1 * mean_z;
self.first_stage[0] = alpha0;
self.first_stage[1] = alpha1;
self.second_stage = if cov_dz.abs() > 1e-12 {
cov_yz / cov_dz
} else {
0.0
};
for i in 0..n {
d_hat[i] = alpha0 + alpha1 * z_vec[i];
}
} else {
let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
let mean_z0 = z0.iter().sum::<f64>() / n as f64;
let mean_d = d.iter().sum::<f64>() / n as f64;
let cov = z0
.iter()
.zip(d.iter())
.map(|(zi, di)| (zi - mean_z0) * (di - mean_d))
.sum::<f64>()
/ n as f64;
let var_z0 = z0.iter().map(|zi| (zi - mean_z0).powi(2)).sum::<f64>() / n as f64;
let alpha1 = if var_z0 > 1e-12 { cov / var_z0 } else { 0.0 };
let alpha0 = mean_d - alpha1 * mean_z0;
self.first_stage[0] = alpha0;
self.first_stage[1] = alpha1;
for i in 0..n {
d_hat[i] = alpha0 + alpha1 * z0[i];
}
let mean_dhat = d_hat.iter().sum::<f64>() / n as f64;
let mean_y = y.iter().sum::<f64>() / n as f64;
let cov_ydhat: f64 = y
.iter()
.zip(d_hat.iter())
.map(|(yi, di)| (yi - mean_y) * (di - mean_dhat))
.sum::<f64>()
/ n as f64;
let var_dhat: f64 =
d_hat.iter().map(|di| (di - mean_dhat).powi(2)).sum::<f64>() / n as f64;
self.second_stage = if var_dhat > 1e-12 {
cov_ydhat / var_dhat
} else {
0.0
};
}
}
pub fn first_stage_f_stat(&self, y: &[f64], d: &[f64], z: &[Vec<f64>]) -> f64 {
let n = y.len();
let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
let _mean_z0 = z0.iter().sum::<f64>() / n as f64;
let mean_d = d.iter().sum::<f64>() / n as f64;
let d_hat: Vec<f64> = z0
.iter()
.map(|zi| self.first_stage[0] + self.first_stage[1] * zi)
.collect();
let ss_res: f64 = d
.iter()
.zip(d_hat.iter())
.map(|(di, dh)| (di - dh).powi(2))
.sum();
let ss_tot: f64 = d.iter().map(|di| (di - mean_d).powi(2)).sum();
let r2 = 1.0 - ss_res / ss_tot.max(1e-12);
let k = 1.0_f64; let n_f = n as f64;
(r2 / k) / ((1.0 - r2) / (n_f - k - 1.0)).max(1e-12)
}
pub fn predict(&self, d_val: f64) -> f64 {
self.second_stage * d_val
}
}
#[derive(Debug, Clone)]
pub struct CausalDiscovery {
pub n_vars: usize,
pub skeleton: Vec<Vec<bool>>,
pub directed: Vec<Vec<bool>>,
pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
pub alpha: f64,
}
impl CausalDiscovery {
pub fn new(n_vars: usize, alpha: f64) -> Self {
Self {
n_vars,
skeleton: vec![vec![true; n_vars]; n_vars],
directed: vec![vec![false; n_vars]; n_vars],
sep_sets: HashMap::new(),
alpha,
}
}
pub fn learn_skeleton(&mut self, data: &[Vec<f64>]) {
let n = self.n_vars;
for i in 0..n {
self.skeleton[i][i] = false;
}
for i in 0..n {
for j in (i + 1)..n {
let r = partial_correlation(data, i, j, &[]);
let p = fisher_z_test(r, data.len(), 0);
if p > self.alpha {
self.skeleton[i][j] = false;
self.skeleton[j][i] = false;
self.sep_sets.insert((i, j), vec![]);
self.sep_sets.insert((j, i), vec![]);
}
}
}
for cond_size in 1..n.saturating_sub(1) {
for i in 0..n {
let adj_i: Vec<usize> = (0..n).filter(|&k| k != i && self.skeleton[i][k]).collect();
for &j in &adj_i {
if !self.skeleton[i][j] {
continue;
}
let adj_minus_j: Vec<usize> =
adj_i.iter().copied().filter(|&k| k != j).collect();
if adj_minus_j.len() < cond_size {
continue;
}
for cond_set in subsets(&adj_minus_j, cond_size) {
let r = partial_correlation(data, i, j, &cond_set);
let p = fisher_z_test(r, data.len(), cond_size);
if p > self.alpha {
self.skeleton[i][j] = false;
self.skeleton[j][i] = false;
self.sep_sets.insert((i, j), cond_set.clone());
self.sep_sets.insert((j, i), cond_set);
break;
}
}
}
}
}
}
pub fn orient_v_structures(&mut self) {
let n = self.n_vars;
for i in 0..n {
for k in 0..n {
if i == k || !self.skeleton[i][k] {
continue;
}
for j in (i + 1)..n {
if j == k || !self.skeleton[k][j] || self.skeleton[i][j] {
continue;
}
let sep = self.sep_sets.get(&(i, j)).cloned().unwrap_or_default();
if !sep.contains(&k) {
self.directed[i][k] = true;
self.directed[j][k] = true;
self.skeleton[k][i] = false;
self.skeleton[k][j] = false;
}
}
}
}
}
pub fn apply_meek_rules(&mut self) {
let n = self.n_vars;
let mut changed = true;
while changed {
changed = false;
for i in 0..n {
for j in 0..n {
if !self.directed[i][j] {
continue;
}
for k in 0..n {
if k == i || k == j {
continue;
}
if self.skeleton[j][k]
&& !self.directed[j][k]
&& !self.directed[k][j]
&& !self.skeleton[i][k]
{
self.directed[j][k] = true;
self.skeleton[k][j] = false;
changed = true;
}
}
}
}
for i in 0..n {
for j in 0..n {
if i == j || !self.skeleton[i][j] || self.directed[i][j] {
continue;
}
for k in 0..n {
if k == i || k == j {
continue;
}
if self.directed[i][k] && self.directed[k][j] {
self.directed[i][j] = true;
self.skeleton[j][i] = false;
changed = true;
}
}
}
}
}
}
pub fn run(&mut self, data: &[Vec<f64>]) {
self.learn_skeleton(data);
self.orient_v_structures();
self.apply_meek_rules();
}
}
pub fn partial_correlation(data: &[Vec<f64>], i: usize, j: usize, cond: &[usize]) -> f64 {
if cond.is_empty() {
return pearson_correlation(data, i, j);
}
if cond.len() == 1 {
let k = cond[0];
let r_ij = pearson_correlation(data, i, j);
let r_ik = pearson_correlation(data, i, k);
let r_jk = pearson_correlation(data, j, k);
let denom = ((1.0 - r_ik * r_ik) * (1.0 - r_jk * r_jk)).sqrt();
if denom < 1e-12 {
return 0.0;
}
return (r_ij - r_ik * r_jk) / denom;
}
let last = cond[cond.len() - 1];
let rest = &cond[..cond.len() - 1];
let r_ij_rest = partial_correlation(data, i, j, rest);
let r_ik_rest = partial_correlation(data, i, last, rest);
let r_jk_rest = partial_correlation(data, j, last, rest);
let denom = ((1.0 - r_ik_rest * r_ik_rest) * (1.0 - r_jk_rest * r_jk_rest)).sqrt();
if denom < 1e-12 {
return 0.0;
}
(r_ij_rest - r_ik_rest * r_jk_rest) / denom
}
pub fn pearson_correlation(data: &[Vec<f64>], i: usize, j: usize) -> f64 {
let n = data.len() as f64;
let mean_i = data.iter().map(|row| row[i]).sum::<f64>() / n;
let mean_j = data.iter().map(|row| row[j]).sum::<f64>() / n;
let cov: f64 = data
.iter()
.map(|row| (row[i] - mean_i) * (row[j] - mean_j))
.sum::<f64>()
/ n;
let std_i = (data
.iter()
.map(|row| (row[i] - mean_i).powi(2))
.sum::<f64>()
/ n)
.sqrt();
let std_j = (data
.iter()
.map(|row| (row[j] - mean_j).powi(2))
.sum::<f64>()
/ n)
.sqrt();
if std_i < 1e-12 || std_j < 1e-12 {
return 0.0;
}
(cov / (std_i * std_j)).clamp(-1.0, 1.0)
}
pub fn fisher_z_test(r: f64, n: usize, cond_size: usize) -> f64 {
let r = r.clamp(-0.9999, 0.9999);
let z = 0.5 * ((1.0 + r) / (1.0 - r)).ln();
let se = 1.0 / ((n as f64 - cond_size as f64 - 3.0).max(1.0)).sqrt();
let stat = (z / se).abs();
2.0 * (1.0 - standard_normal_cdf(stat))
}
fn standard_normal_cdf(x: f64) -> f64 {
let t = 1.0 / (1.0 + 0.2316419 * x.abs());
let poly = t
* (0.319_381_530
+ t * (-0.356_563_782
+ t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
let pdf = (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt();
let cdf = 1.0 - pdf * poly;
if x >= 0.0 { cdf } else { 1.0 - cdf }
}
fn subsets(set: &[usize], k: usize) -> Vec<Vec<usize>> {
if k == 0 {
return vec![vec![]];
}
if set.len() < k {
return vec![];
}
let mut result = Vec::new();
for (i, &v) in set.iter().enumerate() {
let rest = subsets(&set[(i + 1)..], k - 1);
for mut subset in rest {
subset.insert(0, v);
result.push(subset);
}
}
result
}
#[derive(Debug, Clone)]
pub struct CounterfactualQuery {
pub scm: StructuralCausalModel,
}
impl CounterfactualQuery {
pub fn new(scm: StructuralCausalModel) -> Self {
Self { scm }
}
pub fn query_do(
&self,
target: usize,
x_val: f64,
outcome: usize,
noise_samples: &[Vec<f64>],
) -> f64 {
self.scm
.average_causal_effect(target, x_val, outcome, noise_samples)
}
pub fn counterfactual(
&self,
obs: &[Option<f64>],
target: usize,
x_val: f64,
outcome: usize,
) -> f64 {
let n = self.scm.graph.n_nodes;
assert_eq!(obs.len(), n);
let order = self
.scm
.graph
.topological_sort()
.expect("SCM must be acyclic");
let mut x = vec![0.0_f64; n];
let mut noise = vec![0.0_f64; n];
for &v in &order {
if let Some(val) = obs[v] {
x[v] = val;
let pred: f64 = self.scm.graph.parents[v]
.iter()
.zip(self.scm.coefficients[v].iter())
.map(|(&p, &c)| c * x[p])
.sum::<f64>();
let residual = val - self.scm.intercepts[v] - pred;
noise[v] = if self.scm.noise_std[v].abs() > 1e-12 {
residual / self.scm.noise_std[v]
} else {
0.0
};
} else {
noise[v] = 0.0;
let pred: f64 = self.scm.graph.parents[v]
.iter()
.zip(self.scm.coefficients[v].iter())
.map(|(&p, &c)| c * x[p])
.sum::<f64>();
x[v] = self.scm.intercepts[v] + pred;
}
}
let intervened = self.scm.intervene(target, x_val);
intervened.sample_with_noise(&noise)[outcome]
}
pub fn probability_of_necessity(
&self,
treatment: usize,
outcome: usize,
t_val: f64,
t_counter: f64,
threshold_y: f64,
noise_samples: &[Vec<f64>],
) -> f64 {
let mut count = 0;
let mut denom = 0;
for noise in noise_samples {
let x_actual = self.scm.sample_with_noise(noise);
if x_actual[treatment] < t_val - 0.5 || x_actual[outcome] < threshold_y {
continue;
}
denom += 1;
let counter_scm = self.scm.intervene(treatment, t_counter);
let x_counter = counter_scm.sample_with_noise(noise);
if x_counter[outcome] < threshold_y {
count += 1;
}
}
if denom == 0 {
0.0
} else {
count as f64 / denom as f64
}
}
}
pub fn sample_covariance(data: &[Vec<f64>]) -> Vec<f64> {
let n = data.len();
let p = data[0].len();
let means: Vec<f64> = (0..p)
.map(|j| data.iter().map(|row| row[j]).sum::<f64>() / n as f64)
.collect();
let mut cov = vec![0.0_f64; p * p];
for i in 0..n {
for j in 0..p {
for k in j..p {
cov[j * p + k] += (data[i][j] - means[j]) * (data[i][k] - means[k]);
}
}
}
for j in 0..p {
for k in j..p {
cov[j * p + k] /= (n - 1) as f64;
cov[k * p + j] = cov[j * p + k];
}
}
cov
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_chain() -> CausalGraph {
let mut g = CausalGraph::new(3);
g.add_edge(0, 1);
g.add_edge(1, 2);
g
}
fn fork_graph() -> CausalGraph {
let mut g = CausalGraph::new(3);
g.add_edge(0, 1);
g.add_edge(0, 2);
g
}
fn collider_graph() -> CausalGraph {
let mut g = CausalGraph::new(3);
g.add_edge(0, 2);
g.add_edge(1, 2);
g
}
#[test]
fn test_topological_sort_chain() {
let g = simple_chain();
let order = g.topological_sort().unwrap();
assert_eq!(order, vec![0, 1, 2]);
}
#[test]
fn test_topological_sort_fork() {
let g = fork_graph();
let order = g.topological_sort().unwrap();
assert_eq!(order[0], 0); }
#[test]
fn test_ancestors() {
let g = simple_chain();
let anc = g.ancestors(2);
assert!(anc.contains(&0));
assert!(anc.contains(&1));
assert!(!anc.contains(&2));
}
#[test]
fn test_descendants() {
let g = simple_chain();
let desc = g.descendants(0);
assert!(desc.contains(&1));
assert!(desc.contains(&2));
}
#[test]
fn test_d_separation_chain_blocked_by_middle() {
let g = simple_chain();
assert!(g.d_separated(&[0], &[2], &[1]));
}
#[test]
fn test_d_separation_chain_not_blocked_empty() {
let g = simple_chain();
assert!(!g.d_separated(&[0], &[2], &[]));
}
#[test]
fn test_d_separation_fork() {
let g = fork_graph();
assert!(g.d_separated(&[1], &[2], &[0]));
assert!(!g.d_separated(&[1], &[2], &[]));
}
#[test]
fn test_d_separation_collider_blocked_by_default() {
let g = collider_graph();
assert!(g.d_separated(&[0], &[1], &[]));
}
#[test]
fn test_d_separation_collider_opened_by_conditioning() {
let g = collider_graph();
assert!(!g.d_separated(&[0], &[1], &[2]));
}
#[test]
fn test_is_acyclic() {
let g = simple_chain();
assert!(g.is_acyclic());
}
#[test]
fn test_creates_cycle_detected() {
let mut g = CausalGraph::new(3);
g.add_edge(0, 1);
g.add_edge(1, 2);
assert!(g.creates_cycle(2, 0));
}
#[test]
fn test_markov_blanket() {
let g = simple_chain();
let blanket = g.markov_blanket(1);
assert!(blanket.contains(&0));
assert!(blanket.contains(&2));
assert!(!blanket.contains(&1));
}
#[test]
fn test_scm_sample_basic() {
let mut scm = StructuralCausalModel::new(2);
scm.add_edge(0, 1, 2.0);
scm.noise_std = vec![1.0, 0.0];
let noise = vec![3.0, 0.0];
let x = scm.sample_with_noise(&noise);
assert!((x[0] - 3.0).abs() < 1e-10);
assert!((x[1] - 6.0).abs() < 1e-10);
}
#[test]
fn test_scm_intervention_removes_parents() {
let mut scm = StructuralCausalModel::new(2);
scm.add_edge(0, 1, 2.0);
let intervened = scm.intervene(1, 5.0);
assert!(intervened.graph.parents[1].is_empty());
let x = intervened.sample_with_noise(&[1.0, 0.0]);
assert!((x[1] - 5.0).abs() < 1e-10);
}
#[test]
fn test_scm_total_effect_chain() {
let mut scm = StructuralCausalModel::new(3);
scm.add_edge(0, 1, 2.0);
scm.add_edge(1, 2, 3.0);
let effect = scm.total_effect_linear(0, 2);
assert!((effect - 6.0).abs() < 1e-10);
}
#[test]
fn test_scm_no_effect_for_independent_vars() {
let mut scm = StructuralCausalModel::new(3);
scm.add_edge(0, 1, 1.0);
let effect = scm.total_effect_linear(0, 2);
assert!(effect.abs() < 1e-10);
}
#[test]
fn test_backdoor_check_valid_adjustment() {
let mut g = CausalGraph::new(3); g.add_edge(0, 1); g.add_edge(1, 2); g.add_edge(0, 2); let bd = BackdoorCriterion::new(g);
assert!(bd.check(1, 2, &[0]));
}
#[test]
fn test_backdoor_fails_if_descendant_in_set() {
let mut g = CausalGraph::new(3); g.add_edge(0, 1);
g.add_edge(1, 2);
let bd = BackdoorCriterion::new(g);
assert!(!bd.check(0, 2, &[1])); }
#[test]
fn test_propensity_score_fit_and_predict() {
let mut psm = PropensityScoreMatching::new(2);
let covariates: Vec<Vec<f64>> = (0..100).map(|i| vec![(i as f64) / 100.0, 0.5]).collect();
let treatment: Vec<f64> = covariates
.iter()
.map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
.collect();
psm.fit(&covariates, &treatment, 0.5, 200);
let ps = psm.predict(&covariates);
assert_eq!(ps.len(), 100);
for p in &ps {
assert!(*p > 0.0 && *p < 1.0);
}
}
#[test]
fn test_ate_sign_positive() {
let mut psm = PropensityScoreMatching::new(1);
let n = 200;
let covariates: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64) / n as f64]).collect();
let treatment: Vec<f64> = covariates
.iter()
.map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
.collect();
let outcome: Vec<f64> = covariates
.iter()
.zip(treatment.iter())
.map(|(x, t)| x[0] + 2.0 * t + 0.1)
.collect();
psm.fit(&covariates, &treatment, 0.3, 300);
let ate = psm.estimate_ate(&covariates, &treatment, &outcome);
assert!(ate > 0.0, "ATE should be positive, got {ate}");
}
#[test]
fn test_iv_estimation_simple() {
let n = 500;
let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 % 2.0)]).collect();
let d: Vec<f64> = z.iter().map(|zi| zi[0] + 0.5).collect();
let y: Vec<f64> = d.iter().map(|di| 2.0 * di + 1.0).collect();
let mut iv = InstrumentalVariables::new(1, 1);
iv.fit_2sls(&y, &d, &z);
assert!(
(iv.second_stage - 2.0).abs() < 0.5,
"IV est = {}",
iv.second_stage
);
}
#[test]
fn test_iv_first_stage_f_stat() {
let n = 200;
let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 / n as f64)]).collect();
let d: Vec<f64> = z.iter().map(|zi| 2.0 * zi[0]).collect();
let y: Vec<f64> = d.iter().map(|di| di + 1.0).collect();
let mut iv = InstrumentalVariables::new(1, 1);
iv.fit_2sls(&y, &d, &z);
let f = iv.first_stage_f_stat(&y, &d, &z);
assert!(
f > 10.0,
"F-stat should be large for strong instrument, got {f}"
);
}
#[test]
fn test_pearson_correlation_perfect() {
let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, 2.0 * i as f64]).collect();
let r = pearson_correlation(&data, 0, 1);
assert!((r - 1.0).abs() < 1e-10);
}
#[test]
fn test_pearson_correlation_zero() {
let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 1.0]).collect();
let r = pearson_correlation(&data, 0, 1);
assert!(r.abs() < 1e-10, "r={r}");
}
#[test]
fn test_partial_correlation_returns_in_range() {
let data: Vec<Vec<f64>> = (0..50)
.map(|i| vec![i as f64, 2.0 * i as f64 + 1.0, i as f64 * 0.5])
.collect();
let r = partial_correlation(&data, 0, 1, &[2]);
assert!((-1.0..=1.0).contains(&r));
}
#[test]
fn test_fisher_z_test_high_correlation() {
let r = 0.0; let p = fisher_z_test(r, 100, 0);
assert!(p > 0.05, "Should not reject independence for r=0");
}
#[test]
fn test_causal_discovery_skeleton_independent() {
let data: Vec<Vec<f64>> = (0..100)
.map(|i| vec![(i as f64).sin(), (i as f64 * 2.3 + 1.0).cos()])
.collect();
let mut cd = CausalDiscovery::new(2, 0.01);
cd.learn_skeleton(&data);
assert!(cd.n_vars == 2);
}
#[test]
fn test_subsets_correctness() {
let v = vec![0, 1, 2];
let subs = subsets(&v, 2);
assert_eq!(subs.len(), 3);
}
#[test]
fn test_subsets_empty() {
let v = vec![0, 1];
let subs = subsets(&v, 0);
assert_eq!(subs.len(), 1);
assert!(subs[0].is_empty());
}
#[test]
fn test_counterfactual_simple_chain() {
let mut scm = StructuralCausalModel::new(2);
scm.add_edge(0, 1, 2.0);
scm.noise_std = vec![1.0, 1.0];
let query = CounterfactualQuery::new(scm);
let obs = vec![Some(1.0), Some(2.0)];
let cf = query.counterfactual(&obs, 0, 3.0, 1);
assert!((cf - 6.0).abs() < 1e-10, "cf={cf}");
}
#[test]
fn test_counterfactual_intercept() {
let mut scm = StructuralCausalModel::new(2);
scm.intercepts[1] = 5.0;
scm.noise_std = vec![1.0, 1.0];
let query = CounterfactualQuery::new(scm);
let obs = vec![Some(0.0), Some(7.0)];
let cf = query.counterfactual(&obs, 0, 10.0, 1);
assert!((cf - 7.0).abs() < 1e-10, "cf={cf}");
}
#[test]
fn test_sample_covariance_diagonal() {
let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 0.0]).collect();
let cov = sample_covariance(&data);
assert!(cov[0] > 0.0);
assert!(cov[3].abs() < 1e-10); }
#[test]
fn test_sample_covariance_symmetric() {
let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, (i as f64).sin()]).collect();
let cov = sample_covariance(&data);
assert!((cov[1] - cov[2]).abs() < 1e-12); }
#[test]
fn test_full_scm_pipeline() {
let mut scm = StructuralCausalModel::new(3);
scm.add_edge(0, 1, 1.0); scm.add_edge(0, 2, 1.0); scm.add_edge(1, 2, 2.0); scm.noise_std = vec![1.0, 0.5, 0.5];
let noise_samples: Vec<Vec<f64>> = (0..500)
.map(|i| {
let u = (i as f64 * 0.01).sin();
let x = (i as f64 * 0.013).cos();
let y = (i as f64 * 0.017).sin();
vec![u, x, y]
})
.collect();
let ace0 = scm.average_causal_effect(1, 0.0, 2, &noise_samples);
let ace1 = scm.average_causal_effect(1, 1.0, 2, &noise_samples);
let diff = ace1 - ace0;
assert!((diff - 2.0).abs() < 0.1, "ACE diff = {diff}");
}
}