causal_hub/estimators/parameters/
maximum_likelihood.rs

1use dry::macro_for;
2use ndarray::prelude::*;
3use ndarray_linalg::Determinant;
4
5use crate::{
6    datasets::{
7        CatTable, CatTrj, CatTrjs, CatWtdTable, CatWtdTrj, CatWtdTrjs, GaussTable, GaussWtdTable,
8    },
9    estimators::{
10        CIMEstimator, CPDEstimator, CSSEstimator, ParCIMEstimator, ParCPDEstimator,
11        ParCSSEstimator, SSE,
12    },
13    models::{CatCIM, CatCIMS, CatCPD, CatCPDS, GaussCPD, GaussCPDP, GaussCPDS, Labelled},
14    types::{LN_2_PI, Labels, Set, States},
15    utils::PseudoInverse,
16};
17
18/// A struct representing a maximum likelihood estimator.
19#[derive(Clone, Copy, Debug)]
20pub struct MLE<'a, D> {
21    dataset: &'a D,
22}
23
24impl<'a, D> MLE<'a, D> {
25    /// Creates a new maximum likelihood estimator.
26    ///
27    /// # Arguments
28    ///
29    /// * `dataset` - A reference to the dataset to fit the estimator to.
30    ///
31    /// # Returns
32    ///
33    /// A new `MaximumLikelihoodEstimator` instance.
34    ///
35    #[inline]
36    pub const fn new(dataset: &'a D) -> Self {
37        Self { dataset }
38    }
39}
40
41impl<D> Labelled for MLE<'_, D>
42where
43    D: Labelled,
44{
45    #[inline]
46    fn labels(&self) -> &Labels {
47        self.dataset.labels()
48    }
49}
50
51impl MLE<'_, CatTable> {
52    fn fit(states: &States, x: &Set<usize>, z: &Set<usize>, sample_statistics: CatCPDS) -> CatCPD {
53        // Get the conditional counts.
54        let n_xz = sample_statistics.sample_conditional_counts();
55        // Marginalize the counts.
56        let n_z = &n_xz.sum_axis(Axis(1)).insert_axis(Axis(1));
57
58        // Assert the marginal counts are not zero.
59        assert!(
60            n_z.iter().all(|&x| x > 0.),
61            "Failed to get non-zero counts.",
62        );
63
64        // Compute the parameters by normalizing the counts.
65        let parameters = n_xz / n_z;
66
67        // Set epsilon to avoid ln(0).
68        let eps = f64::MIN_POSITIVE;
69        // Compute the sample log-likelihood, avoiding ln(0).
70        let sample_log_likelihood = (n_xz * (&parameters + eps).ln()).sum();
71
72        // Subset the conditioning labels, states and shape.
73        let conditioning_states = z
74            .iter()
75            .map(|&i| {
76                let (k, v) = states.get_index(i).unwrap();
77                (k.clone(), v.clone())
78            })
79            .collect();
80        // Get the labels of the conditioned variables.
81        let states = x
82            .iter()
83            .map(|&i| {
84                let (k, v) = states.get_index(i).unwrap();
85                (k.clone(), v.clone())
86            })
87            .collect();
88
89        // Wrap the sample statistics in an option.
90        let sample_statistics = Some(sample_statistics);
91        // Wrap the sample log-likelihood in an option.
92        let sample_log_likelihood = Some(sample_log_likelihood);
93
94        // Construct the CPD.
95        CatCPD::with_optionals(
96            states,
97            conditioning_states,
98            parameters,
99            sample_statistics,
100            sample_log_likelihood,
101        )
102    }
103}
104
105// Implement the CatCPD estimator for the MLE struct.
106macro_for!($type in [CatTable, CatWtdTable] {
107
108    impl CPDEstimator<CatCPD> for MLE<'_, $type> {
109        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
110            // Get states.
111            let states = self.dataset.states();
112            // Compute sufficient statistics.
113            let sample_statistics = SSE::new(self.dataset).fit(x, z);
114            // Fit the CPD given the sufficient statistics.
115            MLE::<'_, CatTable>::fit(states, x, z, sample_statistics)
116        }
117    }
118
119    impl ParCPDEstimator<CatCPD> for MLE<'_, $type> {
120        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
121            // Get states.
122            let states = self.dataset.states();
123            // Compute sufficient statistics in parallel.
124            let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
125            // Fit the CPD given the sufficient statistics.
126            MLE::<'_, CatTable>::fit(states, x, z, sample_statistics)
127        }
128    }
129
130});
131
132impl MLE<'_, GaussTable> {
133    fn fit(
134        labels: &Labels,
135        x: &Set<usize>,
136        z: &Set<usize>,
137        sample_statistics: GaussCPDS,
138    ) -> GaussCPD {
139        // Get the sample covariance matrices and size.
140        let (mu_x, mu_z, s_xx, s_xz, s_zz, n) = (
141            sample_statistics.sample_response_mean(),
142            sample_statistics.sample_design_mean(),
143            sample_statistics.sample_response_covariance(),
144            sample_statistics.sample_cross_covariance(),
145            sample_statistics.sample_design_covariance(),
146            sample_statistics.sample_size(),
147        );
148
149        // Compute the parameters in closed form.
150        let (a, b, s) = if z.is_empty() {
151            // Compute the parameters as the empirical mean and covariance.
152            let a = Array2::zeros((x.len(), 0));
153            let b = mu_x.clone();
154            let s = s_xx / n;
155            // Return the parameters.
156            (a, b, s)
157        } else {
158            // Compute the pseudo-inverse of S_zz.
159            let s_zz_pinv = s_zz.pinv();
160            // Compute the coefficient matrix.
161            let a = s_xz.dot(&s_zz_pinv);
162            // Compute the intercept vector.
163            let b = mu_x - &a.dot(mu_z);
164            // Compute the covariance matrix.
165            let s = (s_xx - &a.dot(&s_xz.t())) / n;
166            // Return the parameters.
167            (a, b, s)
168        };
169
170        // Compute the sample log-likelihood.
171        let p = x.len() as f64;
172        let (_, ln_det) = s.sln_det().expect("Failed to compute determinant of S.");
173        let sample_log_likelihood = -0.5 * n * (p * LN_2_PI + ln_det + p);
174
175        // Construct the CPD parameters.
176        let parameters = GaussCPDP::new(a, b, s);
177
178        // Subset the conditioning labels, states and shape.
179        let conditioning_labels = z.iter().map(|&i| labels[i].clone()).collect();
180        // Get the labels of the conditioned variables.
181        let labels = x.iter().map(|&i| labels[i].clone()).collect();
182
183        // Wrap the sample statistics in an option.
184        let sample_statistics = Some(sample_statistics);
185        // Wrap the sample log-likelihood in an option.
186        let sample_log_likelihood = Some(sample_log_likelihood);
187
188        // Construct the CPD.
189        GaussCPD::with_optionals(
190            labels,
191            conditioning_labels,
192            parameters,
193            sample_statistics,
194            sample_log_likelihood,
195        )
196    }
197}
198
199// Implement the GaussCPD estimator for the MLE struct.
200macro_for!($type in [GaussTable, GaussWtdTable] {
201
202    impl CPDEstimator<GaussCPD> for MLE<'_, $type> {
203        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPD {
204            // Get labels.
205            let labels = self.dataset.labels();
206            // Compute sufficient statistics.
207            let sample_statistics = SSE::new(self.dataset).fit(x, z);
208            // Fit the CPD given the sufficient statistics.
209            MLE::<'_, GaussTable>::fit(labels, x, z, sample_statistics)
210        }
211    }
212
213    impl ParCPDEstimator<GaussCPD> for MLE<'_, $type> {
214        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPD {
215            // Get labels.
216            let labels = self.dataset.labels();
217            // Compute sufficient statistics in parallel.
218            let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
219            // Fit the CPD given the sufficient statistics.
220            MLE::<'_, GaussTable>::fit(labels, x, z, sample_statistics)
221        }
222    }
223
224});
225
226impl MLE<'_, CatTrj> {
227    // Fit a CIM given sufficient statistics.
228    fn fit(states: &States, x: &Set<usize>, z: &Set<usize>, sample_statistics: CatCIMS) -> CatCIM {
229        // Get the conditional counts and times.
230        let n_xz = sample_statistics.sample_conditional_counts();
231        let t_xz = sample_statistics.sample_conditional_times();
232
233        // Assert the conditional times counts are not zero.
234        assert!(
235            t_xz.iter().all(|&x| x > 0.),
236            "Failed to get non-zero conditional times."
237        );
238
239        // Insert axis to align the dimensions.
240        let t_xz = &t_xz.clone().insert_axis(Axis(2));
241
242        // Estimate the parameters by normalizing the counts.
243        let mut parameters = n_xz / t_xz;
244        // Fix the diagonal.
245        parameters.outer_iter_mut().for_each(|mut q| {
246            // Fill the diagonal with zeros.
247            q.diag_mut().fill(0.);
248            // Compute the negative sum of the rows.
249            let q_neg_sum = -q.sum_axis(Axis(1));
250            // Assign the negative sum to the diagonal.
251            q.diag_mut().assign(&q_neg_sum);
252        });
253
254        // Set epsilon to avoid ln(0).
255        let eps = f64::MIN_POSITIVE;
256        // Compute the sample log-likelihood, avoiding ln(0).
257        let sample_log_likelihood = {
258            // Compute the sample log-likelihood.
259            let ll_q_xz = {
260                // Sum counts, aligning the dimensions.
261                let n_z = n_xz.sum_axis(Axis(2));
262                let t_z = t_xz.sum_axis(Axis(2));
263                // Clone the parameters.
264                let mut q_z = Array::zeros(n_z.dim());
265                // Get the diagonals.
266                parameters
267                    .outer_iter()
268                    .zip(q_z.outer_iter_mut())
269                    .for_each(|(p, mut q)| {
270                        q.assign(&(-&p.diag()));
271                    });
272                // Compute the sample log-likelihood.
273                (&n_z * (&q_z + eps).ln()).sum() + (-&q_z * &t_z).sum()
274            };
275            // Compute the sample log-likelihood.
276            let ll_p_xz = {
277                // Clone the parameters.
278                let mut p_xz = parameters.clone();
279                // Set diagonal to zero.
280                p_xz.outer_iter_mut().for_each(|mut p| {
281                    // Fill the diagonal with zeros.
282                    p.diag_mut().fill(0.);
283                });
284                // Normalize the parameters, align the dimensions.
285                p_xz /= &p_xz.sum_axis(Axis(2)).insert_axis(Axis(2));
286                // Compute the sample log-likelihood.
287                (n_xz * (p_xz + eps).ln()).sum()
288            };
289            // Return the total log-likelihood.
290            ll_q_xz + ll_p_xz
291        };
292
293        // Subset the conditioning labels, states and shape.
294        let conditioning_states = z
295            .iter()
296            .map(|&i| {
297                let (k, v) = states.get_index(i).unwrap();
298                (k.clone(), v.clone())
299            })
300            .collect();
301        // Get the labels of the conditioned variables.
302        let states = x
303            .iter()
304            .map(|&i| {
305                let (k, v) = states.get_index(i).unwrap();
306                (k.clone(), v.clone())
307            })
308            .collect();
309
310        // Wrap the sufficient statistics in an option.
311        let sample_statistics = Some(sample_statistics);
312        // Wrap the sample log-likelihood in an option.
313        let sample_log_likelihood = Some(sample_log_likelihood);
314
315        // Construct the CIM.
316        CatCIM::with_optionals(
317            states,
318            conditioning_states,
319            parameters,
320            sample_statistics,
321            sample_log_likelihood,
322        )
323    }
324}
325
326// Implement the CatCIM estimator for the MLE struct.
327macro_for!($type in [CatTrj, CatWtdTrj, CatTrjs, CatWtdTrjs] {
328
329    impl CIMEstimator<CatCIM> for MLE<'_, $type> {
330        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
331            // Get states.
332            let states = self.dataset.states();
333            // Compute sufficient statistics.
334            let sample_statistics = SSE::new(self.dataset).fit(x, z);
335            // Fit the CIM given the sufficient statistics.
336            MLE::<'_, CatTrj>::fit(states, x, z, sample_statistics)
337        }
338    }
339
340});
341
342// Implement the parallel version of the CIM estimator for the MLE struct.
343macro_for!($type in [CatTrjs, CatWtdTrjs] {
344
345    impl ParCIMEstimator<CatCIM> for MLE<'_, $type> {
346        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
347            // Get states.
348            let states = self.dataset.states();
349            // Compute sufficient statistics in parallel.
350            let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
351            // Fit the CIM given the sufficient statistics.
352            MLE::<'_, CatTrj>::fit(states, x, z, sample_statistics)
353        }
354    }
355
356});