causal_hub/models/bayesian_network/gaussian/
potential.rs

1use std::ops::{Div, DivAssign, Mul, MulAssign};
2
3use approx::{AbsDiffEq, RelativeEq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6use ndarray_linalg::Determinant;
7
8use crate::{
9    datasets::{GaussEv, GaussEvT},
10    models::{CPD, GaussCPD, GaussCPDP, Labelled, Phi},
11    types::{LN_2_PI, Labels, Set},
12    utils::PseudoInverse,
13};
14
15/// Parameters of a Gaussian potential.
16#[derive(Clone, Debug)]
17pub struct GaussPhiK {
18    /// Precision matrix |X| x |X|.
19    k: Array2<f64>,
20    /// Information vector |X|.
21    h: Array1<f64>,
22    /// Log-normalization constant.
23    g: f64,
24}
25
26impl GaussPhiK {
27    /// Creates a new Gaussian potential with the given parameters.
28    ///
29    /// # Arguments
30    ///
31    /// * `k` - Precision matrix |X| x |X|.
32    /// * `h` - Information vector |X|.
33    /// * `g` - Log-normalization constant.
34    ///
35    /// # Panics
36    ///
37    /// * Panics if `k` is not square and symmetric.
38    /// * Panics if the length of `h` does not match the size of `k`.
39    /// * Panics if `k`, `h`, or `g` contain non-finite values.
40    ///
41    /// # Results
42    ///
43    /// A new Gaussian potential instance.
44    ///
45    pub fn new(k: Array2<f64>, h: Array1<f64>, g: f64) -> Self {
46        // Assert K is square.
47        assert!(k.is_square(), "Precision matrix must be square.");
48        // Assert the length of h matches the size of K.
49        assert_eq!(
50            k.nrows(),
51            h.len(),
52            "Information vector length must match precision matrix size."
53        );
54        // Assert K is finite.
55        assert!(
56            k.iter().all(|x| x.is_finite()),
57            "Precision matrix must be finite."
58        );
59        // Assert K is symmetric.
60        assert_eq!(k, k.t(), "Precision matrix must be symmetric.");
61        // Assert h is finite.
62        assert!(
63            h.iter().all(|x| x.is_finite()),
64            "Information vector must be finite."
65        );
66        // Assert g is finite.
67        assert!(g.is_finite(), "Log-normalization constant must be finite.");
68
69        Self { k, h, g }
70    }
71
72    /// Returns the precision matrix.
73    ///
74    /// # Returns
75    ///
76    /// A reference to the precision matrix.
77    ///    
78    #[inline]
79    pub const fn precision_matrix(&self) -> &Array2<f64> {
80        &self.k
81    }
82
83    /// Returns the information vector.
84    ///
85    /// # Returns
86    ///
87    /// A reference to the information vector.
88    ///
89    #[inline]
90    pub const fn information_vector(&self) -> &Array1<f64> {
91        &self.h
92    }
93
94    /// Returns the log-normalization constant.
95    ///
96    /// # Returns
97    ///
98    /// The log-normalization constant.
99    ///
100    #[inline]
101    pub const fn log_normalization_constant(&self) -> f64 {
102        self.g
103    }
104}
105
106impl PartialEq for GaussPhiK {
107    fn eq(&self, other: &Self) -> bool {
108        self.k.eq(&other.k) && self.h.eq(&other.h) && self.g.eq(&other.g)
109    }
110}
111
112impl AbsDiffEq for GaussPhiK {
113    type Epsilon = f64;
114
115    fn default_epsilon() -> Self::Epsilon {
116        Self::Epsilon::default_epsilon()
117    }
118
119    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
120        self.k.abs_diff_eq(&other.k, epsilon)
121            && self.h.abs_diff_eq(&other.h, epsilon)
122            && self.g.abs_diff_eq(&other.g, epsilon)
123    }
124}
125
126impl RelativeEq for GaussPhiK {
127    fn default_max_relative() -> Self::Epsilon {
128        Self::Epsilon::default_max_relative()
129    }
130
131    fn relative_eq(
132        &self,
133        other: &Self,
134        epsilon: Self::Epsilon,
135        max_relative: Self::Epsilon,
136    ) -> bool {
137        self.k.relative_eq(&other.k, epsilon, max_relative)
138            && self.h.relative_eq(&other.h, epsilon, max_relative)
139            && self.g.relative_eq(&other.g, epsilon, max_relative)
140    }
141}
142
143/// A Gaussian potential.
144#[derive(Clone, Debug)]
145pub struct GaussPhi {
146    // Labels of the variables.
147    labels: Labels,
148    // Parameters.
149    parameters: GaussPhiK,
150}
151
152impl Labelled for GaussPhi {
153    #[inline]
154    fn labels(&self) -> &Labels {
155        &self.labels
156    }
157}
158
159impl PartialEq for GaussPhi {
160    fn eq(&self, other: &Self) -> bool {
161        self.labels.eq(&other.labels) && self.parameters.eq(&other.parameters)
162    }
163}
164
165impl AbsDiffEq for GaussPhi {
166    type Epsilon = f64;
167
168    fn default_epsilon() -> Self::Epsilon {
169        Self::Epsilon::default_epsilon()
170    }
171
172    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
173        self.labels.eq(&other.labels) && self.parameters.abs_diff_eq(&other.parameters, epsilon)
174    }
175}
176
177impl RelativeEq for GaussPhi {
178    fn default_max_relative() -> Self::Epsilon {
179        Self::Epsilon::default_max_relative()
180    }
181
182    fn relative_eq(
183        &self,
184        other: &Self,
185        epsilon: Self::Epsilon,
186        max_relative: Self::Epsilon,
187    ) -> bool {
188        self.labels.eq(&other.labels)
189            && self
190                .parameters
191                .relative_eq(&other.parameters, epsilon, max_relative)
192    }
193}
194
195impl MulAssign<&GaussPhi> for GaussPhi {
196    fn mul_assign(&mut self, rhs: &GaussPhi) {
197        // Get the union of the labels.
198        let mut labels = self.labels.clone();
199        labels.extend(rhs.labels.clone());
200        // Sort the labels.
201        labels.sort();
202
203        // Get the number of variables.
204        let n = labels.len();
205
206        // Order LHS indices w.r.t. new labels.
207        let lhs_m: Vec<_> = labels.iter().map(|l| self.labels.get_index_of(l)).collect();
208        // Allocate extended LHS parameters.
209        let lhs_k = Array::from_shape_fn((n, n), |(i, j)| match (lhs_m[i], lhs_m[j]) {
210            (Some(i), Some(j)) => self.parameters.k[[i, j]],
211            _ => 0.,
212        });
213        let lhs_h = Array::from_shape_fn(n, |i| match lhs_m[i] {
214            Some(i) => self.parameters.h[i],
215            _ => 0.,
216        });
217        let lhs_g = self.parameters.g;
218
219        // Order RHS indices w.r.t. new labels.
220        let rhs_m: Vec<_> = labels.iter().map(|l| rhs.labels.get_index_of(l)).collect();
221        // Allocate extended RHS parameters.
222        let rhs_k = Array::from_shape_fn((n, n), |(i, j)| match (rhs_m[i], rhs_m[j]) {
223            (Some(i), Some(j)) => rhs.parameters.k[[i, j]],
224            _ => 0.,
225        });
226        let rhs_h = Array::from_shape_fn(n, |i| match rhs_m[i] {
227            Some(i) => rhs.parameters.h[i],
228            _ => 0.,
229        });
230        let rhs_g = rhs.parameters.g;
231
232        // Sum parameters.
233        let k = lhs_k + rhs_k;
234        let h = lhs_h + rhs_h;
235        let g = lhs_g + rhs_g;
236        // Assemble parameters.
237        let parameters = GaussPhiK::new(k, h, g);
238
239        // Update the labels.
240        self.labels = labels;
241        // Update the parameters.
242        self.parameters = parameters;
243    }
244}
245
246impl Mul<&GaussPhi> for &GaussPhi {
247    type Output = GaussPhi;
248
249    #[inline]
250    fn mul(self, rhs: &GaussPhi) -> Self::Output {
251        let mut lhs = self.clone();
252        lhs *= rhs;
253        lhs
254    }
255}
256
257impl DivAssign<&GaussPhi> for GaussPhi {
258    fn div_assign(&mut self, rhs: &GaussPhi) {
259        // Get the union of the labels.
260        let mut labels = self.labels.clone();
261        labels.extend(rhs.labels.clone());
262        // Sort the labels.
263        labels.sort();
264
265        // Get the number of variables.
266        let n = labels.len();
267
268        // Order LHS indices w.r.t. new labels.
269        let lhs_m: Vec<_> = labels.iter().map(|l| self.labels.get_index_of(l)).collect();
270        // Allocate extended LHS parameters.
271        let lhs_k = Array::from_shape_fn((n, n), |(i, j)| match (lhs_m[i], lhs_m[j]) {
272            (Some(i), Some(j)) => self.parameters.k[[i, j]],
273            _ => 0.,
274        });
275        let lhs_h = Array::from_shape_fn(n, |i| match lhs_m[i] {
276            Some(i) => self.parameters.h[i],
277            _ => 0.,
278        });
279        let lhs_g = self.parameters.g;
280
281        // Order RHS indices w.r.t. new labels.
282        let rhs_m: Vec<_> = labels.iter().map(|l| rhs.labels.get_index_of(l)).collect();
283        // Allocate extended RHS parameters.
284        let rhs_k = Array::from_shape_fn((n, n), |(i, j)| match (rhs_m[i], rhs_m[j]) {
285            (Some(i), Some(j)) => rhs.parameters.k[[i, j]],
286            _ => 0.,
287        });
288        let rhs_h = Array::from_shape_fn(n, |i| match rhs_m[i] {
289            Some(i) => rhs.parameters.h[i],
290            _ => 0.,
291        });
292        let rhs_g = rhs.parameters.g;
293
294        // Sum parameters.
295        let k_prime = lhs_k - rhs_k;
296        let h_prime = lhs_h - rhs_h;
297        let g_prime = lhs_g - rhs_g;
298        // Assemble parameters.
299        let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
300
301        // Update the labels.
302        self.labels = labels;
303        // Update the parameters.
304        self.parameters = parameters;
305    }
306}
307
308impl Div<&GaussPhi> for &GaussPhi {
309    type Output = GaussPhi;
310
311    #[inline]
312    fn div(self, rhs: &GaussPhi) -> Self::Output {
313        let mut lhs = self.clone();
314        lhs /= rhs;
315        lhs
316    }
317}
318
319impl Phi for GaussPhi {
320    type CPD = GaussCPD;
321    type Parameters = GaussPhiK;
322    type Evidence = GaussEv;
323
324    #[inline]
325    fn parameters(&self) -> &Self::Parameters {
326        &self.parameters
327    }
328
329    #[inline]
330    fn parameters_size(&self) -> usize {
331        let k = {
332            // Precision matrix is symmetric.
333            let k = self.parameters.k.nrows();
334            k * (k + 1) / 2
335        };
336
337        k + self.parameters.h.len() + 1
338    }
339
340    fn condition(&self, e: &Self::Evidence) -> Self {
341        // Assert that the evidence labels match the potential labels.
342        assert_eq!(
343            e.labels(),
344            self.labels(),
345            "Failed to condition on evidence: \n\
346            \t expected:    evidence labels to match potential labels , \n\
347            \t found:       potential labels = {:?} , \n\
348            \t              evidence  labels = {:?} .",
349            self.labels(),
350            e.labels(),
351        );
352
353        // Get the evidence and remove nones.
354        let e = e.evidences().iter().flatten();
355        // Assert that the evidence is certain and positive.
356        let e = e.cloned().map(|e| match e {
357            GaussEvT::CertainPositive { event, value } => (event, value),
358            /* _ => panic! NOTE: No other variant so far. */
359        });
360
361        // Get X and Y from the evidence.
362        let y: Set<_> = e.clone().map(|(event, _)| event).collect();
363        let x: Set<_> = &Set::from_iter(0..self.labels.len()) - &y;
364
365        // Select the labels of the conditioned potential.
366        let labels: Labels = x.iter().map(|&x| self.labels[x].clone()).collect();
367
368        // Get the values from the evidence.
369        let _y = Array::from_iter(e.map(|(_, value)| value));
370
371        // Get the precision matrix.
372        let k = self.parameters.precision_matrix();
373        // Get the information vector.
374        let h = self.parameters.information_vector();
375        // Get the log-normalization constant.
376        let g = self.parameters.log_normalization_constant();
377
378        // Compute the precision matrix as K_xx from K and X.
379        let k_prime = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
380        // Compute the information vector.
381        let h_prime = {
382            // Get K_xy from K, X and Y.
383            let k_xy = Array::from_shape_fn((x.len(), y.len()), |(i, j)| k[[x[i], y[j]]]);
384            // Get h_x from h and X.
385            let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
386            // Compute h as: h' = h_x - K_xy * y.
387            h_x - k_xy.dot(&_y)
388        };
389        // Compute the log-normalization constant.
390        let g_prime = {
391            // Get K_yy from K and Y.
392            let k_yy = Array::from_shape_fn((y.len(), y.len()), |(i, j)| k[[y[i], y[j]]]);
393            // Get h_y from h and Y.
394            let h_y = Array::from_shape_fn(y.len(), |i| h[y[i]]);
395            // Compute g as: g' = g + h_y^T * y - 0.5 * y^T * K_yy * y.
396            g + h_y.dot(&_y) - 0.5 * _y.dot(&k_yy).dot(&_y)
397        };
398
399        // Assemble the parameters.
400        let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
401
402        // Return the conditioned potential.
403        Self::new(labels, parameters)
404    }
405
406    fn marginalize(&self, x: &Set<usize>) -> Self {
407        // Base case: if no variables to marginalize, return self.
408        if x.is_empty() {
409            return self.clone();
410        }
411
412        // Assert X is a subset of the variables.
413        x.iter().for_each(|&x| {
414            assert!(
415                x < self.labels.len(),
416                "Variable index out of bounds: \n\
417                \t expected:    x <  {} , \n\
418                \t found:       x == {} .",
419                self.labels.len(),
420                x,
421            );
422        });
423
424        // Get Z as V \ X.
425        let v: Set<_> = Set::from_iter(0..self.labels.len());
426        let z: Set<_> = &v - x;
427
428        // Get the labels of the marginalized potential.
429        let labels_z: Labels = z.iter().map(|&i| self.labels[i].clone()).collect();
430
431        // Get the precision matrix.
432        let k = self.parameters.precision_matrix();
433        // Get the information vector.
434        let h = self.parameters.information_vector();
435        // Get the log-normalization constant.
436        let g = self.parameters.log_normalization_constant();
437
438        // Compute the covariance matrix as: S_xx = (K_xx)^(-1).
439        let s_xx = {
440            // Get K_xx from K and X.
441            let k_xx = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
442            // Compute the covariance as: S = (K_xx)^(-1)
443            k_xx.pinv()
444        };
445        // Get K_zx from K, Z and X.
446        let k_zx = Array::from_shape_fn((z.len(), x.len()), |(i, j)| k[[z[i], x[j]]]);
447        // Get h_x from h and X.
448        let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
449
450        // Compute K_zx * S_xx once.
451        let k_zx_dot_s_xx = k_zx.dot(&s_xx);
452
453        // Compute the marginalized precision matrix.
454        let k_prime = {
455            // Get K_zz and K_xz from K, X and Z.
456            let k_zz = Array::from_shape_fn((z.len(), z.len()), |(i, j)| k[[z[i], z[j]]]);
457            let k_xz = Array::from_shape_fn((x.len(), z.len()), |(i, j)| k[[x[i], z[j]]]);
458            // Compute the precision matrix as: K' = K_zz - K_zx * (K_xx)^(-1) * K_xz
459            k_zz - k_zx_dot_s_xx.dot(&k_xz)
460        };
461        // Compute the marginalized information vector.
462        let h_prime = {
463            // Get h_z from h, X and Z.
464            let h_z = Array::from_shape_fn(z.len(), |i| h[z[i]]);
465            // Compute the information vector as: h' = h_z - K_zx * (K_xx)^(-1) * h_x
466            h_z - k_zx_dot_s_xx.dot(&h_x)
467        };
468        // Compute the marginalized log-normalization constant.
469        let g_prime = {
470            // Compute the log-normalization constant as: g' = g + 0.5 * (ln|2 pi (K_xx)^-1| + h_x^T * (K_xx)^-1 * h_x)
471            let n_ln_2_pi = s_xx.nrows() as f64 * LN_2_PI;
472            let (_, ln_det) = s_xx.sln_det().expect("Failed to compute the determinant.");
473            g + 0.5 * (n_ln_2_pi + ln_det + h_x.dot(&s_xx).dot(&h_x))
474        };
475
476        // Assemble the parameters.
477        let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
478
479        // Return the marginalized potential.
480        Self::new(labels_z, parameters)
481    }
482
483    #[inline]
484    fn normalize(&self) -> Self {
485        // The potential is already normalized.
486        self.clone()
487    }
488
489    fn from_cpd(cpd: Self::CPD) -> Self {
490        // Merge labels and conditioning labels in this order.
491        let mut labels = cpd.labels().clone();
492        labels.extend(cpd.conditioning_labels().clone());
493
494        // Get the parameters from the CPD.
495        let parameters = cpd.parameters();
496        // Get the coefficients and covariance.
497        let (a, b, s) = (
498            parameters.coefficients(),
499            parameters.intercept(),
500            parameters.covariance(),
501        );
502
503        // Compute the precision matrix as:
504        //
505        // | K_xx  K_xz |
506        // | K_zx  K_zz |
507        //
508        let k_xx = s.pinv(); //                 Precision of X.
509        let k_xz = -&k_xx.dot(a); //            Cross-precision of X and Z.    
510        let k_zx = -a.t().dot(&k_xx); //        Cross-precision of Z and X.
511        let k_zz = a.t().dot(&k_xx).dot(a); //  Induced precision of Z.
512        // Assemble the precision matrix.
513        let k_prime = {
514            let (n, m) = (a.nrows(), a.ncols());
515            let mut k = Array::zeros((n + m, n + m));
516            k.slice_mut(s![0..n, 0..n]).assign(&k_xx);
517            k.slice_mut(s![0..n, n..n + m]).assign(&k_xz);
518            k.slice_mut(s![n..n + m, 0..n]).assign(&k_zx);
519            k.slice_mut(s![n..n + m, n..n + m]).assign(&k_zz);
520            k
521        };
522
523        // Compute the information vector as:
524        //
525        // | h_x | = | K_xx * b |
526        // | h_z | = | K_zx * b |
527        //
528        let h_x = k_xx.dot(b); // Information of X.
529        let h_z = k_zx.dot(b); // Information of Z.
530        // Assemble the information vector.
531        let h_prime = {
532            let mut h = Array::zeros(h_x.len() + h_z.len());
533            h.slice_mut(s![0..h_x.len()]).assign(&h_x);
534            h.slice_mut(s![h_x.len()..]).assign(&h_z);
535            h
536        };
537
538        // Compute the log-normalization constant.
539        let g_prime = {
540            let n_ln_2_pi = s.nrows() as f64 * LN_2_PI;
541            let (_, ln_det) = s.sln_det().expect("Failed to compute the determinant.");
542            -0.5 * (n_ln_2_pi + ln_det + b.dot(&h_x))
543        };
544
545        // Construct the parameters.
546        let parameters = GaussPhiK::new(k_prime, h_prime, g_prime);
547
548        // Return the potential.
549        Self::new(labels, parameters)
550    }
551
552    fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD {
553        // Assert that X and Z are disjoint.
554        assert!(
555            x.is_disjoint(z),
556            "Variables and conditioning variables must be disjoint."
557        );
558        // Assert that X and Z cover all variables.
559        assert!(
560            (x | z).iter().sorted().cloned().eq(0..self.labels.len()),
561            "Variables and conditioning variables must cover all potential variables."
562        );
563
564        // Split labels into labels and conditioning labels.
565        let labels_x: Labels = x.iter().map(|&i| self.labels[i].clone()).collect();
566        let labels_z: Labels = z.iter().map(|&i| self.labels[i].clone()).collect();
567
568        // Get the precision matrix.
569        let k = self.parameters.precision_matrix();
570        // Get the information vector.
571        let h = self.parameters.information_vector();
572
573        // Compute the covariance matrix.
574        let s = {
575            // Get K_xx from K and X.
576            let k_xx = Array::from_shape_fn((x.len(), x.len()), |(i, j)| k[[x[i], x[j]]]);
577            // Compute the covariance as: S = (K_xx)^(-1)
578            k_xx.pinv()
579        };
580        // Compute the coefficient matrix.
581        let a = {
582            // Get K_xz from K, X, and Z.
583            let k_xz = Array::from_shape_fn((x.len(), z.len()), |(i, j)| k[[x[i], z[j]]]);
584            // Compute the coefficients as: A = - (K_xx)^(-1) * K_xz
585            -s.dot(&k_xz)
586        };
587        // Compute the intercept vector.
588        let b = {
589            // Get h_x from h and X.
590            let h_x = Array::from_shape_fn(x.len(), |i| h[x[i]]);
591            // Compute the intercept as: b = (K_xx)^(-1) * h_x
592            s.dot(&h_x)
593        };
594
595        // Assemble the parameters.
596        let parameters = GaussCPDP::new(a, b, s);
597
598        // Create the new CPD.
599        GaussCPD::new(labels_x, labels_z, parameters)
600    }
601}
602
603impl GaussPhi {
604    /// Creates a new Gaussian potential with the given labels and parameters.
605    ///
606    /// # Arguments
607    ///
608    /// * `labels` - Labels of the variables.
609    /// * `parameters` - Parameters of the potential.
610    ///
611    /// # Results
612    ///
613    /// A new Gaussian potential instance.
614    ///
615    pub fn new(mut labels: Labels, mut parameters: GaussPhiK) -> Self {
616        // Assert parameters shape matches labels length.
617        assert_eq!(
618            parameters.precision_matrix().nrows(),
619            labels.len(),
620            "Precision matrix rows must match labels length."
621        );
622        assert_eq!(
623            parameters.information_vector().len(),
624            labels.len(),
625            "Information vector length must match labels length."
626        );
627
628        // Sort labels if not sorted and permute parameters accordingly.
629        if !labels.is_sorted() {
630            // Get the new indices order w.r.t. sorted labels.
631            let mut indices: Vec<_> = (0..labels.len()).collect();
632            indices.sort_by_key(|&i| labels.get_index(i).unwrap());
633            // Sort the labels.
634            labels.sort();
635
636            // Clone the precision matrix.
637            let mut k = parameters.k.clone();
638            // Permute the precision matrix rows.
639            for (i, &j) in indices.iter().enumerate() {
640                k.row_mut(i).assign(&parameters.k.row(j));
641            }
642            parameters.k = k.clone();
643            // Permute the precision matrix columns.
644            for (i, &j) in indices.iter().enumerate() {
645                k.column_mut(i).assign(&parameters.k.column(j));
646            }
647            parameters.k = k;
648
649            // Clone the information vector.
650            let mut h = parameters.h.clone();
651            // Permute the information vector.
652            for (i, &j) in indices.iter().enumerate() {
653                h[i] = parameters.h[j];
654            }
655            parameters.h = h;
656        }
657
658        Self { labels, parameters }
659    }
660}