1#![allow(clippy::needless_range_loop)]
2#![allow(dead_code)]
18#![allow(clippy::too_many_arguments)]
19
20use std::collections::{HashMap, HashSet, VecDeque};
21
22#[derive(Debug, Clone)]
31pub struct CausalGraph {
32 pub n_nodes: usize,
34 pub parents: Vec<Vec<usize>>,
36 pub children: Vec<Vec<usize>>,
38 pub names: Vec<String>,
40}
41
42impl CausalGraph {
43 pub fn new(n: usize) -> Self {
48 Self {
49 n_nodes: n,
50 parents: vec![vec![]; n],
51 children: vec![vec![]; n],
52 names: (0..n).map(|i| format!("X{i}")).collect(),
53 }
54 }
55
56 pub fn set_names(&mut self, names: &[&str]) {
61 assert_eq!(names.len(), self.n_nodes);
62 self.names = names.iter().map(|s| s.to_string()).collect();
63 }
64
65 pub fn add_edge(&mut self, from: usize, to: usize) {
70 assert!(
71 from < self.n_nodes && to < self.n_nodes,
72 "node index out of bounds"
73 );
74 assert!(
75 !self.creates_cycle(from, to),
76 "edge {from}→{to} would create a cycle"
77 );
78 if !self.children[from].contains(&to) {
79 self.children[from].push(to);
80 self.parents[to].push(from);
81 }
82 }
83
84 pub fn creates_cycle(&self, from: usize, to: usize) -> bool {
86 let mut visited = vec![false; self.n_nodes];
88 let mut stack = vec![to];
89 while let Some(node) = stack.pop() {
90 if node == from {
91 return true;
92 }
93 if !visited[node] {
94 visited[node] = true;
95 for &child in &self.children[node] {
96 stack.push(child);
97 }
98 }
99 }
100 false
101 }
102
103 pub fn topological_sort(&self) -> Option<Vec<usize>> {
108 let mut in_degree: Vec<usize> = self.parents.iter().map(|p| p.len()).collect();
109 let mut queue: VecDeque<usize> = (0..self.n_nodes).filter(|&v| in_degree[v] == 0).collect();
110 let mut order = Vec::with_capacity(self.n_nodes);
111 while let Some(v) = queue.pop_front() {
112 order.push(v);
113 for &child in &self.children[v] {
114 in_degree[child] -= 1;
115 if in_degree[child] == 0 {
116 queue.push_back(child);
117 }
118 }
119 }
120 if order.len() == self.n_nodes {
121 Some(order)
122 } else {
123 None
124 }
125 }
126
127 pub fn ancestors(&self, v: usize) -> HashSet<usize> {
129 let mut anc = HashSet::new();
130 let mut stack = vec![v];
131 while let Some(node) = stack.pop() {
132 for &p in &self.parents[node] {
133 if anc.insert(p) {
134 stack.push(p);
135 }
136 }
137 }
138 anc
139 }
140
141 pub fn descendants(&self, v: usize) -> HashSet<usize> {
143 let mut desc = HashSet::new();
144 let mut stack = vec![v];
145 while let Some(node) = stack.pop() {
146 for &c in &self.children[node] {
147 if desc.insert(c) {
148 stack.push(c);
149 }
150 }
151 }
152 desc
153 }
154
155 pub fn d_separated(&self, x: &[usize], y: &[usize], z: &[usize]) -> bool {
159 let z_set: HashSet<usize> = z.iter().copied().collect();
160 let y_set: HashSet<usize> = y.iter().copied().collect();
161
162 let mut z_ancestors: HashSet<usize> = z_set.clone();
164 for &zv in z {
165 z_ancestors.extend(self.ancestors(zv));
166 }
167
168 let mut visited: HashSet<(usize, bool)> = HashSet::new();
170 let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
171
172 for &xv in x {
173 queue.push_back((xv, true)); queue.push_back((xv, false)); }
177
178 while let Some((node, via_child)) = queue.pop_front() {
179 if visited.contains(&(node, via_child)) {
180 continue;
181 }
182 visited.insert((node, via_child));
183
184 if y_set.contains(&node) {
185 return false; }
187
188 let in_z = z_set.contains(&node);
189 let in_z_anc = z_ancestors.contains(&node);
190
191 if via_child && !in_z {
192 for &p in &self.parents[node] {
195 queue.push_back((p, true));
196 }
197 for &c in &self.children[node] {
198 queue.push_back((c, false));
199 }
200 } else if !via_child {
201 if !in_z {
203 for &c in &self.children[node] {
205 queue.push_back((c, false));
206 }
207 }
208 if in_z_anc {
209 for &p in &self.parents[node] {
211 queue.push_back((p, true));
212 }
213 }
214 }
215 }
216 true
217 }
218
219 pub fn markov_blanket(&self, v: usize) -> HashSet<usize> {
221 let mut blanket = HashSet::new();
222 for &p in &self.parents[v] {
223 blanket.insert(p);
224 }
225 for &c in &self.children[v] {
226 blanket.insert(c);
227 for &cp in &self.parents[c] {
228 if cp != v {
229 blanket.insert(cp);
230 }
231 }
232 }
233 blanket
234 }
235
236 pub fn is_acyclic(&self) -> bool {
238 self.topological_sort().is_some()
239 }
240}
241
242#[derive(Debug, Clone)]
253pub struct StructuralCausalModel {
254 pub graph: CausalGraph,
256 pub coefficients: Vec<Vec<f64>>,
259 pub noise_std: Vec<f64>,
261 pub intercepts: Vec<f64>,
263}
264
265impl StructuralCausalModel {
266 pub fn new(n: usize) -> Self {
270 Self {
271 graph: CausalGraph::new(n),
272 coefficients: vec![vec![]; n],
273 noise_std: vec![1.0; n],
274 intercepts: vec![0.0; n],
275 }
276 }
277
278 pub fn add_edge(&mut self, from: usize, to: usize, coeff: f64) {
285 self.graph.add_edge(from, to);
286 self.coefficients[to].push(coeff);
288 }
289
290 pub fn set_noise(&mut self, v: usize, std: f64) {
292 self.noise_std[v] = std;
293 }
294
295 pub fn set_intercept(&mut self, v: usize, intercept: f64) {
297 self.intercepts[v] = intercept;
298 }
299
300 pub fn sample_with_noise(&self, noise: &[f64]) -> Vec<f64> {
307 let n = self.graph.n_nodes;
308 let order = self
309 .graph
310 .topological_sort()
311 .expect("SCM graph must be acyclic");
312 let mut x = vec![0.0_f64; n];
313 for &v in &order {
314 let val: f64 = self.intercepts[v]
315 + self.graph.parents[v]
316 .iter()
317 .zip(self.coefficients[v].iter())
318 .map(|(&p, &c)| c * x[p])
319 .sum::<f64>()
320 + self.noise_std[v] * noise[v];
321 x[v] = val;
322 }
323 x
324 }
325
326 pub fn intervene(&self, target: usize, val: f64) -> Self {
331 let mut scm = self.clone();
332 let parents = scm.graph.parents[target].clone();
334 for &p in &parents {
335 scm.graph.children[p].retain(|&c| c != target);
336 }
337 scm.graph.parents[target].clear();
338 scm.coefficients[target].clear();
339 scm.noise_std[target] = 0.0;
340 scm.intercepts[target] = val;
341 scm
342 }
343
344 pub fn average_causal_effect(
353 &self,
354 cause: usize,
355 val: f64,
356 effect: usize,
357 noise_samples: &[Vec<f64>],
358 ) -> f64 {
359 let intervened = self.intervene(cause, val);
360 let mean: f64 = noise_samples
361 .iter()
362 .map(|noise| intervened.sample_with_noise(noise)[effect])
363 .sum::<f64>()
364 / noise_samples.len() as f64;
365 mean
366 }
367
368 pub fn total_effect_linear(&self, cause: usize, effect: usize) -> f64 {
373 let mut total = 0.0_f64;
375 let mut stack: Vec<(usize, f64)> = vec![(cause, 1.0)];
377 while let Some((node, prod)) = stack.pop() {
378 if node == effect && node != cause {
379 total += prod;
380 }
381 for (k, &child) in self.graph.children[node].iter().enumerate() {
382 if let Some(idx) = self.graph.parents[child].iter().position(|&p| p == node) {
384 let coeff = self.coefficients[child][idx];
385 let _ = k; stack.push((child, prod * coeff));
387 }
388 }
389 }
390 total
391 }
392}
393
394#[derive(Debug, Clone)]
405pub struct BackdoorCriterion {
406 pub graph: CausalGraph,
408}
409
410impl BackdoorCriterion {
411 pub fn new(graph: CausalGraph) -> Self {
413 Self { graph }
414 }
415
416 pub fn check(&self, treatment: usize, outcome: usize, adjustment_set: &[usize]) -> bool {
421 let desc_treatment = self.graph.descendants(treatment);
422
423 for &z in adjustment_set {
425 if desc_treatment.contains(&z) {
426 return false;
427 }
428 }
429
430 let mut modified = self.graph.clone();
435 let children_of_treatment = modified.graph_children_of(treatment);
437 for &c in &children_of_treatment {
438 modified.parents[c].retain(|&p| p != treatment);
439 }
440 modified.children[treatment].clear();
441
442 modified.d_separated(&[treatment], &[outcome], adjustment_set)
443 }
444
445 pub fn adjusted_effect(
457 data_x: &[f64],
458 data_y: &[f64],
459 data_z: &[f64],
460 x_val: f64,
461 _tolerance: f64,
462 ) -> f64 {
463 let n = data_x.len();
464 assert_eq!(data_y.len(), n);
465 assert_eq!(data_z.len(), n);
466
467 let n_strata = 5usize;
469 let mut sorted_z = data_z.to_vec();
470 sorted_z.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
471 let quantiles: Vec<f64> = (1..n_strata)
472 .map(|i| sorted_z[(i * n) / n_strata])
473 .collect();
474
475 let stratum_of = |z: f64| -> usize {
476 quantiles
477 .iter()
478 .position(|&q| z < q)
479 .unwrap_or(n_strata - 1)
480 };
481
482 let mut stratum_sums_y = vec![0.0_f64; n_strata];
484 let mut stratum_counts = vec![0usize; n_strata];
485 let mut stratum_counts_near_x = vec![0usize; n_strata];
486 let mut stratum_y_near_x = vec![0.0_f64; n_strata];
487
488 let bandwidth = {
489 let mut xs = data_x.to_vec();
490 xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491 let iqr = xs[3 * n / 4] - xs[n / 4];
492 iqr.max(0.1) * 0.5
493 };
494
495 for i in 0..n {
496 let s = stratum_of(data_z[i]);
497 stratum_sums_y[s] += data_y[i];
498 stratum_counts[s] += 1;
499 if (data_x[i] - x_val).abs() < bandwidth {
500 stratum_y_near_x[s] += data_y[i];
501 stratum_counts_near_x[s] += 1;
502 }
503 }
504
505 let mut total = 0.0_f64;
506 for s in 0..n_strata {
507 if stratum_counts[s] == 0 {
508 continue;
509 }
510 let p_z = stratum_counts[s] as f64 / n as f64;
511 let e_y_xz = if stratum_counts_near_x[s] > 0 {
512 stratum_y_near_x[s] / stratum_counts_near_x[s] as f64
513 } else {
514 stratum_sums_y[s] / stratum_counts[s] as f64
515 };
516 total += e_y_xz * p_z;
517 }
518 total
519 }
520}
521
522trait GraphChildrenOf {
524 fn graph_children_of(&self, v: usize) -> Vec<usize>;
525}
526
527impl GraphChildrenOf for CausalGraph {
528 fn graph_children_of(&self, v: usize) -> Vec<usize> {
529 self.children[v].clone()
530 }
531}
532
533#[derive(Debug, Clone)]
542pub struct FrontdoorCriterion {
543 pub graph: CausalGraph,
545}
546
547impl FrontdoorCriterion {
548 pub fn new(graph: CausalGraph) -> Self {
550 Self { graph }
551 }
552
553 pub fn check(&self, treatment: usize, outcome: usize, mediator_set: &[usize]) -> bool {
561 let med_set: HashSet<usize> = mediator_set.iter().copied().collect();
562
563 if !self.intercepts_all_paths(treatment, outcome, &med_set) {
565 return false;
566 }
567
568 for &m in mediator_set {
572 if !self.graph.d_separated(&[treatment], &[m], &[treatment]) {
573 let x_anc = self.graph.ancestors(treatment);
575 let m_anc = self.graph.ancestors(m);
576 let _ = (x_anc, m_anc);
579 }
580 }
581
582 for &m in mediator_set {
584 if !self.graph.d_separated(&[m], &[outcome], &[treatment]) {
585 return false;
586 }
587 }
588
589 true
590 }
591
592 fn intercepts_all_paths(&self, src: usize, dst: usize, med_set: &HashSet<usize>) -> bool {
594 let mut stack: Vec<(usize, Vec<usize>)> = vec![(src, vec![src])];
596 while let Some((node, path)) = stack.pop() {
597 if node == dst {
598 let on_path = path[1..].iter().any(|v| med_set.contains(v));
600 if !on_path {
601 return false;
602 }
603 continue;
604 }
605 for &child in &self.graph.children[node] {
606 if !path.contains(&child) {
607 let mut new_path = path.clone();
608 new_path.push(child);
609 stack.push((child, new_path));
610 }
611 }
612 }
613 true
614 }
615
616 pub fn adjusted_effect(data_x: &[f64], data_m: &[f64], data_y: &[f64], x_val: f64) -> f64 {
622 let n = data_x.len();
623 let bandwidth = {
624 let mut xs = data_x.to_vec();
625 xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
626 let iqr = xs[3 * n / 4] - xs[n / 4];
627 iqr.max(0.1) * 0.4
628 };
629
630 let (mut sum_m, mut w_sum) = (0.0_f64, 0.0_f64);
632 for i in 0..n {
633 let w = gaussian_kernel((data_x[i] - x_val) / bandwidth);
634 sum_m += w * data_m[i];
635 w_sum += w;
636 }
637 let e_m_given_x = if w_sum > 1e-12 { sum_m / w_sum } else { 0.0 };
638
639 let bw_m = {
641 let mut ms = data_m.to_vec();
642 ms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
643 let iqr = ms[3 * n / 4] - ms[n / 4];
644 iqr.max(0.1) * 0.4
645 };
646 let (mut sum_y, mut wy_sum) = (0.0_f64, 0.0_f64);
647 for i in 0..n {
648 let w = gaussian_kernel((data_m[i] - e_m_given_x) / bw_m);
649 sum_y += w * data_y[i];
650 wy_sum += w;
651 }
652 if wy_sum > 1e-12 { sum_y / wy_sum } else { 0.0 }
653 }
654}
655
656#[allow(dead_code)]
658fn gaussian_kernel(u: f64) -> f64 {
659 (-0.5 * u * u).exp()
660}
661
662#[derive(Debug, Clone)]
671pub struct PropensityScoreMatching {
672 pub weights: Vec<f64>,
674 pub n_covariates: usize,
676}
677
678impl PropensityScoreMatching {
679 pub fn new(n_covariates: usize) -> Self {
684 Self {
685 weights: vec![0.0; n_covariates + 1],
686 n_covariates,
687 }
688 }
689
690 pub fn fit(&mut self, covariates: &[Vec<f64>], treatment: &[f64], lr: f64, n_iter: usize) {
698 let n = covariates.len();
699 assert_eq!(treatment.len(), n);
700 for _ in 0..n_iter {
701 let mut grad = vec![0.0_f64; self.n_covariates + 1];
702 for i in 0..n {
703 let p = self.predict_one(&covariates[i]);
704 let err = p - treatment[i];
705 grad[0] += err; for j in 0..self.n_covariates {
707 grad[j + 1] += err * covariates[i][j];
708 }
709 }
710 for k in 0..self.weights.len() {
711 self.weights[k] -= lr * grad[k] / n as f64;
712 }
713 }
714 }
715
716 pub fn predict_one(&self, x: &[f64]) -> f64 {
718 let logit: f64 = self.weights[0]
719 + x.iter()
720 .zip(self.weights[1..].iter())
721 .map(|(xi, wi)| xi * wi)
722 .sum::<f64>();
723 sigmoid(logit)
724 }
725
726 pub fn predict(&self, covariates: &[Vec<f64>]) -> Vec<f64> {
728 covariates.iter().map(|x| self.predict_one(x)).collect()
729 }
730
731 pub fn estimate_ate(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
735 let n = covariates.len();
736 let (mut sum1, mut w1, mut sum0, mut w0) = (0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64);
737 for i in 0..n {
738 let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
739 if treatment[i] > 0.5 {
740 sum1 += outcome[i] / e;
741 w1 += 1.0 / e;
742 } else {
743 sum0 += outcome[i] / (1.0 - e);
744 w0 += 1.0 / (1.0 - e);
745 }
746 }
747 let ey1 = if w1 > 0.0 { sum1 / w1 } else { 0.0 };
748 let ey0 = if w0 > 0.0 { sum0 / w0 } else { 0.0 };
749 ey1 - ey0
750 }
751
752 pub fn estimate_att(&self, covariates: &[Vec<f64>], treatment: &[f64], outcome: &[f64]) -> f64 {
756 let n = covariates.len();
757 let mut treated_y: Vec<f64> = Vec::new();
758 let mut control_y: Vec<f64> = Vec::new();
759 let mut control_ps: Vec<f64> = Vec::new();
760
761 for i in 0..n {
762 let e = self.predict_one(&covariates[i]).clamp(1e-6, 1.0 - 1e-6);
763 if treatment[i] > 0.5 {
764 treated_y.push(outcome[i]);
765 } else {
766 control_y.push(outcome[i]);
767 control_ps.push(e / (1.0 - e)); }
769 }
770
771 if treated_y.is_empty() || control_y.is_empty() {
772 return 0.0;
773 }
774
775 let mean_treated = treated_y.iter().sum::<f64>() / treated_y.len() as f64;
776 let total_weight: f64 = control_ps.iter().sum();
777 let mean_control = if total_weight > 0.0 {
778 control_y
779 .iter()
780 .zip(control_ps.iter())
781 .map(|(y, w)| y * w)
782 .sum::<f64>()
783 / total_weight
784 } else {
785 control_y.iter().sum::<f64>() / control_y.len() as f64
786 };
787
788 mean_treated - mean_control
789 }
790}
791
792fn sigmoid(x: f64) -> f64 {
794 1.0 / (1.0 + (-x).exp())
795}
796
797#[derive(Debug, Clone)]
809pub struct InstrumentalVariables {
810 pub n_endogenous: usize,
812 pub n_instruments: usize,
814 pub first_stage: Vec<f64>,
816 pub second_stage: f64,
818}
819
820impl InstrumentalVariables {
821 pub fn new(n_endogenous: usize, n_instruments: usize) -> Self {
823 Self {
824 n_endogenous,
825 n_instruments,
826 first_stage: vec![0.0; n_instruments + 1],
827 second_stage: 0.0,
828 }
829 }
830
831 pub fn fit_2sls(&mut self, y: &[f64], d: &[f64], z: &[Vec<f64>]) {
841 let n = y.len();
842 assert_eq!(d.len(), n);
843 assert_eq!(z.len(), n);
844
845 let n_inst = self.n_instruments;
848 let mut d_hat = vec![0.0_f64; n];
849
850 if n_inst == 1 {
851 let z_vec: Vec<f64> = z.iter().map(|row| row[0]).collect();
853 let mean_z = z_vec.iter().sum::<f64>() / n as f64;
854 let mean_d = d.iter().sum::<f64>() / n as f64;
855 let mean_y = y.iter().sum::<f64>() / n as f64;
856
857 let cov_dz: f64 = d
858 .iter()
859 .zip(z_vec.iter())
860 .map(|(di, zi)| (di - mean_d) * (zi - mean_z))
861 .sum::<f64>()
862 / n as f64;
863 let cov_yz: f64 = y
864 .iter()
865 .zip(z_vec.iter())
866 .map(|(yi, zi)| (yi - mean_y) * (zi - mean_z))
867 .sum::<f64>()
868 / n as f64;
869 let var_z: f64 = z_vec.iter().map(|zi| (zi - mean_z).powi(2)).sum::<f64>() / n as f64;
870
871 let alpha1 = if var_z.abs() > 1e-12 {
873 cov_dz / var_z
874 } else {
875 0.0
876 };
877 let alpha0 = mean_d - alpha1 * mean_z;
878 self.first_stage[0] = alpha0;
879 self.first_stage[1] = alpha1;
880
881 self.second_stage = if cov_dz.abs() > 1e-12 {
883 cov_yz / cov_dz
884 } else {
885 0.0
886 };
887
888 for i in 0..n {
889 d_hat[i] = alpha0 + alpha1 * z_vec[i];
890 }
891 } else {
892 let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
895 let mean_z0 = z0.iter().sum::<f64>() / n as f64;
896 let mean_d = d.iter().sum::<f64>() / n as f64;
897
898 let cov = z0
899 .iter()
900 .zip(d.iter())
901 .map(|(zi, di)| (zi - mean_z0) * (di - mean_d))
902 .sum::<f64>()
903 / n as f64;
904 let var_z0 = z0.iter().map(|zi| (zi - mean_z0).powi(2)).sum::<f64>() / n as f64;
905
906 let alpha1 = if var_z0 > 1e-12 { cov / var_z0 } else { 0.0 };
907 let alpha0 = mean_d - alpha1 * mean_z0;
908 self.first_stage[0] = alpha0;
909 self.first_stage[1] = alpha1;
910
911 for i in 0..n {
912 d_hat[i] = alpha0 + alpha1 * z0[i];
913 }
914
915 let mean_dhat = d_hat.iter().sum::<f64>() / n as f64;
917 let mean_y = y.iter().sum::<f64>() / n as f64;
918 let cov_ydhat: f64 = y
919 .iter()
920 .zip(d_hat.iter())
921 .map(|(yi, di)| (yi - mean_y) * (di - mean_dhat))
922 .sum::<f64>()
923 / n as f64;
924 let var_dhat: f64 =
925 d_hat.iter().map(|di| (di - mean_dhat).powi(2)).sum::<f64>() / n as f64;
926 self.second_stage = if var_dhat > 1e-12 {
927 cov_ydhat / var_dhat
928 } else {
929 0.0
930 };
931 }
932 }
933
934 pub fn first_stage_f_stat(&self, y: &[f64], d: &[f64], z: &[Vec<f64>]) -> f64 {
938 let n = y.len();
939 let z0: Vec<f64> = z.iter().map(|row| row[0]).collect();
940 let _mean_z0 = z0.iter().sum::<f64>() / n as f64;
941 let mean_d = d.iter().sum::<f64>() / n as f64;
942
943 let d_hat: Vec<f64> = z0
944 .iter()
945 .map(|zi| self.first_stage[0] + self.first_stage[1] * zi)
946 .collect();
947
948 let ss_res: f64 = d
949 .iter()
950 .zip(d_hat.iter())
951 .map(|(di, dh)| (di - dh).powi(2))
952 .sum();
953 let ss_tot: f64 = d.iter().map(|di| (di - mean_d).powi(2)).sum();
954
955 let r2 = 1.0 - ss_res / ss_tot.max(1e-12);
956 let k = 1.0_f64; let n_f = n as f64;
958 (r2 / k) / ((1.0 - r2) / (n_f - k - 1.0)).max(1e-12)
959 }
960
961 pub fn predict(&self, d_val: f64) -> f64 {
963 self.second_stage * d_val
964 }
965}
966
967#[derive(Debug, Clone)]
977pub struct CausalDiscovery {
978 pub n_vars: usize,
980 pub skeleton: Vec<Vec<bool>>,
982 pub directed: Vec<Vec<bool>>,
984 pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
986 pub alpha: f64,
988}
989
990impl CausalDiscovery {
991 pub fn new(n_vars: usize, alpha: f64) -> Self {
997 Self {
998 n_vars,
999 skeleton: vec![vec![true; n_vars]; n_vars],
1000 directed: vec![vec![false; n_vars]; n_vars],
1001 sep_sets: HashMap::new(),
1002 alpha,
1003 }
1004 }
1005
1006 pub fn learn_skeleton(&mut self, data: &[Vec<f64>]) {
1011 let n = self.n_vars;
1012
1013 for i in 0..n {
1015 self.skeleton[i][i] = false;
1016 }
1017
1018 for i in 0..n {
1020 for j in (i + 1)..n {
1021 let r = partial_correlation(data, i, j, &[]);
1022 let p = fisher_z_test(r, data.len(), 0);
1023 if p > self.alpha {
1024 self.skeleton[i][j] = false;
1025 self.skeleton[j][i] = false;
1026 self.sep_sets.insert((i, j), vec![]);
1027 self.sep_sets.insert((j, i), vec![]);
1028 }
1029 }
1030 }
1031
1032 for cond_size in 1..n.saturating_sub(1) {
1034 for i in 0..n {
1035 let adj_i: Vec<usize> = (0..n).filter(|&k| k != i && self.skeleton[i][k]).collect();
1036 for &j in &adj_i {
1037 if !self.skeleton[i][j] {
1038 continue;
1039 }
1040 let adj_minus_j: Vec<usize> =
1041 adj_i.iter().copied().filter(|&k| k != j).collect();
1042 if adj_minus_j.len() < cond_size {
1043 continue;
1044 }
1045 for cond_set in subsets(&adj_minus_j, cond_size) {
1047 let r = partial_correlation(data, i, j, &cond_set);
1048 let p = fisher_z_test(r, data.len(), cond_size);
1049 if p > self.alpha {
1050 self.skeleton[i][j] = false;
1051 self.skeleton[j][i] = false;
1052 self.sep_sets.insert((i, j), cond_set.clone());
1053 self.sep_sets.insert((j, i), cond_set);
1054 break;
1055 }
1056 }
1057 }
1058 }
1059 }
1060 }
1061
1062 pub fn orient_v_structures(&mut self) {
1067 let n = self.n_vars;
1068 for i in 0..n {
1069 for k in 0..n {
1070 if i == k || !self.skeleton[i][k] {
1071 continue;
1072 }
1073 for j in (i + 1)..n {
1074 if j == k || !self.skeleton[k][j] || self.skeleton[i][j] {
1075 continue;
1076 }
1077 let sep = self.sep_sets.get(&(i, j)).cloned().unwrap_or_default();
1079 if !sep.contains(&k) {
1080 self.directed[i][k] = true;
1082 self.directed[j][k] = true;
1083 self.skeleton[k][i] = false;
1084 self.skeleton[k][j] = false;
1085 }
1086 }
1087 }
1088 }
1089 }
1090
1091 pub fn apply_meek_rules(&mut self) {
1095 let n = self.n_vars;
1096 let mut changed = true;
1097 while changed {
1098 changed = false;
1099 for i in 0..n {
1101 for j in 0..n {
1102 if !self.directed[i][j] {
1103 continue;
1104 }
1105 for k in 0..n {
1106 if k == i || k == j {
1107 continue;
1108 }
1109 if self.skeleton[j][k]
1110 && !self.directed[j][k]
1111 && !self.directed[k][j]
1112 && !self.skeleton[i][k]
1113 {
1114 self.directed[j][k] = true;
1115 self.skeleton[k][j] = false;
1116 changed = true;
1117 }
1118 }
1119 }
1120 }
1121 for i in 0..n {
1123 for j in 0..n {
1124 if i == j || !self.skeleton[i][j] || self.directed[i][j] {
1125 continue;
1126 }
1127 for k in 0..n {
1128 if k == i || k == j {
1129 continue;
1130 }
1131 if self.directed[i][k] && self.directed[k][j] {
1132 self.directed[i][j] = true;
1133 self.skeleton[j][i] = false;
1134 changed = true;
1135 }
1136 }
1137 }
1138 }
1139 }
1140 }
1141
1142 pub fn run(&mut self, data: &[Vec<f64>]) {
1144 self.learn_skeleton(data);
1145 self.orient_v_structures();
1146 self.apply_meek_rules();
1147 }
1148}
1149
1150pub fn partial_correlation(data: &[Vec<f64>], i: usize, j: usize, cond: &[usize]) -> f64 {
1154 if cond.is_empty() {
1155 return pearson_correlation(data, i, j);
1156 }
1157 if cond.len() == 1 {
1158 let k = cond[0];
1159 let r_ij = pearson_correlation(data, i, j);
1160 let r_ik = pearson_correlation(data, i, k);
1161 let r_jk = pearson_correlation(data, j, k);
1162 let denom = ((1.0 - r_ik * r_ik) * (1.0 - r_jk * r_jk)).sqrt();
1163 if denom < 1e-12 {
1164 return 0.0;
1165 }
1166 return (r_ij - r_ik * r_jk) / denom;
1167 }
1168 let last = cond[cond.len() - 1];
1171 let rest = &cond[..cond.len() - 1];
1172 let r_ij_rest = partial_correlation(data, i, j, rest);
1173 let r_ik_rest = partial_correlation(data, i, last, rest);
1174 let r_jk_rest = partial_correlation(data, j, last, rest);
1175 let denom = ((1.0 - r_ik_rest * r_ik_rest) * (1.0 - r_jk_rest * r_jk_rest)).sqrt();
1176 if denom < 1e-12 {
1177 return 0.0;
1178 }
1179 (r_ij_rest - r_ik_rest * r_jk_rest) / denom
1180}
1181
1182pub fn pearson_correlation(data: &[Vec<f64>], i: usize, j: usize) -> f64 {
1184 let n = data.len() as f64;
1185 let mean_i = data.iter().map(|row| row[i]).sum::<f64>() / n;
1186 let mean_j = data.iter().map(|row| row[j]).sum::<f64>() / n;
1187 let cov: f64 = data
1188 .iter()
1189 .map(|row| (row[i] - mean_i) * (row[j] - mean_j))
1190 .sum::<f64>()
1191 / n;
1192 let std_i = (data
1193 .iter()
1194 .map(|row| (row[i] - mean_i).powi(2))
1195 .sum::<f64>()
1196 / n)
1197 .sqrt();
1198 let std_j = (data
1199 .iter()
1200 .map(|row| (row[j] - mean_j).powi(2))
1201 .sum::<f64>()
1202 / n)
1203 .sqrt();
1204 if std_i < 1e-12 || std_j < 1e-12 {
1205 return 0.0;
1206 }
1207 (cov / (std_i * std_j)).clamp(-1.0, 1.0)
1208}
1209
1210pub fn fisher_z_test(r: f64, n: usize, cond_size: usize) -> f64 {
1214 let r = r.clamp(-0.9999, 0.9999);
1215 let z = 0.5 * ((1.0 + r) / (1.0 - r)).ln();
1216 let se = 1.0 / ((n as f64 - cond_size as f64 - 3.0).max(1.0)).sqrt();
1217 let stat = (z / se).abs();
1218 2.0 * (1.0 - standard_normal_cdf(stat))
1220}
1221
1222fn standard_normal_cdf(x: f64) -> f64 {
1223 let t = 1.0 / (1.0 + 0.2316419 * x.abs());
1224 let poly = t
1225 * (0.319_381_530
1226 + t * (-0.356_563_782
1227 + t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
1228 let pdf = (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt();
1229 let cdf = 1.0 - pdf * poly;
1230 if x >= 0.0 { cdf } else { 1.0 - cdf }
1231}
1232
1233fn subsets(set: &[usize], k: usize) -> Vec<Vec<usize>> {
1235 if k == 0 {
1236 return vec![vec![]];
1237 }
1238 if set.len() < k {
1239 return vec![];
1240 }
1241 let mut result = Vec::new();
1242 for (i, &v) in set.iter().enumerate() {
1243 let rest = subsets(&set[(i + 1)..], k - 1);
1244 for mut subset in rest {
1245 subset.insert(0, v);
1246 result.push(subset);
1247 }
1248 }
1249 result
1250}
1251
1252#[derive(Debug, Clone)]
1260pub struct CounterfactualQuery {
1261 pub scm: StructuralCausalModel,
1263}
1264
1265impl CounterfactualQuery {
1266 pub fn new(scm: StructuralCausalModel) -> Self {
1268 Self { scm }
1269 }
1270
1271 pub fn query_do(
1279 &self,
1280 target: usize,
1281 x_val: f64,
1282 outcome: usize,
1283 noise_samples: &[Vec<f64>],
1284 ) -> f64 {
1285 self.scm
1286 .average_causal_effect(target, x_val, outcome, noise_samples)
1287 }
1288
1289 pub fn counterfactual(
1303 &self,
1304 obs: &[Option<f64>],
1305 target: usize,
1306 x_val: f64,
1307 outcome: usize,
1308 ) -> f64 {
1309 let n = self.scm.graph.n_nodes;
1310 assert_eq!(obs.len(), n);
1311
1312 let order = self
1315 .scm
1316 .graph
1317 .topological_sort()
1318 .expect("SCM must be acyclic");
1319 let mut x = vec![0.0_f64; n];
1320 let mut noise = vec![0.0_f64; n];
1321
1322 for &v in &order {
1323 if let Some(val) = obs[v] {
1324 x[v] = val;
1325 let pred: f64 = self.scm.graph.parents[v]
1326 .iter()
1327 .zip(self.scm.coefficients[v].iter())
1328 .map(|(&p, &c)| c * x[p])
1329 .sum::<f64>();
1330 let residual = val - self.scm.intercepts[v] - pred;
1331 noise[v] = if self.scm.noise_std[v].abs() > 1e-12 {
1332 residual / self.scm.noise_std[v]
1333 } else {
1334 0.0
1335 };
1336 } else {
1337 noise[v] = 0.0;
1339 let pred: f64 = self.scm.graph.parents[v]
1340 .iter()
1341 .zip(self.scm.coefficients[v].iter())
1342 .map(|(&p, &c)| c * x[p])
1343 .sum::<f64>();
1344 x[v] = self.scm.intercepts[v] + pred;
1345 }
1346 }
1347
1348 let intervened = self.scm.intervene(target, x_val);
1350
1351 intervened.sample_with_noise(&noise)[outcome]
1353 }
1354
1355 pub fn probability_of_necessity(
1359 &self,
1360 treatment: usize,
1361 outcome: usize,
1362 t_val: f64,
1363 t_counter: f64,
1364 threshold_y: f64,
1365 noise_samples: &[Vec<f64>],
1366 ) -> f64 {
1367 let mut count = 0;
1368 let mut denom = 0;
1369 for noise in noise_samples {
1370 let x_actual = self.scm.sample_with_noise(noise);
1372 if x_actual[treatment] < t_val - 0.5 || x_actual[outcome] < threshold_y {
1373 continue;
1374 }
1375 denom += 1;
1376 let counter_scm = self.scm.intervene(treatment, t_counter);
1378 let x_counter = counter_scm.sample_with_noise(noise);
1379 if x_counter[outcome] < threshold_y {
1380 count += 1;
1381 }
1382 }
1383 if denom == 0 {
1384 0.0
1385 } else {
1386 count as f64 / denom as f64
1387 }
1388 }
1389}
1390
1391pub fn sample_covariance(data: &[Vec<f64>]) -> Vec<f64> {
1399 let n = data.len();
1400 let p = data[0].len();
1401 let means: Vec<f64> = (0..p)
1402 .map(|j| data.iter().map(|row| row[j]).sum::<f64>() / n as f64)
1403 .collect();
1404 let mut cov = vec![0.0_f64; p * p];
1405 for i in 0..n {
1406 for j in 0..p {
1407 for k in j..p {
1408 cov[j * p + k] += (data[i][j] - means[j]) * (data[i][k] - means[k]);
1409 }
1410 }
1411 }
1412 for j in 0..p {
1413 for k in j..p {
1414 cov[j * p + k] /= (n - 1) as f64;
1415 cov[k * p + j] = cov[j * p + k];
1416 }
1417 }
1418 cov
1419}
1420
1421#[cfg(test)]
1426mod tests {
1427 use super::*;
1428
1429 fn simple_chain() -> CausalGraph {
1430 let mut g = CausalGraph::new(3);
1432 g.add_edge(0, 1);
1433 g.add_edge(1, 2);
1434 g
1435 }
1436
1437 fn fork_graph() -> CausalGraph {
1438 let mut g = CausalGraph::new(3);
1440 g.add_edge(0, 1);
1441 g.add_edge(0, 2);
1442 g
1443 }
1444
1445 fn collider_graph() -> CausalGraph {
1446 let mut g = CausalGraph::new(3);
1448 g.add_edge(0, 2);
1449 g.add_edge(1, 2);
1450 g
1451 }
1452
1453 #[test]
1456 fn test_topological_sort_chain() {
1457 let g = simple_chain();
1458 let order = g.topological_sort().unwrap();
1459 assert_eq!(order, vec![0, 1, 2]);
1460 }
1461
1462 #[test]
1463 fn test_topological_sort_fork() {
1464 let g = fork_graph();
1465 let order = g.topological_sort().unwrap();
1466 assert_eq!(order[0], 0); }
1468
1469 #[test]
1470 fn test_ancestors() {
1471 let g = simple_chain();
1472 let anc = g.ancestors(2);
1473 assert!(anc.contains(&0));
1474 assert!(anc.contains(&1));
1475 assert!(!anc.contains(&2));
1476 }
1477
1478 #[test]
1479 fn test_descendants() {
1480 let g = simple_chain();
1481 let desc = g.descendants(0);
1482 assert!(desc.contains(&1));
1483 assert!(desc.contains(&2));
1484 }
1485
1486 #[test]
1487 fn test_d_separation_chain_blocked_by_middle() {
1488 let g = simple_chain();
1490 assert!(g.d_separated(&[0], &[2], &[1]));
1491 }
1492
1493 #[test]
1494 fn test_d_separation_chain_not_blocked_empty() {
1495 let g = simple_chain();
1496 assert!(!g.d_separated(&[0], &[2], &[]));
1497 }
1498
1499 #[test]
1500 fn test_d_separation_fork() {
1501 let g = fork_graph();
1503 assert!(g.d_separated(&[1], &[2], &[0]));
1504 assert!(!g.d_separated(&[1], &[2], &[]));
1505 }
1506
1507 #[test]
1508 fn test_d_separation_collider_blocked_by_default() {
1509 let g = collider_graph();
1511 assert!(g.d_separated(&[0], &[1], &[]));
1512 }
1513
1514 #[test]
1515 fn test_d_separation_collider_opened_by_conditioning() {
1516 let g = collider_graph();
1518 assert!(!g.d_separated(&[0], &[1], &[2]));
1519 }
1520
1521 #[test]
1522 fn test_is_acyclic() {
1523 let g = simple_chain();
1524 assert!(g.is_acyclic());
1525 }
1526
1527 #[test]
1528 fn test_creates_cycle_detected() {
1529 let mut g = CausalGraph::new(3);
1530 g.add_edge(0, 1);
1531 g.add_edge(1, 2);
1532 assert!(g.creates_cycle(2, 0));
1533 }
1534
1535 #[test]
1536 fn test_markov_blanket() {
1537 let g = simple_chain();
1539 let blanket = g.markov_blanket(1);
1540 assert!(blanket.contains(&0));
1541 assert!(blanket.contains(&2));
1542 assert!(!blanket.contains(&1));
1543 }
1544
1545 #[test]
1548 fn test_scm_sample_basic() {
1549 let mut scm = StructuralCausalModel::new(2);
1551 scm.add_edge(0, 1, 2.0);
1552 scm.noise_std = vec![1.0, 0.0];
1553 let noise = vec![3.0, 0.0];
1554 let x = scm.sample_with_noise(&noise);
1555 assert!((x[0] - 3.0).abs() < 1e-10);
1556 assert!((x[1] - 6.0).abs() < 1e-10);
1557 }
1558
1559 #[test]
1560 fn test_scm_intervention_removes_parents() {
1561 let mut scm = StructuralCausalModel::new(2);
1562 scm.add_edge(0, 1, 2.0);
1563 let intervened = scm.intervene(1, 5.0);
1564 assert!(intervened.graph.parents[1].is_empty());
1565 let x = intervened.sample_with_noise(&[1.0, 0.0]);
1566 assert!((x[1] - 5.0).abs() < 1e-10);
1567 }
1568
1569 #[test]
1570 fn test_scm_total_effect_chain() {
1571 let mut scm = StructuralCausalModel::new(3);
1573 scm.add_edge(0, 1, 2.0);
1574 scm.add_edge(1, 2, 3.0);
1575 let effect = scm.total_effect_linear(0, 2);
1576 assert!((effect - 6.0).abs() < 1e-10);
1577 }
1578
1579 #[test]
1580 fn test_scm_no_effect_for_independent_vars() {
1581 let mut scm = StructuralCausalModel::new(3);
1582 scm.add_edge(0, 1, 1.0);
1583 let effect = scm.total_effect_linear(0, 2);
1584 assert!(effect.abs() < 1e-10);
1585 }
1586
1587 #[test]
1590 fn test_backdoor_check_valid_adjustment() {
1591 let mut g = CausalGraph::new(3); g.add_edge(0, 1); g.add_edge(1, 2); g.add_edge(0, 2); let bd = BackdoorCriterion::new(g);
1598 assert!(bd.check(1, 2, &[0]));
1600 }
1601
1602 #[test]
1603 fn test_backdoor_fails_if_descendant_in_set() {
1604 let mut g = CausalGraph::new(3); g.add_edge(0, 1);
1607 g.add_edge(1, 2);
1608 let bd = BackdoorCriterion::new(g);
1609 assert!(!bd.check(0, 2, &[1])); }
1611
1612 #[test]
1615 fn test_propensity_score_fit_and_predict() {
1616 let mut psm = PropensityScoreMatching::new(2);
1617 let covariates: Vec<Vec<f64>> = (0..100).map(|i| vec![(i as f64) / 100.0, 0.5]).collect();
1618 let treatment: Vec<f64> = covariates
1619 .iter()
1620 .map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
1621 .collect();
1622 psm.fit(&covariates, &treatment, 0.5, 200);
1623 let ps = psm.predict(&covariates);
1624 assert_eq!(ps.len(), 100);
1625 for p in &ps {
1627 assert!(*p > 0.0 && *p < 1.0);
1628 }
1629 }
1630
1631 #[test]
1632 fn test_ate_sign_positive() {
1633 let mut psm = PropensityScoreMatching::new(1);
1635 let n = 200;
1636 let covariates: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64) / n as f64]).collect();
1637 let treatment: Vec<f64> = covariates
1638 .iter()
1639 .map(|x| if x[0] > 0.5 { 1.0 } else { 0.0 })
1640 .collect();
1641 let outcome: Vec<f64> = covariates
1642 .iter()
1643 .zip(treatment.iter())
1644 .map(|(x, t)| x[0] + 2.0 * t + 0.1)
1645 .collect();
1646 psm.fit(&covariates, &treatment, 0.3, 300);
1647 let ate = psm.estimate_ate(&covariates, &treatment, &outcome);
1648 assert!(ate > 0.0, "ATE should be positive, got {ate}");
1649 }
1650
1651 #[test]
1654 fn test_iv_estimation_simple() {
1655 let n = 500;
1658 let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 % 2.0)]).collect();
1659 let d: Vec<f64> = z.iter().map(|zi| zi[0] + 0.5).collect();
1660 let y: Vec<f64> = d.iter().map(|di| 2.0 * di + 1.0).collect();
1661
1662 let mut iv = InstrumentalVariables::new(1, 1);
1663 iv.fit_2sls(&y, &d, &z);
1664 assert!(
1665 (iv.second_stage - 2.0).abs() < 0.5,
1666 "IV est = {}",
1667 iv.second_stage
1668 );
1669 }
1670
1671 #[test]
1672 fn test_iv_first_stage_f_stat() {
1673 let n = 200;
1674 let z: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 / n as f64)]).collect();
1675 let d: Vec<f64> = z.iter().map(|zi| 2.0 * zi[0]).collect();
1676 let y: Vec<f64> = d.iter().map(|di| di + 1.0).collect();
1677 let mut iv = InstrumentalVariables::new(1, 1);
1678 iv.fit_2sls(&y, &d, &z);
1679 let f = iv.first_stage_f_stat(&y, &d, &z);
1680 assert!(
1681 f > 10.0,
1682 "F-stat should be large for strong instrument, got {f}"
1683 );
1684 }
1685
1686 #[test]
1689 fn test_pearson_correlation_perfect() {
1690 let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, 2.0 * i as f64]).collect();
1691 let r = pearson_correlation(&data, 0, 1);
1692 assert!((r - 1.0).abs() < 1e-10);
1693 }
1694
1695 #[test]
1696 fn test_pearson_correlation_zero() {
1697 let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 1.0]).collect();
1699 let r = pearson_correlation(&data, 0, 1);
1700 assert!(r.abs() < 1e-10, "r={r}");
1701 }
1702
1703 #[test]
1704 fn test_partial_correlation_returns_in_range() {
1705 let data: Vec<Vec<f64>> = (0..50)
1706 .map(|i| vec![i as f64, 2.0 * i as f64 + 1.0, i as f64 * 0.5])
1707 .collect();
1708 let r = partial_correlation(&data, 0, 1, &[2]);
1709 assert!((-1.0..=1.0).contains(&r));
1710 }
1711
1712 #[test]
1713 fn test_fisher_z_test_high_correlation() {
1714 let r = 0.0; let p = fisher_z_test(r, 100, 0);
1716 assert!(p > 0.05, "Should not reject independence for r=0");
1717 }
1718
1719 #[test]
1720 fn test_causal_discovery_skeleton_independent() {
1721 let data: Vec<Vec<f64>> = (0..100)
1723 .map(|i| vec![(i as f64).sin(), (i as f64 * 2.3 + 1.0).cos()])
1724 .collect();
1725 let mut cd = CausalDiscovery::new(2, 0.01);
1726 cd.learn_skeleton(&data);
1727 assert!(cd.n_vars == 2);
1729 }
1730
1731 #[test]
1732 fn test_subsets_correctness() {
1733 let v = vec![0, 1, 2];
1734 let subs = subsets(&v, 2);
1735 assert_eq!(subs.len(), 3);
1736 }
1737
1738 #[test]
1739 fn test_subsets_empty() {
1740 let v = vec![0, 1];
1741 let subs = subsets(&v, 0);
1742 assert_eq!(subs.len(), 1);
1743 assert!(subs[0].is_empty());
1744 }
1745
1746 #[test]
1749 fn test_counterfactual_simple_chain() {
1750 let mut scm = StructuralCausalModel::new(2);
1753 scm.add_edge(0, 1, 2.0);
1754 scm.noise_std = vec![1.0, 1.0];
1755 let query = CounterfactualQuery::new(scm);
1756 let obs = vec![Some(1.0), Some(2.0)];
1757 let cf = query.counterfactual(&obs, 0, 3.0, 1);
1758 assert!((cf - 6.0).abs() < 1e-10, "cf={cf}");
1760 }
1761
1762 #[test]
1763 fn test_counterfactual_intercept() {
1764 let mut scm = StructuralCausalModel::new(2);
1767 scm.intercepts[1] = 5.0;
1768 scm.noise_std = vec![1.0, 1.0];
1769 let query = CounterfactualQuery::new(scm);
1770 let obs = vec![Some(0.0), Some(7.0)];
1771 let cf = query.counterfactual(&obs, 0, 10.0, 1);
1772 assert!((cf - 7.0).abs() < 1e-10, "cf={cf}");
1774 }
1775
1776 #[test]
1779 fn test_sample_covariance_diagonal() {
1780 let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, 0.0]).collect();
1781 let cov = sample_covariance(&data);
1782 assert!(cov[0] > 0.0);
1784 assert!(cov[3].abs() < 1e-10); }
1786
1787 #[test]
1788 fn test_sample_covariance_symmetric() {
1789 let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64, (i as f64).sin()]).collect();
1790 let cov = sample_covariance(&data);
1791 assert!((cov[1] - cov[2]).abs() < 1e-12); }
1793
1794 #[test]
1797 fn test_full_scm_pipeline() {
1798 let mut scm = StructuralCausalModel::new(3);
1802 scm.add_edge(0, 1, 1.0); scm.add_edge(0, 2, 1.0); scm.add_edge(1, 2, 2.0); scm.noise_std = vec![1.0, 0.5, 0.5];
1806
1807 let noise_samples: Vec<Vec<f64>> = (0..500)
1809 .map(|i| {
1810 let u = (i as f64 * 0.01).sin();
1811 let x = (i as f64 * 0.013).cos();
1812 let y = (i as f64 * 0.017).sin();
1813 vec![u, x, y]
1814 })
1815 .collect();
1816
1817 let ace0 = scm.average_causal_effect(1, 0.0, 2, &noise_samples);
1818 let ace1 = scm.average_causal_effect(1, 1.0, 2, &noise_samples);
1819 let diff = ace1 - ace0;
1820 assert!((diff - 2.0).abs() < 0.1, "ACE diff = {diff}");
1822 }
1823}