causal_hub/estimators/parameters/sufficient_statistics/table/
categorical.rs

1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5    datasets::{CatTable, CatWtdTable, Dataset},
6    estimators::{CSSEstimator, ParCSSEstimator, SSE},
7    models::CatCPDS,
8    types::{AXIS_CHUNK_LENGTH, Set},
9    utils::MI,
10};
11
12impl CSSEstimator<CatCPDS> for SSE<'_, CatTable> {
13    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
14        // Assert variables and conditioning variables must be disjoint.
15        assert!(
16            x.is_disjoint(z),
17            "Variables and conditioning variables must be disjoint."
18        );
19
20        // Get the shape.
21        let shape = self.dataset.shape();
22        // Initialize the multi index.
23        let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
24        let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
25        // Get the shape of the conditioned and conditioning variables.
26        let s_x = m_idx_x.shape().product();
27        let s_z = m_idx_z.shape().product();
28        // Initialize the joint counts.
29        let mut n_xz: Array2<usize> = Array::zeros((s_z, s_x));
30
31        // Count the occurrences of the states.
32        self.dataset.values().rows().into_iter().for_each(|row| {
33            // Get the value of X and Z as index.
34            let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
35            let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
36            // Increment the joint counts.
37            n_xz[[idx_z, idx_x]] += 1;
38        });
39
40        // Cast the counts to floating point.
41        let n_xz = n_xz.mapv(|x| x as f64);
42        // Compute the sample size.
43        let n = n_xz.sum();
44
45        // Return the sufficient statistics.
46        CatCPDS::new(n_xz, n)
47    }
48}
49
50impl ParCSSEstimator<CatCPDS> for SSE<'_, CatTable> {
51    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
52        // Assert variables and conditioning variables must be disjoint.
53        assert!(
54            x.is_disjoint(z),
55            "Variables and conditioning variables must be disjoint."
56        );
57
58        // Get the shape.
59        let shape = self.dataset.shape();
60        // Initialize the multi index.
61        let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
62        let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
63        // Get the shape of the conditioned and conditioning variables.
64        let s_x = m_idx_x.shape().product();
65        let s_z = m_idx_z.shape().product();
66        // Initialize the joint counts.
67        let n_xz: Array2<usize> = Array::zeros((s_z, s_x));
68
69        // Count the occurrences of the states.
70        let n_xz = self
71            .dataset
72            .values()
73            .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
74            .into_par_iter()
75            .map(|values| {
76                // Clone the zeros joint counts.
77                let mut n_xz = n_xz.clone();
78                // Count the occurrences of the states.
79                values.rows().into_iter().for_each(|row| {
80                    // Get the value of X and Z as index.
81                    let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
82                    let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
83                    // Increment the joint counts.
84                    n_xz[[idx_z, idx_x]] += 1;
85                });
86                // Return the local joint counts.
87                n_xz
88            })
89            // Aggregate the local joint counts.
90            .fold(|| n_xz.clone(), |a, b| a + b)
91            .reduce(|| n_xz.clone(), |a, b| a + b);
92
93        // Cast the counts to floating point.
94        let n_xz = n_xz.mapv(|x| x as f64);
95        // Compute the sample size.
96        let n = n_xz.sum();
97
98        // Return the sufficient statistics.
99        CatCPDS::new(n_xz, n)
100    }
101}
102
103impl CSSEstimator<CatCPDS> for SSE<'_, CatWtdTable> {
104    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
105        // Assert variables and conditioning variables must be disjoint.
106        assert!(
107            x.is_disjoint(z),
108            "Variables and conditioning variables must be disjoint."
109        );
110
111        // Get the shape.
112        let shape = self.dataset.shape();
113        // Initialize the multi index.
114        let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
115        let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
116        // Get the shape of the conditioned and conditioning variables.
117        let s_x = m_idx_x.shape().product();
118        let s_z = m_idx_z.shape().product();
119        // Initialize the joint counts.
120        let mut n_xz: Array2<f64> = Array::zeros((s_z, s_x));
121
122        // Get the unweighted values and weights.
123        let values = self.dataset.values().values();
124        let weights = self.dataset.weights();
125
126        // Count the occurrences of the states.
127        values
128            .rows()
129            .into_iter()
130            .zip(weights)
131            .for_each(|(row, &weight)| {
132                // Get the value of X and Z as index.
133                let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
134                let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
135                // Increment the joint counts.
136                n_xz[[idx_z, idx_x]] += weight;
137            });
138
139        // Compute the sample size.
140        let n = n_xz.sum();
141
142        // Return the sufficient statistics.
143        CatCPDS::new(n_xz, n)
144    }
145}
146
147impl ParCSSEstimator<CatCPDS> for SSE<'_, CatWtdTable> {
148    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
149        // Assert variables and conditioning variables must be disjoint.
150        assert!(
151            x.is_disjoint(z),
152            "Variables and conditioning variables must be disjoint."
153        );
154
155        // Get the shape.
156        let shape = self.dataset.shape();
157        // Initialize the multi index.
158        let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
159        let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
160        // Get the shape of the conditioned and conditioning variables.
161        let s_x = m_idx_x.shape().product();
162        let s_z = m_idx_z.shape().product();
163        // Initialize the joint counts.
164        let n_xz: Array2<f64> = Array::zeros((s_z, s_x));
165
166        // Get the unweighted values and weights.
167        let values = self.dataset.values().values();
168        let weights = self.dataset.weights();
169
170        // Count the occurrences of the states.
171        let n_xz = values
172            .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
173            .into_par_iter()
174            .zip(weights.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH))
175            .map(|(values, weights)| {
176                // Clone the zeros joint counts.
177                let mut n_xz = n_xz.clone();
178                // Count the occurrences of the states.
179                values
180                    .rows()
181                    .into_iter()
182                    .zip(weights)
183                    .for_each(|(row, &weight)| {
184                        // Get the value of X and Z as index.
185                        let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
186                        let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
187                        // Increment the joint counts.
188                        n_xz[[idx_z, idx_x]] += weight;
189                    });
190                // Return the local joint counts.
191                n_xz
192            })
193            // Aggregate the local joint counts.
194            .fold(|| n_xz.clone(), |a, b| a + b)
195            .reduce(|| n_xz.clone(), |a, b| a + b);
196
197        // Compute the sample size.
198        let n = n_xz.sum();
199
200        // Return the sufficient statistics.
201        CatCPDS::new(n_xz, n)
202    }
203}