causal_hub/estimators/parameters/sufficient_statistics/table/
categorical.rs1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5 datasets::{CatTable, CatWtdTable, Dataset},
6 estimators::{CSSEstimator, ParCSSEstimator, SSE},
7 models::CatCPDS,
8 types::{AXIS_CHUNK_LENGTH, Set},
9 utils::MI,
10};
11
12impl CSSEstimator<CatCPDS> for SSE<'_, CatTable> {
13 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
14 assert!(
16 x.is_disjoint(z),
17 "Variables and conditioning variables must be disjoint."
18 );
19
20 let shape = self.dataset.shape();
22 let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
24 let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
25 let s_x = m_idx_x.shape().product();
27 let s_z = m_idx_z.shape().product();
28 let mut n_xz: Array2<usize> = Array::zeros((s_z, s_x));
30
31 self.dataset.values().rows().into_iter().for_each(|row| {
33 let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
35 let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
36 n_xz[[idx_z, idx_x]] += 1;
38 });
39
40 let n_xz = n_xz.mapv(|x| x as f64);
42 let n = n_xz.sum();
44
45 CatCPDS::new(n_xz, n)
47 }
48}
49
50impl ParCSSEstimator<CatCPDS> for SSE<'_, CatTable> {
51 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
52 assert!(
54 x.is_disjoint(z),
55 "Variables and conditioning variables must be disjoint."
56 );
57
58 let shape = self.dataset.shape();
60 let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
62 let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
63 let s_x = m_idx_x.shape().product();
65 let s_z = m_idx_z.shape().product();
66 let n_xz: Array2<usize> = Array::zeros((s_z, s_x));
68
69 let n_xz = self
71 .dataset
72 .values()
73 .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
74 .into_par_iter()
75 .map(|values| {
76 let mut n_xz = n_xz.clone();
78 values.rows().into_iter().for_each(|row| {
80 let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
82 let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
83 n_xz[[idx_z, idx_x]] += 1;
85 });
86 n_xz
88 })
89 .fold(|| n_xz.clone(), |a, b| a + b)
91 .reduce(|| n_xz.clone(), |a, b| a + b);
92
93 let n_xz = n_xz.mapv(|x| x as f64);
95 let n = n_xz.sum();
97
98 CatCPDS::new(n_xz, n)
100 }
101}
102
103impl CSSEstimator<CatCPDS> for SSE<'_, CatWtdTable> {
104 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
105 assert!(
107 x.is_disjoint(z),
108 "Variables and conditioning variables must be disjoint."
109 );
110
111 let shape = self.dataset.shape();
113 let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
115 let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
116 let s_x = m_idx_x.shape().product();
118 let s_z = m_idx_z.shape().product();
119 let mut n_xz: Array2<f64> = Array::zeros((s_z, s_x));
121
122 let values = self.dataset.values().values();
124 let weights = self.dataset.weights();
125
126 values
128 .rows()
129 .into_iter()
130 .zip(weights)
131 .for_each(|(row, &weight)| {
132 let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
134 let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
135 n_xz[[idx_z, idx_x]] += weight;
137 });
138
139 let n = n_xz.sum();
141
142 CatCPDS::new(n_xz, n)
144 }
145}
146
147impl ParCSSEstimator<CatCPDS> for SSE<'_, CatWtdTable> {
148 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPDS {
149 assert!(
151 x.is_disjoint(z),
152 "Variables and conditioning variables must be disjoint."
153 );
154
155 let shape = self.dataset.shape();
157 let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
159 let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
160 let s_x = m_idx_x.shape().product();
162 let s_z = m_idx_z.shape().product();
163 let n_xz: Array2<f64> = Array::zeros((s_z, s_x));
165
166 let values = self.dataset.values().values();
168 let weights = self.dataset.weights();
169
170 let n_xz = values
172 .axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH)
173 .into_par_iter()
174 .zip(weights.axis_chunks_iter(Axis(0), AXIS_CHUNK_LENGTH))
175 .map(|(values, weights)| {
176 let mut n_xz = n_xz.clone();
178 values
180 .rows()
181 .into_iter()
182 .zip(weights)
183 .for_each(|(row, &weight)| {
184 let idx_x = m_idx_x.ravel(x.iter().map(|&i| row[i] as usize));
186 let idx_z = m_idx_z.ravel(z.iter().map(|&i| row[i] as usize));
187 n_xz[[idx_z, idx_x]] += weight;
189 });
190 n_xz
192 })
193 .fold(|| n_xz.clone(), |a, b| a + b)
195 .reduce(|| n_xz.clone(), |a, b| a + b);
196
197 let n = n_xz.sum();
199
200 CatCPDS::new(n_xz, n)
202 }
203}