causal_hub/estimators/parameters/
bayesian.rs

1use dry::macro_for;
2use ndarray::prelude::*;
3use statrs::function::gamma::ln_gamma;
4
5use crate::{
6    datasets::{CatTable, CatTrj, CatTrjs, CatWtdTable, CatWtdTrj, CatWtdTrjs},
7    estimators::{
8        CIMEstimator, CPDEstimator, CSSEstimator, ParCIMEstimator, ParCPDEstimator,
9        ParCSSEstimator, SSE,
10    },
11    models::{CatCIM, CatCIMS, CatCPD, CatCPDS, Labelled},
12    types::{Labels, Set, States},
13};
14
15/// A struct representing a Bayesian estimator.
16#[derive(Clone, Copy, Debug)]
17pub struct BE<'a, D, T> {
18    dataset: &'a D,
19    prior: T,
20}
21
22impl<'a, D> BE<'a, D, ()> {
23    /// Creates a new Bayesian estimator.
24    ///
25    /// # Arguments
26    ///
27    /// * `dataset` - A reference to the dataset to fit the estimator to.
28    ///
29    /// # Returns
30    ///
31    /// A new Bayesian estimator.
32    ///
33    #[inline]
34    pub const fn new(dataset: &'a D) -> Self {
35        Self { dataset, prior: () }
36    }
37}
38
39impl<'a, D, T> BE<'a, D, T> {
40    /// Sets the prior distribution.
41    ///
42    /// # Arguments
43    ///
44    /// * `prior` - The prior distribution to set.
45    ///
46    /// # Returns
47    ///
48    /// A new Bayesian estimator with the specified prior.
49    ///
50    #[inline]
51    pub fn with_prior<U>(self, prior: U) -> BE<'a, D, U> {
52        BE {
53            dataset: self.dataset,
54            prior,
55        }
56    }
57}
58
59impl<D, T> Labelled for BE<'_, D, T>
60where
61    D: Labelled,
62{
63    #[inline]
64    fn labels(&self) -> &Labels {
65        self.dataset.labels()
66    }
67}
68
69impl BE<'_, CatTable, usize> {
70    // Fit a CPD given sufficient statistics.
71    fn fit(
72        states: &States,
73        shape: &Array1<usize>,
74        x: &Set<usize>,
75        z: &Set<usize>,
76        sample_statistics: CatCPDS,
77        prior: usize,
78    ) -> CatCPD {
79        // Get the conditional counts.
80        let n_xz = sample_statistics.sample_conditional_counts();
81        // Marginalize the counts.
82        let n_z = n_xz.sum_axis(Axis(1)).insert_axis(Axis(1));
83
84        // Get the prior, as the alpha of the Dirichlet distribution.
85        let alpha = prior;
86        // Assert alpha is positive.
87        assert!(alpha > 0, "Alpha must be positive.");
88
89        // Cast alpha to floating point.
90        let alpha = alpha as f64;
91
92        // Add the prior to the counts.
93        let n_xz = n_xz + alpha;
94        let n_z = n_z + alpha * x.iter().map(|&i| shape[i]).product::<usize>() as f64;
95        // Compute the parameters by normalizing the counts with the prior.
96        let parameters = &n_xz / &n_z;
97
98        // Compute the sample log-likelihood.
99        let sample_log_likelihood = Some((&n_xz * parameters.ln()).sum());
100
101        // Subset the conditioning labels, states and shape.
102        let conditioning_states = z
103            .iter()
104            .map(|&i| {
105                let (k, v) = states.get_index(i).unwrap();
106                (k.clone(), v.clone())
107            })
108            .collect();
109        // Get the labels of the conditioned variables.
110        let states = x
111            .iter()
112            .map(|&i| {
113                let (k, v) = states.get_index(i).unwrap();
114                (k.clone(), v.clone())
115            })
116            .collect();
117
118        // Wrap the sample statistics in an option.
119        let sample_statistics = Some(sample_statistics);
120
121        // Construct the CPD.
122        CatCPD::with_optionals(
123            states,
124            conditioning_states,
125            parameters,
126            sample_statistics,
127            sample_log_likelihood,
128        )
129    }
130}
131
132// Implement the CPD estimator for the BE struct.
133macro_for!($type in [CatTable, CatWtdTable] {
134
135    impl CPDEstimator<CatCPD> for BE<'_, $type, ()> {
136        #[inline]
137        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
138            // Default to uniform prior.
139            self.clone().with_prior(1).fit(x, z)
140        }
141    }
142
143    impl CPDEstimator<CatCPD> for BE<'_, $type, usize> {
144        #[inline]
145        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
146            // Get (states, shape, prior).
147            let (states, shape, prior) = (self.dataset.states(), self.dataset.shape(), self.prior);
148            // Compute sufficient statistics.
149            let sample_statistics = SSE::new(self.dataset).fit(x, z);
150            // Fit the CPD given the sufficient statistics.
151            BE::<'_, CatTable, _>::fit(states, shape, x, z, sample_statistics, prior)
152        }
153    }
154
155    impl ParCPDEstimator<CatCPD> for BE<'_, $type, ()> {
156        #[inline]
157        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
158            // Default to uniform prior.
159            self.clone().with_prior(1).fit(x, z)
160        }
161    }
162
163    impl ParCPDEstimator<CatCPD> for BE<'_, $type, usize> {
164        #[inline]
165        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
166            // Get (states, shape, prior).
167            let (states, shape, prior) = (self.dataset.states(), self.dataset.shape(), self.prior);
168            // Compute sufficient statistics in parallel.
169            let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
170            // Fit the CPD given the sufficient statistics.
171            BE::<'_, CatTable, _>::fit(states, shape, x, z, sample_statistics, prior)
172        }
173    }
174
175});
176
177impl BE<'_, CatTrj, (usize, f64)> {
178    // Fit a CIM given sufficient statistics.
179    fn fit(
180        states: &States,
181        x: &Set<usize>,
182        z: &Set<usize>,
183        sample_statistics: CatCIMS,
184        prior: (usize, f64),
185    ) -> CatCIM {
186        // Get the prior, as the alpha of Dirichlet and tau of Gamma.
187        let (alpha, tau) = prior;
188        // Assert alpha is positive.
189        assert!(alpha > 0, "Alpha must be positive.");
190        // Assert tau is positive.
191        assert!(tau > 0.0, "Tau must be positive.");
192
193        // Get the conditional counts and times.
194        let n_xz = sample_statistics.sample_conditional_counts();
195        let t_xz = sample_statistics.sample_conditional_times();
196
197        // Insert axis to align the dimensions.
198        let t_xz = &t_xz.clone().insert_axis(Axis(2));
199
200        // Get the shape of the conditioning variables.
201        let s_z = n_xz.shape()[0] as f64;
202        // Scale the prior by the shape.
203        let alpha = alpha as f64 / s_z;
204        let tau = tau / s_z;
205
206        // Add the prior to the counts and times.
207        let n_xz = n_xz + alpha;
208        let t_xz = t_xz + tau;
209        // Estimate the parameters by normalizing the counts.
210        let mut parameters = &n_xz / &t_xz;
211        // Fix the diagonal.
212        parameters.outer_iter_mut().for_each(|mut q| {
213            // Fill the diagonal with zeros.
214            q.diag_mut().fill(0.);
215            // Compute the negative sum of the rows.
216            let q_neg_sum = -q.sum_axis(Axis(1));
217            // Assign the negative sum to the diagonal.
218            q.diag_mut().assign(&q_neg_sum);
219        });
220
221        // Compute the sample log-likelihood.
222        let sample_log_likelihood = Some({
223            // Sum counts.
224            let n_z = n_xz.sum_axis(Axis(2));
225            let t_z = t_xz.sum_axis(Axis(2));
226            // Compute the sample log-likelihood.
227            let ll_q_xz = {
228                // Compute the sample log-likelihood.
229                (&n_z + 1.).mapv(ln_gamma).sum() + (alpha + 1.) * f64::ln(tau) //.
230                - (ln_gamma(alpha + 1.) + ((&n_z + 1.) * &t_z.ln()).sum())
231            };
232            // Compute the sample log-likelihood.
233            let ll_p_xz = {
234                // Compute the sample log-likelihood.
235                (ln_gamma(alpha) - n_z.mapv(ln_gamma).sum())     //.
236                + (ln_gamma(alpha) - n_xz.mapv(ln_gamma).sum())
237            };
238            // Return the total log-likelihood.
239            ll_q_xz + ll_p_xz
240        });
241
242        // Subset the conditioning labels, states and shape.
243        let conditioning_states = z
244            .iter()
245            .map(|&i| {
246                let (k, v) = states.get_index(i).unwrap();
247                (k.clone(), v.clone())
248            })
249            .collect();
250        // Get the labels of the conditioned variables.
251        let states = x
252            .iter()
253            .map(|&i| {
254                let (k, v) = states.get_index(i).unwrap();
255                (k.clone(), v.clone())
256            })
257            .collect();
258
259        // Wrap the sufficient statistics in an option.
260        let sample_statistics = Some(sample_statistics);
261
262        // Construct the CIM.
263        CatCIM::with_optionals(
264            states,
265            conditioning_states,
266            parameters,
267            sample_statistics,
268            sample_log_likelihood,
269        )
270    }
271}
272
273// Implement the CIM estimator for the BE struct.
274macro_for!($type in [CatTrj, CatWtdTrj, CatTrjs, CatWtdTrjs] {
275
276    impl CIMEstimator<CatCIM> for BE<'_, $type, ()> {
277        #[inline]
278        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
279            // Default to uniform prior.
280            self.clone().with_prior((1, 1.)).fit(x, z)
281        }
282    }
283
284    impl CIMEstimator<CatCIM> for BE<'_, $type, (usize, f64)> {
285        #[inline]
286        fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
287            // Get (states, prior).
288            let (states, prior) = (self.dataset.states(), self.prior);
289            // Compute sufficient statistics.
290            let sample_statistics = SSE::new(self.dataset).fit(x, z);
291            // Fit the CIM given the sufficient statistics.
292            BE::<'_, CatTrj, _>::fit(states, x, z, sample_statistics, prior)
293        }
294    }
295
296});
297
298// Implement the parallel CIM estimator for the BE struct.
299macro_for!($type in [CatTrjs, CatWtdTrjs] {
300
301    impl ParCIMEstimator<CatCIM> for BE<'_, $type, ()> {
302        #[inline]
303        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
304            // Default to uniform prior.
305            self.clone().with_prior((1, 1.)).fit(x, z)
306        }
307    }
308
309    impl ParCIMEstimator<CatCIM> for BE<'_, $type, (usize, f64)> {
310        #[inline]
311        fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
312            // Get (states, prior).
313            let (states, prior) = (self.dataset.states(), self.prior);
314            // Compute sufficient statistics in parallel.
315            let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
316            // Fit the CIM given the sufficient statistics.
317            BE::<'_, CatTrj, _>::fit(states, x, z, sample_statistics, prior)
318        }
319    }
320
321});