causal_hub/estimators/parameters/sufficient_statistics/trajectory/
categorical.rs1use dry::macro_for;
2use itertools::Itertools;
3use ndarray::prelude::*;
4use rayon::prelude::*;
5
6use crate::{
7 datasets::{CatTrj, CatTrjs, CatWtdTrj, CatWtdTrjs, Dataset},
8 estimators::{CSSEstimator, ParCSSEstimator, SSE},
9 models::CatCIMS,
10 types::Set,
11 utils::MI,
12};
13
14impl CSSEstimator<CatCIMS> for SSE<'_, CatTrj> {
15 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
16 assert!(
18 x.is_disjoint(z),
19 "Variables and conditioning variables must be disjoint."
20 );
21
22 let shape = self.dataset.shape();
24
25 let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
27 let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
28 let s_x = m_idx_x.shape().product();
30 let s_z = m_idx_z.shape().product();
31
32 let mut n_xz: Array3<usize> = Array::zeros((s_z, s_x, s_x));
34 let mut t_xz: Array2<f64> = Array::zeros((s_z, s_x));
36
37 self.dataset
39 .values()
40 .rows()
41 .into_iter()
42 .zip(self.dataset.times())
43 .tuple_windows()
44 .for_each(|((e_i, t_i), (e_j, t_j))| {
46 let idx_x_i = m_idx_x.ravel(x.iter().map(|&i| e_i[i] as usize));
48 let idx_x_j = m_idx_x.ravel(x.iter().map(|&i| e_j[i] as usize));
49 let idx_z = m_idx_z.ravel(z.iter().map(|&i| e_i[i] as usize));
51 n_xz[[idx_z, idx_x_i, idx_x_j]] += (idx_x_i != idx_x_j) as usize;
53 t_xz[[idx_z, idx_x_i]] += t_j - t_i;
55 });
56
57 let n_xz = n_xz.mapv(|x| x as f64);
59 let n = n_xz.sum();
61
62 CatCIMS::new(n_xz, t_xz, n)
63 }
64}
65
66impl CSSEstimator<CatCIMS> for SSE<'_, CatWtdTrj> {
67 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
68 let w = self.dataset.weight();
70 let s = SSE::new(self.dataset.trajectory()).fit(x, z);
72 let n_xz = s.sample_conditional_counts();
74 let t_xz = s.sample_conditional_times();
75 let n = s.sample_size();
76 CatCIMS::new(n_xz * w, t_xz * w, n * w)
78 }
79}
80
81macro_for!($type in [CatTrjs, CatWtdTrjs] {
83
84 impl CSSEstimator<CatCIMS> for SSE<'_, $type> {
85 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
86 let shape = self.dataset.shape();
88
89 let s_x = x.iter().map(|&i| shape[i]).product();
91 let s_z = z.iter().map(|&i| shape[i]).product();
92
93 let s = CatCIMS::new(
95 Array3::zeros((s_z, s_x, s_x)),
97 Array2::zeros((s_z, s_x)),
99 0.,
101 );
102
103 self.dataset
105 .into_iter()
106 .fold(s, |s_a, trj_b| s_a + SSE::new(trj_b).fit(x, z))
108 }
109 }
110
111 impl ParCSSEstimator<CatCIMS> for SSE<'_, $type> {
112 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
113 let shape = self.dataset.shape();
115
116 let s_x = x.iter().map(|&i| shape[i]).product();
118 let s_z = z.iter().map(|&i| shape[i]).product();
119
120 let s = CatCIMS::new(
122 Array3::zeros((s_z, s_x, s_x)),
124 Array2::zeros((s_z, s_x)),
126 0.,
128 );
129
130 self.dataset
132 .par_iter()
133 .fold(
135 || s.clone(),
136 |s_a, trj_b| s_a + SSE::new(trj_b).fit(x, z),
137 )
138 .reduce(
139 || s.clone(),
140 |s_a, s_b| s_a + s_b
141 )
142 }
143 }
144});