ppca/
prior.rs

1use nalgebra::{DMatrix, DVector};
2use serde_derive::{Deserialize, Serialize};
3
4/// A prior for the PPCA model. Use this class to mitigate overfit on training (especially on
5/// frequently masked dimensions) and to input _a priori_ knowledge on what the PPCA should look
6/// like.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Prior {
9    mean: Option<DVector<f64>>,
10    mean_covariance: Option<DMatrix<f64>>,
11    mean_precision: Option<DMatrix<f64>>,
12    isotropic_noise_alpha: Option<f64>,
13    isotropic_noise_beta: Option<f64>,
14    transformation_precision: f64,
15}
16
17impl Default for Prior {
18    fn default() -> Prior {
19        Prior {
20            mean: None,
21            mean_covariance: None,
22            mean_precision: None,
23            isotropic_noise_alpha: None,
24            isotropic_noise_beta: None,
25            transformation_precision: 0.0,
26        }
27    }
28}
29
30impl Prior {
31    /// Add a prior to the mean of the PPCA. The prior is a normal multivariate distribution.
32    pub fn with_mean_prior(mut self, mean: DVector<f64>, mean_covariance: DMatrix<f64>) -> Self {
33        assert_eq!(mean.len(), mean_covariance.nrows());
34        assert_eq!(mean.len(), mean_covariance.ncols());
35        self.mean = Some(mean);
36        self.mean_precision = Some(
37            mean_covariance
38                .clone()
39                .try_inverse()
40                .expect("mean covariance should be invertible"),
41        );
42        self.mean_covariance = Some(mean_covariance);
43
44        self
45    }
46
47    /// Add an isotropic noise prior. The prior is an Inverse Gamma distribution with shape `alpha`
48    /// and rate `beta`.
49    pub fn with_isotropic_noise_prior(mut self, alpha: f64, beta: f64) -> Self {
50        assert!(alpha >= 0.0);
51        assert!(beta >= 0.0);
52        self.isotropic_noise_alpha = Some(alpha);
53        self.isotropic_noise_beta = Some(beta);
54
55        self
56    }
57
58    /// Impose an independent Normal prior to each dimension of the transformation matrix. The
59    /// precision is the inverse of the variance of the Normal distribution (`1 / sigma ^ 2`).
60    pub fn with_transformation_precision(mut self, precision: f64) -> Self {
61        assert!(precision >= 0.0);
62        self.transformation_precision = precision;
63
64        self
65    }
66
67    pub fn mean(&self) -> Option<&DVector<f64>> {
68        self.mean.as_ref()
69    }
70
71    pub fn mean_covariance(&self) -> Option<&DMatrix<f64>> {
72        self.mean_covariance.as_ref()
73    }
74
75    pub fn has_isotropic_noise_prior(&self) -> bool {
76        self.isotropic_noise_alpha.is_some()
77    }
78
79    pub fn isotropic_noise_alpha(&self) -> f64 {
80        self.isotropic_noise_alpha
81            .expect("isotropic noise prior not set")
82    }
83
84    pub fn isotropic_noise_beta(&self) -> f64 {
85        self.isotropic_noise_beta
86            .expect("isotropic noise prior not set")
87    }
88
89    pub fn transformation_precision(&self) -> f64 {
90        self.transformation_precision
91    }
92
93    pub fn has_mean_prior(&self) -> bool {
94        self.mean.is_some()
95    }
96
97    pub(crate) fn smooth_mean(&self, mean: DVector<f64>, precision: DMatrix<f64>) -> DVector<f64> {
98        let (prior_mean, prior_precision) = self
99            .mean
100            .as_ref()
101            .zip(self.mean_precision.as_ref())
102            .expect("mean prior not set");
103        let total_precision = prior_precision + &precision;
104        let numerator = prior_precision * prior_mean + &precision * mean;
105
106        total_precision
107            .qr()
108            .solve(&numerator)
109            .expect("total precision matrix is always invertible")
110    }
111}