1use dry::macro_for;
2use ndarray::prelude::*;
3use statrs::function::gamma::ln_gamma;
4
5use crate::{
6 datasets::{CatTable, CatTrj, CatTrjs, CatWtdTable, CatWtdTrj, CatWtdTrjs},
7 estimators::{
8 CIMEstimator, CPDEstimator, CSSEstimator, ParCIMEstimator, ParCPDEstimator,
9 ParCSSEstimator, SSE,
10 },
11 models::{CatCIM, CatCIMS, CatCPD, CatCPDS, Labelled},
12 types::{Labels, Set, States},
13};
14
15#[derive(Clone, Copy, Debug)]
17pub struct BE<'a, D, T> {
18 dataset: &'a D,
19 prior: T,
20}
21
22impl<'a, D> BE<'a, D, ()> {
23 #[inline]
34 pub const fn new(dataset: &'a D) -> Self {
35 Self { dataset, prior: () }
36 }
37}
38
39impl<'a, D, T> BE<'a, D, T> {
40 #[inline]
51 pub fn with_prior<U>(self, prior: U) -> BE<'a, D, U> {
52 BE {
53 dataset: self.dataset,
54 prior,
55 }
56 }
57}
58
59impl<D, T> Labelled for BE<'_, D, T>
60where
61 D: Labelled,
62{
63 #[inline]
64 fn labels(&self) -> &Labels {
65 self.dataset.labels()
66 }
67}
68
69impl BE<'_, CatTable, usize> {
70 fn fit(
72 states: &States,
73 shape: &Array1<usize>,
74 x: &Set<usize>,
75 z: &Set<usize>,
76 sample_statistics: CatCPDS,
77 prior: usize,
78 ) -> CatCPD {
79 let n_xz = sample_statistics.sample_conditional_counts();
81 let n_z = n_xz.sum_axis(Axis(1)).insert_axis(Axis(1));
83
84 let alpha = prior;
86 assert!(alpha > 0, "Alpha must be positive.");
88
89 let alpha = alpha as f64;
91
92 let n_xz = n_xz + alpha;
94 let n_z = n_z + alpha * x.iter().map(|&i| shape[i]).product::<usize>() as f64;
95 let parameters = &n_xz / &n_z;
97
98 let sample_log_likelihood = Some((&n_xz * parameters.ln()).sum());
100
101 let conditioning_states = z
103 .iter()
104 .map(|&i| {
105 let (k, v) = states.get_index(i).unwrap();
106 (k.clone(), v.clone())
107 })
108 .collect();
109 let states = x
111 .iter()
112 .map(|&i| {
113 let (k, v) = states.get_index(i).unwrap();
114 (k.clone(), v.clone())
115 })
116 .collect();
117
118 let sample_statistics = Some(sample_statistics);
120
121 CatCPD::with_optionals(
123 states,
124 conditioning_states,
125 parameters,
126 sample_statistics,
127 sample_log_likelihood,
128 )
129 }
130}
131
132macro_for!($type in [CatTable, CatWtdTable] {
134
135 impl CPDEstimator<CatCPD> for BE<'_, $type, ()> {
136 #[inline]
137 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
138 self.clone().with_prior(1).fit(x, z)
140 }
141 }
142
143 impl CPDEstimator<CatCPD> for BE<'_, $type, usize> {
144 #[inline]
145 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
146 let (states, shape, prior) = (self.dataset.states(), self.dataset.shape(), self.prior);
148 let sample_statistics = SSE::new(self.dataset).fit(x, z);
150 BE::<'_, CatTable, _>::fit(states, shape, x, z, sample_statistics, prior)
152 }
153 }
154
155 impl ParCPDEstimator<CatCPD> for BE<'_, $type, ()> {
156 #[inline]
157 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
158 self.clone().with_prior(1).fit(x, z)
160 }
161 }
162
163 impl ParCPDEstimator<CatCPD> for BE<'_, $type, usize> {
164 #[inline]
165 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
166 let (states, shape, prior) = (self.dataset.states(), self.dataset.shape(), self.prior);
168 let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
170 BE::<'_, CatTable, _>::fit(states, shape, x, z, sample_statistics, prior)
172 }
173 }
174
175});
176
177impl BE<'_, CatTrj, (usize, f64)> {
178 fn fit(
180 states: &States,
181 x: &Set<usize>,
182 z: &Set<usize>,
183 sample_statistics: CatCIMS,
184 prior: (usize, f64),
185 ) -> CatCIM {
186 let (alpha, tau) = prior;
188 assert!(alpha > 0, "Alpha must be positive.");
190 assert!(tau > 0.0, "Tau must be positive.");
192
193 let n_xz = sample_statistics.sample_conditional_counts();
195 let t_xz = sample_statistics.sample_conditional_times();
196
197 let t_xz = &t_xz.clone().insert_axis(Axis(2));
199
200 let s_z = n_xz.shape()[0] as f64;
202 let alpha = alpha as f64 / s_z;
204 let tau = tau / s_z;
205
206 let n_xz = n_xz + alpha;
208 let t_xz = t_xz + tau;
209 let mut parameters = &n_xz / &t_xz;
211 parameters.outer_iter_mut().for_each(|mut q| {
213 q.diag_mut().fill(0.);
215 let q_neg_sum = -q.sum_axis(Axis(1));
217 q.diag_mut().assign(&q_neg_sum);
219 });
220
221 let sample_log_likelihood = Some({
223 let n_z = n_xz.sum_axis(Axis(2));
225 let t_z = t_xz.sum_axis(Axis(2));
226 let ll_q_xz = {
228 (&n_z + 1.).mapv(ln_gamma).sum() + (alpha + 1.) * f64::ln(tau) - (ln_gamma(alpha + 1.) + ((&n_z + 1.) * &t_z.ln()).sum())
231 };
232 let ll_p_xz = {
234 (ln_gamma(alpha) - n_z.mapv(ln_gamma).sum()) + (ln_gamma(alpha) - n_xz.mapv(ln_gamma).sum())
237 };
238 ll_q_xz + ll_p_xz
240 });
241
242 let conditioning_states = z
244 .iter()
245 .map(|&i| {
246 let (k, v) = states.get_index(i).unwrap();
247 (k.clone(), v.clone())
248 })
249 .collect();
250 let states = x
252 .iter()
253 .map(|&i| {
254 let (k, v) = states.get_index(i).unwrap();
255 (k.clone(), v.clone())
256 })
257 .collect();
258
259 let sample_statistics = Some(sample_statistics);
261
262 CatCIM::with_optionals(
264 states,
265 conditioning_states,
266 parameters,
267 sample_statistics,
268 sample_log_likelihood,
269 )
270 }
271}
272
273macro_for!($type in [CatTrj, CatWtdTrj, CatTrjs, CatWtdTrjs] {
275
276 impl CIMEstimator<CatCIM> for BE<'_, $type, ()> {
277 #[inline]
278 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
279 self.clone().with_prior((1, 1.)).fit(x, z)
281 }
282 }
283
284 impl CIMEstimator<CatCIM> for BE<'_, $type, (usize, f64)> {
285 #[inline]
286 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
287 let (states, prior) = (self.dataset.states(), self.prior);
289 let sample_statistics = SSE::new(self.dataset).fit(x, z);
291 BE::<'_, CatTrj, _>::fit(states, x, z, sample_statistics, prior)
293 }
294 }
295
296});
297
298macro_for!($type in [CatTrjs, CatWtdTrjs] {
300
301 impl ParCIMEstimator<CatCIM> for BE<'_, $type, ()> {
302 #[inline]
303 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
304 self.clone().with_prior((1, 1.)).fit(x, z)
306 }
307 }
308
309 impl ParCIMEstimator<CatCIM> for BE<'_, $type, (usize, f64)> {
310 #[inline]
311 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
312 let (states, prior) = (self.dataset.states(), self.prior);
314 let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
316 BE::<'_, CatTrj, _>::fit(states, x, z, sample_statistics, prior)
318 }
319 }
320
321});