causal_hub/estimators/parameters/
maximum_likelihood.rs1use dry::macro_for;
2use ndarray::prelude::*;
3use ndarray_linalg::Determinant;
4
5use crate::{
6 datasets::{
7 CatTable, CatTrj, CatTrjs, CatWtdTable, CatWtdTrj, CatWtdTrjs, GaussTable, GaussWtdTable,
8 },
9 estimators::{
10 CIMEstimator, CPDEstimator, CSSEstimator, ParCIMEstimator, ParCPDEstimator,
11 ParCSSEstimator, SSE,
12 },
13 models::{CatCIM, CatCIMS, CatCPD, CatCPDS, GaussCPD, GaussCPDP, GaussCPDS, Labelled},
14 types::{LN_2_PI, Labels, Set, States},
15 utils::PseudoInverse,
16};
17
18#[derive(Clone, Copy, Debug)]
20pub struct MLE<'a, D> {
21 dataset: &'a D,
22}
23
24impl<'a, D> MLE<'a, D> {
25 #[inline]
36 pub const fn new(dataset: &'a D) -> Self {
37 Self { dataset }
38 }
39}
40
41impl<D> Labelled for MLE<'_, D>
42where
43 D: Labelled,
44{
45 #[inline]
46 fn labels(&self) -> &Labels {
47 self.dataset.labels()
48 }
49}
50
51impl MLE<'_, CatTable> {
52 fn fit(states: &States, x: &Set<usize>, z: &Set<usize>, sample_statistics: CatCPDS) -> CatCPD {
53 let n_xz = sample_statistics.sample_conditional_counts();
55 let n_z = &n_xz.sum_axis(Axis(1)).insert_axis(Axis(1));
57
58 assert!(
60 n_z.iter().all(|&x| x > 0.),
61 "Failed to get non-zero counts.",
62 );
63
64 let parameters = n_xz / n_z;
66
67 let eps = f64::MIN_POSITIVE;
69 let sample_log_likelihood = (n_xz * (¶meters + eps).ln()).sum();
71
72 let conditioning_states = z
74 .iter()
75 .map(|&i| {
76 let (k, v) = states.get_index(i).unwrap();
77 (k.clone(), v.clone())
78 })
79 .collect();
80 let states = x
82 .iter()
83 .map(|&i| {
84 let (k, v) = states.get_index(i).unwrap();
85 (k.clone(), v.clone())
86 })
87 .collect();
88
89 let sample_statistics = Some(sample_statistics);
91 let sample_log_likelihood = Some(sample_log_likelihood);
93
94 CatCPD::with_optionals(
96 states,
97 conditioning_states,
98 parameters,
99 sample_statistics,
100 sample_log_likelihood,
101 )
102 }
103}
104
105macro_for!($type in [CatTable, CatWtdTable] {
107
108 impl CPDEstimator<CatCPD> for MLE<'_, $type> {
109 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
110 let states = self.dataset.states();
112 let sample_statistics = SSE::new(self.dataset).fit(x, z);
114 MLE::<'_, CatTable>::fit(states, x, z, sample_statistics)
116 }
117 }
118
119 impl ParCPDEstimator<CatCPD> for MLE<'_, $type> {
120 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCPD {
121 let states = self.dataset.states();
123 let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
125 MLE::<'_, CatTable>::fit(states, x, z, sample_statistics)
127 }
128 }
129
130});
131
132impl MLE<'_, GaussTable> {
133 fn fit(
134 labels: &Labels,
135 x: &Set<usize>,
136 z: &Set<usize>,
137 sample_statistics: GaussCPDS,
138 ) -> GaussCPD {
139 let (mu_x, mu_z, s_xx, s_xz, s_zz, n) = (
141 sample_statistics.sample_response_mean(),
142 sample_statistics.sample_design_mean(),
143 sample_statistics.sample_response_covariance(),
144 sample_statistics.sample_cross_covariance(),
145 sample_statistics.sample_design_covariance(),
146 sample_statistics.sample_size(),
147 );
148
149 let (a, b, s) = if z.is_empty() {
151 let a = Array2::zeros((x.len(), 0));
153 let b = mu_x.clone();
154 let s = s_xx / n;
155 (a, b, s)
157 } else {
158 let s_zz_pinv = s_zz.pinv();
160 let a = s_xz.dot(&s_zz_pinv);
162 let b = mu_x - &a.dot(mu_z);
164 let s = (s_xx - &a.dot(&s_xz.t())) / n;
166 (a, b, s)
168 };
169
170 let p = x.len() as f64;
172 let (_, ln_det) = s.sln_det().expect("Failed to compute determinant of S.");
173 let sample_log_likelihood = -0.5 * n * (p * LN_2_PI + ln_det + p);
174
175 let parameters = GaussCPDP::new(a, b, s);
177
178 let conditioning_labels = z.iter().map(|&i| labels[i].clone()).collect();
180 let labels = x.iter().map(|&i| labels[i].clone()).collect();
182
183 let sample_statistics = Some(sample_statistics);
185 let sample_log_likelihood = Some(sample_log_likelihood);
187
188 GaussCPD::with_optionals(
190 labels,
191 conditioning_labels,
192 parameters,
193 sample_statistics,
194 sample_log_likelihood,
195 )
196 }
197}
198
199macro_for!($type in [GaussTable, GaussWtdTable] {
201
202 impl CPDEstimator<GaussCPD> for MLE<'_, $type> {
203 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPD {
204 let labels = self.dataset.labels();
206 let sample_statistics = SSE::new(self.dataset).fit(x, z);
208 MLE::<'_, GaussTable>::fit(labels, x, z, sample_statistics)
210 }
211 }
212
213 impl ParCPDEstimator<GaussCPD> for MLE<'_, $type> {
214 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> GaussCPD {
215 let labels = self.dataset.labels();
217 let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
219 MLE::<'_, GaussTable>::fit(labels, x, z, sample_statistics)
221 }
222 }
223
224});
225
226impl MLE<'_, CatTrj> {
227 fn fit(states: &States, x: &Set<usize>, z: &Set<usize>, sample_statistics: CatCIMS) -> CatCIM {
229 let n_xz = sample_statistics.sample_conditional_counts();
231 let t_xz = sample_statistics.sample_conditional_times();
232
233 assert!(
235 t_xz.iter().all(|&x| x > 0.),
236 "Failed to get non-zero conditional times."
237 );
238
239 let t_xz = &t_xz.clone().insert_axis(Axis(2));
241
242 let mut parameters = n_xz / t_xz;
244 parameters.outer_iter_mut().for_each(|mut q| {
246 q.diag_mut().fill(0.);
248 let q_neg_sum = -q.sum_axis(Axis(1));
250 q.diag_mut().assign(&q_neg_sum);
252 });
253
254 let eps = f64::MIN_POSITIVE;
256 let sample_log_likelihood = {
258 let ll_q_xz = {
260 let n_z = n_xz.sum_axis(Axis(2));
262 let t_z = t_xz.sum_axis(Axis(2));
263 let mut q_z = Array::zeros(n_z.dim());
265 parameters
267 .outer_iter()
268 .zip(q_z.outer_iter_mut())
269 .for_each(|(p, mut q)| {
270 q.assign(&(-&p.diag()));
271 });
272 (&n_z * (&q_z + eps).ln()).sum() + (-&q_z * &t_z).sum()
274 };
275 let ll_p_xz = {
277 let mut p_xz = parameters.clone();
279 p_xz.outer_iter_mut().for_each(|mut p| {
281 p.diag_mut().fill(0.);
283 });
284 p_xz /= &p_xz.sum_axis(Axis(2)).insert_axis(Axis(2));
286 (n_xz * (p_xz + eps).ln()).sum()
288 };
289 ll_q_xz + ll_p_xz
291 };
292
293 let conditioning_states = z
295 .iter()
296 .map(|&i| {
297 let (k, v) = states.get_index(i).unwrap();
298 (k.clone(), v.clone())
299 })
300 .collect();
301 let states = x
303 .iter()
304 .map(|&i| {
305 let (k, v) = states.get_index(i).unwrap();
306 (k.clone(), v.clone())
307 })
308 .collect();
309
310 let sample_statistics = Some(sample_statistics);
312 let sample_log_likelihood = Some(sample_log_likelihood);
314
315 CatCIM::with_optionals(
317 states,
318 conditioning_states,
319 parameters,
320 sample_statistics,
321 sample_log_likelihood,
322 )
323 }
324}
325
326macro_for!($type in [CatTrj, CatWtdTrj, CatTrjs, CatWtdTrjs] {
328
329 impl CIMEstimator<CatCIM> for MLE<'_, $type> {
330 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
331 let states = self.dataset.states();
333 let sample_statistics = SSE::new(self.dataset).fit(x, z);
335 MLE::<'_, CatTrj>::fit(states, x, z, sample_statistics)
337 }
338 }
339
340});
341
342macro_for!($type in [CatTrjs, CatWtdTrjs] {
344
345 impl ParCIMEstimator<CatCIM> for MLE<'_, $type> {
346 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
347 let states = self.dataset.states();
349 let sample_statistics = SSE::new(self.dataset).par_fit(x, z);
351 MLE::<'_, CatTrj>::fit(states, x, z, sample_statistics)
353 }
354 }
355
356});