causal_hub/estimators/parameters/sufficient_statistics/table/
gaussian.rs

1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5    datasets::{Dataset, GaussTable, GaussWtdTable},
6    estimators::{CSSEstimator, ParCSSEstimator, SSE},
7    models::GaussCPDS,
8    types::{AXIS_CHUNK_LENGTH, Set},
9};
10
11impl SSE<'_, GaussTable> {
12    fn fit(d: ArrayView2<f64>, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
13        // Select the columns of the variables.
14        let mut d_x = Array::zeros((d.nrows(), x.len()));
15        for (i, &j) in x.iter().enumerate() {
16            d_x.column_mut(i).assign(&d.column(j));
17        }
18        // Compute the mean.
19        let mu_x = d_x.mean_axis(Axis(0)).unwrap();
20
21        // Select the columns of the conditioning variables.
22        let mut d_z = Array::zeros((d.nrows(), z.len()));
23        for (i, &j) in z.iter().enumerate() {
24            d_z.column_mut(i).assign(&d.column(j));
25        }
26        // Compute the mean.
27        let mu_z = d_z.mean_axis(Axis(0)).unwrap();
28
29        // Compute the second moment statistics.
30        let m_xx = d_x.t().dot(&d_x);
31        let m_xz = d_x.t().dot(&d_z);
32        let m_zz = d_z.t().dot(&d_z);
33
34        // Get the sample size.
35        let n = d.nrows() as f64;
36
37        // Return the sufficient statistics.
38        GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
39    }
40}
41
42impl CSSEstimator<GaussCPDS> for SSE<'_, GaussTable> {
43    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
44        // Assert variables and conditioning variables must be disjoint.
45        assert!(
46            x.is_disjoint(z),
47            "Variables and conditioning variables must be disjoint."
48        );
49        // Get the values.
50        let d = self.dataset.values();
51        // Return the sufficient statistics.
52        Self::fit(d.view(), x, z)
53    }
54}
55
56impl ParCSSEstimator<GaussCPDS> for SSE<'_, GaussTable> {
57    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
58        // Assert variables and conditioning variables must be disjoint.
59        assert!(
60            x.is_disjoint(z),
61            "Variables and conditioning variables must be disjoint."
62        );
63
64        // Initialize the sufficient statistics.
65        let s_xz = {
66            let n = 0.;
67            let mu_x = Array::zeros(x.len());
68            let mu_z = Array::zeros(z.len());
69            let m_xx = Array::zeros((x.len(), x.len()));
70            let m_xz = Array::zeros((x.len(), z.len()));
71            let m_zz = Array::zeros((z.len(), z.len()));
72            GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
73        };
74
75        // Get the values.
76        let d = self.dataset.values();
77
78        // Get the values.
79        d.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
80            .into_par_iter()
81            // Compute the sufficient statistics for each chunk.
82            .map(|d| Self::fit(d, x, z))
83            // Aggregate the sufficient statistics.
84            .fold(|| s_xz.clone(), |a, b| a + b)
85            .reduce(|| s_xz.clone(), |a, b| a + b)
86    }
87}
88
89impl SSE<'_, GaussWtdTable> {
90    fn fit(
91        d: ArrayView2<f64>,
92        norm_w: ArrayView2<f64>,
93        sum_w: f64,
94        x: &Set<usize>,
95        z: &Set<usize>,
96    ) -> GaussCPDS {
97        // Select the columns of the variables.
98        let mut d_x = Array::zeros((d.nrows(), x.len()));
99        for (i, &j) in x.iter().enumerate() {
100            d_x.column_mut(i).assign(&d.column(j));
101        }
102        // Compute the weighted mean.
103        let mu_x = (&norm_w * &d_x).mean_axis(Axis(0)).unwrap();
104
105        // Select the columns of the conditioning variables.
106        let mut d_z = Array::zeros((d.nrows(), z.len()));
107        for (i, &j) in z.iter().enumerate() {
108            d_z.column_mut(i).assign(&d.column(j));
109        }
110        // Compute the weighted mean.
111        let mu_z = (&norm_w * &d_z).mean_axis(Axis(0)).unwrap();
112
113        // Compute the root weights for centering.
114        let sqrt_w = norm_w.mapv(f64::sqrt);
115        let d_sqrt_w_x = &sqrt_w * &d_x;
116        let d_sqrt_w_z = &sqrt_w * &d_z;
117
118        // Compute the weighted second moment statistics.
119        let m_xx = d_sqrt_w_x.t().dot(&d_sqrt_w_x);
120        let m_xz = d_sqrt_w_x.t().dot(&d_sqrt_w_z);
121        let m_zz = d_sqrt_w_z.t().dot(&d_sqrt_w_z);
122
123        // Get the sample (mass) size.
124        let n = sum_w;
125
126        // Return the sufficient statistics.
127        GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
128    }
129}
130
131impl CSSEstimator<GaussCPDS> for SSE<'_, GaussWtdTable> {
132    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
133        // Assert variables and conditioning variables must be disjoint.
134        assert!(
135            x.is_disjoint(z),
136            "Variables and conditioning variables must be disjoint."
137        );
138
139        // Get the values.
140        let d = self.dataset.values().values();
141        // Get the weights.
142        let w = self.dataset.weights();
143        // Sum the weights to normalize.
144        let sum_w = w.sum();
145        // Normalize the weights.
146        let w = w / sum_w;
147        // Align the axis for broadcasting.
148        let w = w.insert_axis(Axis(1));
149
150        // Return the sufficient statistics.
151        Self::fit(d.view(), w.view(), sum_w, x, z)
152    }
153}
154
155impl ParCSSEstimator<GaussCPDS> for SSE<'_, GaussWtdTable> {
156    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPDS {
157        // Assert variables and conditioning variables must be disjoint.
158        assert!(
159            x.is_disjoint(z),
160            "Variables and conditioning variables must be disjoint."
161        );
162
163        // Initialize the sufficient statistics.
164        let s_xz = {
165            let n = 0.;
166            let mu_x = Array::zeros(x.len());
167            let mu_z = Array::zeros(z.len());
168            let m_xx = Array::zeros((x.len(), x.len()));
169            let m_xz = Array::zeros((x.len(), z.len()));
170            let m_zz = Array::zeros((z.len(), z.len()));
171            GaussCPDS::new(mu_x, mu_z, m_xx, m_xz, m_zz, n)
172        };
173
174        // Get the values.
175        let values = self.dataset.values().values();
176        // Get the weights.
177        let weights = self.dataset.weights();
178
179        // Sum the weights to normalize.
180        let sum_w: f64 = weights.par_iter().sum();
181        // Normalize the weights.
182        let weights = {
183            // Clone the weights.
184            let mut weights = weights.clone();
185            // Normalize the weights in parallel.
186            weights
187                .axis_chunks_iter_mut(Axis(0), AXIS_CHUNK_LENGTH)
188                .into_par_iter()
189                .for_each(|mut w| w /= sum_w);
190            // Return the normalized weights.
191            weights
192        };
193        // Align the axis for broadcasting.
194        let weights = weights.insert_axis(Axis(1));
195
196        // Get the values.
197        values
198            .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
199            .into_par_iter()
200            .zip(weights.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH))
201            // Compute the sufficient statistics for each chunk.
202            .map(|(d, w)| Self::fit(d, w, sum_w, x, z))
203            // Aggregate the sufficient statistics.
204            .fold(|| s_xz.clone(), |a, b| a + b)
205            .reduce(|| s_xz.clone(), |a, b| a + b)
206    }
207}