causal_hub/estimators/parameters/sufficient_statistics/table/
gaussian.rs1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5 datasets::{Dataset, GaussTable, GaussWtdTable},
6 estimators::{CSSEstimator, ParCSSEstimator, SSE},
7 models::GaussCPDS,
8 types::{AXIS_CHUNK_LENGTH, Set},
9};
10
11impl SSE<'_, GaussTable> {
12 fn fit(d: ArrayView2<f64>, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
13 let mut d_x = Array::zeros((d.nrows(), x.len()));
15 for (i, &j) in x.iter().enumerate() {
16 d_x.column_mut(i).assign(&d.column(j));
17 }
18 let mu_x = d_x.mean_axis(Axis(0)).unwrap();
20
21 let mut d_z = Array::zeros((d.nrows(), z.len()));
23 for (i, &j) in z.iter().enumerate() {
24 d_z.column_mut(i).assign(&d.column(j));
25 }
26 let mu_z = d_z.mean_axis(Axis(0)).unwrap();
28
29 let m_xx = d_x.t().dot(&d_x);
31 let m_xz = d_x.t().dot(&d_z);
32 let m_zz = d_z.t().dot(&d_z);
33
34 let n = d.nrows() as f64;
36
37 GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
39 }
40}
41
42impl CSSEstimator<GaussCPDS> for SSE<'_, GaussTable> {
43 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
44 assert!(
46 x.is_disjoint(z),
47 "Variables and conditioning variables must be disjoint."
48 );
49 let d = self.dataset.values();
51 Self::fit(d.view(), x, z)
53 }
54}
55
56impl ParCSSEstimator<GaussCPDS> for SSE<'_, GaussTable> {
57 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
58 assert!(
60 x.is_disjoint(z),
61 "Variables and conditioning variables must be disjoint."
62 );
63
64 let s_xz = {
66 let n = 0.;
67 let mu_x = Array::zeros(x.len());
68 let mu_z = Array::zeros(z.len());
69 let m_xx = Array::zeros((x.len(), x.len()));
70 let m_xz = Array::zeros((x.len(), z.len()));
71 let m_zz = Array::zeros((z.len(), z.len()));
72 GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
73 };
74
75 let d = self.dataset.values();
77
78 d.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
80 .into_par_iter()
81 .map(|d| Self::fit(d, x, z))
83 .fold(|| s_xz.clone(), |a, b| a + b)
85 .reduce(|| s_xz.clone(), |a, b| a + b)
86 }
87}
88
89impl SSE<'_, GaussWtdTable> {
90 fn fit(
91 d: ArrayView2<f64>,
92 norm_w: ArrayView2<f64>,
93 sum_w: f64,
94 x: &Set<usize>,
95 z: &Set<usize>,
96 ) -> GaussCPDS {
97 let mut d_x = Array::zeros((d.nrows(), x.len()));
99 for (i, &j) in x.iter().enumerate() {
100 d_x.column_mut(i).assign(&d.column(j));
101 }
102 let mu_x = (&norm_w * &d_x).mean_axis(Axis(0)).unwrap();
104
105 let mut d_z = Array::zeros((d.nrows(), z.len()));
107 for (i, &j) in z.iter().enumerate() {
108 d_z.column_mut(i).assign(&d.column(j));
109 }
110 let mu_z = (&norm_w * &d_z).mean_axis(Axis(0)).unwrap();
112
113 let sqrt_w = norm_w.mapv(f64::sqrt);
115 let d_sqrt_w_x = &sqrt_w * &d_x;
116 let d_sqrt_w_z = &sqrt_w * &d_z;
117
118 let m_xx = d_sqrt_w_x.t().dot(&d_sqrt_w_x);
120 let m_xz = d_sqrt_w_x.t().dot(&d_sqrt_w_z);
121 let m_zz = d_sqrt_w_z.t().dot(&d_sqrt_w_z);
122
123 let n = sum_w;
125
126 GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
128 }
129}
130
131impl CSSEstimator<GaussCPDS> for SSE<'_, GaussWtdTable> {
132 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
133 assert!(
135 x.is_disjoint(z),
136 "Variables and conditioning variables must be disjoint."
137 );
138
139 let d = self.dataset.values().values();
141 let w = self.dataset.weights();
143 let sum_w = w.sum();
145 let w = w / sum_w;
147 let w = w.insert_axis(Axis(1));
149
150 Self::fit(d.view(), w.view(), sum_w, x, z)
152 }
153}
154
155impl ParCSSEstimator<GaussCPDS> for SSE<'_, GaussWtdTable> {
156 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
157 assert!(
159 x.is_disjoint(z),
160 "Variables and conditioning variables must be disjoint."
161 );
162
163 let s_xz = {
165 let n = 0.;
166 let mu_x = Array::zeros(x.len());
167 let mu_z = Array::zeros(z.len());
168 let m_xx = Array::zeros((x.len(), x.len()));
169 let m_xz = Array::zeros((x.len(), z.len()));
170 let m_zz = Array::zeros((z.len(), z.len()));
171 GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
172 };
173
174 let values = self.dataset.values().values();
176 let weights = self.dataset.weights();
178
179 let sum_w: f64 = weights.par_iter().sum();
181 let weights = {
183 let mut weights = weights.clone();
185 weights
187 .axis_chunks_iter_mut(Axis(0), AXIS_CHUNK_LENGTH)
188 .into_par_iter()
189 .for_each(|mut w| w /= sum_w);
190 weights
192 };
193 let weights = weights.insert_axis(Axis(1));
195
196 values
198 .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
199 .into_par_iter()
200 .zip(weights.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH))
201 .map(|(d, w)| Self::fit(d, w, sum_w, x, z))
203 .fold(|| s_xz.clone(), |a, b| a + b)
205 .reduce(|| s_xz.clone(), |a, b| a + b)
206 }
207}