causal_hub/estimators/parameters/sufficient_statistics/trajectory/
categorical.rs

1use dry::macro_for;
2use itertools::Itertools;
3use ndarray::prelude::*;
4use rayon::prelude::*;
5
6use crate::{
7    datasets::{CatTrj, CatTrjs, CatWtdTrj, CatWtdTrjs, Dataset},
8    estimators::{CSSEstimator, ParCSSEstimator, SSE},
9    models::CatCIMS,
10    types::Set,
11    utils::MI,
12};
13
14impl CSSEstimator<CatCIMS> for SSE<'_, CatTrj> {
15    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
16        // Assert variables and conditioning variables must be disjoint..
17        assert!(
18            x.is_disjoint(z),
19            "Variables and conditioning variables must be disjoint."
20        );
21
22        // Get the shape.
23        let shape = self.dataset.shape();
24
25        // Initialize the multi index.
26        let m_idx_x = MI::new(x.iter().map(|&i| shape[i]));
27        let m_idx_z = MI::new(z.iter().map(|&i| shape[i]));
28        // Get the shape of the conditioned and conditioning variables.
29        let s_x = m_idx_x.shape().product();
30        let s_z = m_idx_z.shape().product();
31
32        // Initialize the joint counts.
33        let mut n_xz: Array3<usize> = Array::zeros((s_z, s_x, s_x));
34        // Initialize the time spent in that state.
35        let mut t_xz: Array2<f64> = Array::zeros((s_z, s_x));
36
37        // Iterate over the trajectory events.
38        self.dataset
39            .values()
40            .rows()
41            .into_iter()
42            .zip(self.dataset.times())
43            .tuple_windows()
44            // Compare the current and next event.
45            .for_each(|((e_i, t_i), (e_j, t_j))| {
46                // Get the value of X as index.
47                let idx_x_i = m_idx_x.ravel(x.iter().map(|&i| e_i[i] as usize));
48                let idx_x_j = m_idx_x.ravel(x.iter().map(|&i| e_j[i] as usize));
49                // Get the value of Z as index using the strides.
50                let idx_z = m_idx_z.ravel(z.iter().map(|&i| e_i[i] as usize));
51                // Increment the count when conditioned variable transitions.
52                n_xz[[idx_z, idx_x_i, idx_x_j]] += (idx_x_i != idx_x_j) as usize;
53                // Increment the time in that state.
54                t_xz[[idx_z, idx_x_i]] += t_j - t_i;
55            });
56
57        // Cast the counts to floating point.
58        let n_xz = n_xz.mapv(|x| x as f64);
59        // Compute the sample size.
60        let n = n_xz.sum();
61
62        CatCIMS::new(n_xz, t_xz, n)
63    }
64}
65
66impl CSSEstimator<CatCIMS> for SSE<'_, CatWtdTrj> {
67    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
68        // Get the weight of the trajectory.
69        let w = self.dataset.weight();
70        // Compute the unweighted sufficient statistics.
71        let s = SSE::new(self.dataset.trajectory()).fit(x, z);
72        // Destructure the sufficient statistics.
73        let n_xz = s.sample_conditional_counts();
74        let t_xz = s.sample_conditional_times();
75        let n = s.sample_size();
76        // Apply the weight to the sufficient statistics.
77        CatCIMS::new(n_xz * w, t_xz * w, n * w)
78    }
79}
80
81// Implement the CSSEstimator and ParCSSEstimator traits for both CatTrjs and CatWtdTrjs.
82macro_for!($type in [CatTrjs, CatWtdTrjs] {
83
84    impl CSSEstimator<CatCIMS> for SSE<'_, $type> {
85        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
86            // Get the shape.
87            let shape = self.dataset.shape();
88
89            // Get the shape of the conditioned and conditioning variables.
90            let s_x = x.iter().map(|&i| shape[i]).product();
91            let s_z = z.iter().map(|&i| shape[i]).product();
92
93            // Initialize the sufficient statistics.
94            let s = CatCIMS::new(
95                // Initialize the joint counts.
96                Array3::zeros((s_z, s_x, s_x)),
97                // Initialize the time spent in that state.
98                Array2::zeros((s_z, s_x)),
99                // Initialize the sample size.
100                0.,
101            );
102
103            // Iterate over the trajectories.
104            self.dataset
105                .into_iter()
106                // Sum the sufficient statistics of each trajectory.
107                .fold(s, |s_a, trj_b| s_a + SSE::new(trj_b).fit(x, z))
108        }
109    }
110
111    impl ParCSSEstimator<CatCIMS> for SSE<'_, $type> {
112        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIMS {
113            // Get the shape.
114            let shape = self.dataset.shape();
115
116            // Get the shape of the conditioned and conditioning variables.
117            let s_x = x.iter().map(|&i| shape[i]).product();
118            let s_z = z.iter().map(|&i| shape[i]).product();
119
120            // Initialize the sufficient statistics.
121            let s = CatCIMS::new(
122                // Initialize the joint counts.
123                Array3::zeros((s_z, s_x, s_x)),
124                // Initialize the time spent in that state.
125                Array2::zeros((s_z, s_x)),
126                // Initialize the sample size.
127                0.,
128            );
129
130            // Iterate over the trajectories in parallel.
131            self.dataset
132                .par_iter()
133                // Sum the sufficient statistics of each trajectory.
134                .fold(
135                    || s.clone(),
136                    |s_a, trj_b| s_a + SSE::new(trj_b).fit(x, z),
137                )
138                .reduce(
139                    || s.clone(),
140                    |s_a, s_b| s_a + s_b
141                )
142        }
143    }
144});