1use std::ops::{Div, DivAssign, Mul, MulAssign};
2
3use approx::{AbsDiffEq, RelativeEq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6use ndarray_linalg::Determinant;
7
8use crate::{
9 datasets::{GaussEv, GaussEvT},
10 models::{CPD, GaussCPD, GaussCPDP, Labelled, Phi},
11 types::{LN_2_PI, Labels, Set},
12 utils::PseudoInverse,
13};
14
15#[derive(Clone, Debug)]
17pub struct GaussPhiK {
18 k: Array2<f64>,
20 h: Array1<f64>,
22 g: f64,
24}
25
26impl GaussPhiK {
27 pub fn new(k: Array2<f64>, h: Array1<f64>, g: f64) -> Self {
46 assert!(k.is_square(), "Precision matrix must be square.");
48 assert_eq!(
50 k.nrows(),
51 h.len(),
52 "Information vector length must match precision matrix size."
53 );
54 assert!(
56 k.iter().all(|x| x.is_finite()),
57 "Precision matrix must be finite."
58 );
59 assert_eq!(k, k.t(), "Precision matrix must be symmetric.");
61 assert!(
63 h.iter().all(|x| x.is_finite()),
64 "Information vector must be finite."
65 );
66 assert!(g.is_finite(), "Log-normalization constant must be finite.");
68
69 Self { k, h, g }
70 }
71
72 #[inline]
79 pub const fn precision_matrix(&self) -> &Array2<f64> {
80 &self.k
81 }
82
83 #[inline]
90 pub const fn information_vector(&self) -> &Array1<f64> {
91 &self.h
92 }
93
94 #[inline]
101 pub const fn log_normalization_constant(&self) -> f64 {
102 self.g
103 }
104}
105
106impl PartialEq for GaussPhiK {
107 fn eq(&self, other: &Self) -> bool {
108 self.k.eq(&other.k) && self.h.eq(&other.h) && self.g.eq(&other.g)
109 }
110}
111
112impl AbsDiffEq for GaussPhiK {
113 type Epsilon = f64;
114
115 fn default_epsilon() -> Self::Epsilon {
116 Self::Epsilon::default_epsilon()
117 }
118
119 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
120 self.k.abs_diff_eq(&other.k, epsilon)
121 && self.h.abs_diff_eq(&other.h, epsilon)
122 && self.g.abs_diff_eq(&other.g, epsilon)
123 }
124}
125
126impl RelativeEq for GaussPhiK {
127 fn default_max_relative() -> Self::Epsilon {
128 Self::Epsilon::default_max_relative()
129 }
130
131 fn relative_eq(
132 &self,
133 other: &Self,
134 epsilon: Self::Epsilon,
135 max_relative: Self::Epsilon,
136 ) -> bool {
137 self.k.relative_eq(&other.k, epsilon, max_relative)
138 && self.h.relative_eq(&other.h, epsilon, max_relative)
139 && self.g.relative_eq(&other.g, epsilon, max_relative)
140 }
141}
142
143#[derive(Clone, Debug)]
145pub struct GaussPhi {
146 labels: Labels,
148 parameters: GaussPhiK,
150}
151
152impl Labelled for GaussPhi {
153 #[inline]
154 fn labels(&self) -> &Labels {
155 &self.labels
156 }
157}
158
159impl PartialEq for GaussPhi {
160 fn eq(&self, other: &Self) -> bool {
161 self.labels.eq(&other.labels) && self.parameters.eq(&other.parameters)
162 }
163}
164
165impl AbsDiffEq for GaussPhi {
166 type Epsilon = f64;
167
168 fn default_epsilon() -> Self::Epsilon {
169 Self::Epsilon::default_epsilon()
170 }
171
172 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
173 self.labels.eq(&other.labels) && self.parameters.abs_diff_eq(&other.parameters, epsilon)
174 }
175}
176
177impl RelativeEq for GaussPhi {
178 fn default_max_relative() -> Self::Epsilon {
179 Self::Epsilon::default_max_relative()
180 }
181
182 fn relative_eq(
183 &self,
184 other: &Self,
185 epsilon: Self::Epsilon,
186 max_relative: Self::Epsilon,
187 ) -> bool {
188 self.labels.eq(&other.labels)
189 && self
190 .parameters
191 .relative_eq(&other.parameters, epsilon, max_relative)
192 }
193}
194
195impl MulAssign<&GaussPhi> for GaussPhi {
196 fn mul_assign(&mut self, rhs: &GaussPhi) {
197 let mut labels = self.labels.clone();
199 labels.extend(rhs.labels.clone());
200 labels.sort();
202
203 let n = labels.len();
205
206 let lhs_m: Vec<_> = labels.iter().map(|l| self.labels.get_index_of(l)).collect();
208 let lhs_k = Array::from_shape_fn((n, n), |(i, j)| match (lhs_m[i], lhs_m[j]) {
210 (Some(i), Some(j)) => self.parameters.k[[i, j]],
211 _ => 0.,
212 });
213 let lhs_h = Array::from_shape_fn(n, |i| match lhs_m[i] {
214 Some(i) => self.parameters.h[i],
215 _ => 0.,
216 });
217 let lhs_g = self.parameters.g;
218
219 let rhs_m: Vec<_> = labels.iter().map(|l| rhs.labels.get_index_of(l)).collect();
221 let rhs_k = Array::from_shape_fn((n, n), |(i, j)| match (rhs_m[i], rhs_m[j]) {
223 (Some(i), Some(j)) => rhs.parameters.k[[i, j]],
224 _ => 0.,
225 });
226 let rhs_h = Array::from_shape_fn(n, |i| match rhs_m[i] {
227 Some(i) => rhs.parameters.h[i],
228 _ => 0.,
229 });
230 let rhs_g = rhs.parameters.g;
231
232 let k = lhs_k + rhs_k;
234 let h = lhs_h + rhs_h;
235 let g = lhs_g + rhs_g;
236 let parameters = GaussPhiK::new(k, h, g);
238
239 self.labels = labels;
241 self.parameters = parameters;
243 }
244}
245
246impl Mul<&GaussPhi> for &GaussPhi {
247 type Output = GaussPhi;
248
249 #[inline]
250 fn mul(self, rhs: &GaussPhi) -> Self::Output {
251 let mut lhs = self.clone();
252 lhs *= rhs;
253 lhs
254 }
255}
256
257impl DivAssign<&GaussPhi> for GaussPhi {
258 fn div_assign(&mut self, rhs: &GaussPhi) {
259 let mut labels = self.labels.clone();
261 labels.extend(rhs.labels.clone());
262 labels.sort();
264
265 let n = labels.len();
267
268 let lhs_m: Vec<_> = labels.iter().map(|l| self.labels.get_index_of(l)).collect();
270 let lhs_k = Array::from_shape_fn((n, n), |(i, j)| match (lhs_m[i], lhs_m[j]) {
272 (Some(i), Some(j)) => self.parameters.k[[i, j]],
273 _ => 0.,
274 });
275 let lhs_h = Array::from_shape_fn(n, |i| match lhs_m[i] {
276 Some(i) => self.parameters.h[i],
277 _ => 0.,
278 });
279 let lhs_g = self.parameters.g;
280
281 let rhs_m: Vec<_> = labels.iter().map(|l| rhs.labels.get_index_of(l)).collect();
283 let rhs_k = Array::from_shape_fn((n, n), |(i, j)| match (rhs_m[i], rhs_m[j]) {
285 (Some(i), Some(j)) => rhs.parameters.k[[i, j]],
286 _ => 0.,
287 });
288 let rhs_h = Array::from_shape_fn(n, |i| match rhs_m[i] {
289 Some(i) => rhs.parameters.h[i],
290 _ => 0.,
291 });
292 let rhs_g = rhs.parameters.g;
293
294 let k_prime = lhs_k - rhs_k;
296 let h_prime = lhs_h - rhs_h;
297 let g_prime = lhs_g - rhs_g;
298 let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
300
301 self.labels = labels;
303 self.parameters = parameters;
305 }
306}
307
308impl Div<&GaussPhi> for &GaussPhi {
309 type Output = GaussPhi;
310
311 #[inline]
312 fn div(self, rhs: &GaussPhi) -> Self::Output {
313 let mut lhs = self.clone();
314 lhs /= rhs;
315 lhs
316 }
317}
318
319impl Phi for GaussPhi {
320 type CPD = GaussCPD;
321 type Parameters = GaussPhiK;
322 type Evidence = GaussEv;
323
324 #[inline]
325 fn parameters(&self) -> &Self::Parameters {
326 &self.parameters
327 }
328
329 #[inline]
330 fn parameters_size(&self) -> usize {
331 let k = {
332 let k = self.parameters.k.nrows();
334 k * (k + 1) / 2
335 };
336
337 k + self.parameters.h.len() + 1
338 }
339
340 fn condition(&self, e: &Self::Evidence) -> Self {
341 assert_eq!(
343 e.labels(),
344 self.labels(),
345 "Failed to condition on evidence: \n\
346 \t expected: evidence labels to match potential labels , \n\
347 \t found: potential labels = {:?} , \n\
348 \t evidence labels = {:?} .",
349 self.labels(),
350 e.labels(),
351 );
352
353 let e = e.evidences().iter().flatten();
355 let e = e.cloned().map(|e| match e {
357 GaussEvT::CertainPositive { event, value } => (event, value),
358 });
360
361 let y: Set<_> = e.clone().map(|(event, _)| event).collect();
363 let x: Set<_> = &Set::from_iter(0..self.labels.len()) - &y;
364
365 let labels: Labels = x.iter().map(|&x| self.labels[x].clone()).collect();
367
368 let _y = Array::from_iter(e.map(|(_, value)| value));
370
371 let k = self.parameters.precision_matrix();
373 let h = self.parameters.information_vector();
375 let g = self.parameters.log_normalization_constant();
377
378 let k_prime = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
380 let h_prime = {
382 let k_xy = Array::from_shape_fn((x.len(), y.len()), |(i, j)| k[[x[i], y[j]]]);
384 let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
386 h_x - k_xy.dot(&_y)
388 };
389 let g_prime = {
391 let k_yy = Array::from_shape_fn((y.len(), y.len()), |(i, j)| k[[y[i], y[j]]]);
393 let h_y = Array::from_shape_fn(y.len(), |i| h[y[i]]);
395 g + h_y.dot(&_y) - 0.5 * _y.dot(&k_yy).dot(&_y)
397 };
398
399 let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
401
402 Self::new(labels, parameters)
404 }
405
406 fn marginalize(&self, x: &Set<usize>) -> Self {
407 if x.is_empty() {
409 return self.clone();
410 }
411
412 x.iter().for_each(|&x| {
414 assert!(
415 x < self.labels.len(),
416 "Variable index out of bounds: \n\
417 \t expected: x < {} , \n\
418 \t found: x == {} .",
419 self.labels.len(),
420 x,
421 );
422 });
423
424 let v: Set<_> = Set::from_iter(0..self.labels.len());
426 let z: Set<_> = &v - x;
427
428 let labels_z: Labels = z.iter().map(|&i| self.labels[i].clone()).collect();
430
431 let k = self.parameters.precision_matrix();
433 let h = self.parameters.information_vector();
435 let g = self.parameters.log_normalization_constant();
437
438 let s_xx = {
440 let k_xx = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
442 k_xx.pinv()
444 };
445 let k_zx = Array::from_shape_fn((z.len(), x.len()), |(i, j)| k[[z[i], x[j]]]);
447 let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
449
450 let k_zx_dot_s_xx = k_zx.dot(&s_xx);
452
453 let k_prime = {
455 let k_zz = Array::from_shape_fn((z.len(), z.len()), |(i, j)| k[[z[i], z[j]]]);
457 let k_xz = Array::from_shape_fn((x.len(), z.len()), |(i, j)| k[[x[i], z[j]]]);
458 k_zz - k_zx_dot_s_xx.dot(&k_xz)
460 };
461 let h_prime = {
463 let h_z = Array::from_shape_fn(z.len(), |i| h[z[i]]);
465 h_z - k_zx_dot_s_xx.dot(&h_x)
467 };
468 let g_prime = {
470 let n_ln_2_pi = s_xx.nrows() as f64 * LN_2_PI;
472 let (_, ln_det) = s_xx.sln_det().expect("Failed to compute the determinant.");
473 g + 0.5 * (n_ln_2_pi + ln_det + h_x.dot(&s_xx).dot(&h_x))
474 };
475
476 let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
478
479 Self::new(labels_z, parameters)
481 }
482
483 #[inline]
484 fn normalize(&self) -> Self {
485 self.clone()
487 }
488
489 fn from_cpd(cpd: Self::CPD) -> Self {
490 let mut labels = cpd.labels().clone();
492 labels.extend(cpd.conditioning_labels().clone());
493
494 let parameters = cpd.parameters();
496 let (a, b, s) = (
498 parameters.coefficients(),
499 parameters.intercept(),
500 parameters.covariance(),
501 );
502
503 let k_xx = s.pinv(); let k_xz = -&k_xx.dot(a); let k_zx = -a.t().dot(&k_xx); let k_zz = a.t().dot(&k_xx).dot(a); let k_prime = {
514 let (n, m) = (a.nrows(), a.ncols());
515 let mut k = Array::zeros((n + m, n + m));
516 k.slice_mut(s![0..n, 0..n]).assign(&k_xx);
517 k.slice_mut(s![0..n, n..n + m]).assign(&k_xz);
518 k.slice_mut(s![n..n + m, 0..n]).assign(&k_zx);
519 k.slice_mut(s![n..n + m, n..n + m]).assign(&k_zz);
520 k
521 };
522
523 let h_x = k_xx.dot(b); let h_z = k_zx.dot(b); let h_prime = {
532 let mut h = Array::zeros(h_x.len() + h_z.len());
533 h.slice_mut(s![0..h_x.len()]).assign(&h_x);
534 h.slice_mut(s![h_x.len()..]).assign(&h_z);
535 h
536 };
537
538 let g_prime = {
540 let n_ln_2_pi = s.nrows() as f64 * LN_2_PI;
541 let (_, ln_det) = s.sln_det().expect("Failed to compute the determinant.");
542 -0.5 * (n_ln_2_pi + ln_det + b.dot(&h_x))
543 };
544
545 let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
547
548 Self::new(labels, parameters)
550 }
551
552 fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD {
553 assert!(
555 x.is_disjoint(z),
556 "Variables and conditioning variables must be disjoint."
557 );
558 assert!(
560 (x | z).iter().sorted().cloned().eq(0..self.labels.len()),
561 "Variables and conditioning variables must cover all potential variables."
562 );
563
564 let labels_x: Labels = x.iter().map(|&i| self.labels[i].clone()).collect();
566 let labels_z: Labels = z.iter().map(|&i| self.labels[i].clone()).collect();
567
568 let k = self.parameters.precision_matrix();
570 let h = self.parameters.information_vector();
572
573 let s = {
575 let k_xx = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
577 k_xx.pinv()
579 };
580 let a = {
582 let k_xz = Array::from_shape_fn((x.len(), z.len()), |(i, j)| k[[x[i], z[j]]]);
584 -s.dot(&k_xz)
586 };
587 let b = {
589 let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
591 s.dot(&h_x)
593 };
594
595 let parameters = GaussCPDP::new(a, b, s);
597
598 GaussCPD::new(labels_x, labels_z, parameters)
600 }
601}
602
603impl GaussPhi {
604 pub fn new(mut labels: Labels, mut parameters: GaussPhiK) -> Self {
616 assert_eq!(
618 parameters.precision_matrix().nrows(),
619 labels.len(),
620 "Precision matrix rows must match labels length."
621 );
622 assert_eq!(
623 parameters.information_vector().len(),
624 labels.len(),
625 "Information vector length must match labels length."
626 );
627
628 if !labels.is_sorted() {
630 let mut indices: Vec<_> = (0..labels.len()).collect();
632 indices.sort_by_key(|&i| labels.get_index(i).unwrap());
633 labels.sort();
635
636 let mut k = parameters.k.clone();
638 for (i, &j) in indices.iter().enumerate() {
640 k.row_mut(i).assign(¶meters.k.row(j));
641 }
642 parameters.k = k.clone();
643 for (i, &j) in indices.iter().enumerate() {
645 k.column_mut(i).assign(¶meters.k.column(j));
646 }
647 parameters.k = k;
648
649 let mut h = parameters.h.clone();
651 for (i, &j) in indices.iter().enumerate() {
653 h[i] = parameters.h[j];
654 }
655 parameters.h = h;
656 }
657
658 Self { labels, parameters }
659 }
660}