1use nalgebra::{DMatrix, DVector};
2use serde_derive::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Prior {
9 mean: Option<DVector<f64>>,
10 mean_covariance: Option<DMatrix<f64>>,
11 mean_precision: Option<DMatrix<f64>>,
12 isotropic_noise_alpha: Option<f64>,
13 isotropic_noise_beta: Option<f64>,
14 transformation_precision: f64,
15}
16
17impl Default for Prior {
18 fn default() -> Prior {
19 Prior {
20 mean: None,
21 mean_covariance: None,
22 mean_precision: None,
23 isotropic_noise_alpha: None,
24 isotropic_noise_beta: None,
25 transformation_precision: 0.0,
26 }
27 }
28}
29
30impl Prior {
31 pub fn with_mean_prior(mut self, mean: DVector<f64>, mean_covariance: DMatrix<f64>) -> Self {
33 assert_eq!(mean.len(), mean_covariance.nrows());
34 assert_eq!(mean.len(), mean_covariance.ncols());
35 self.mean = Some(mean);
36 self.mean_precision = Some(
37 mean_covariance
38 .clone()
39 .try_inverse()
40 .expect("mean covariance should be invertible"),
41 );
42 self.mean_covariance = Some(mean_covariance);
43
44 self
45 }
46
47 pub fn with_isotropic_noise_prior(mut self, alpha: f64, beta: f64) -> Self {
50 assert!(alpha >= 0.0);
51 assert!(beta >= 0.0);
52 self.isotropic_noise_alpha = Some(alpha);
53 self.isotropic_noise_beta = Some(beta);
54
55 self
56 }
57
58 pub fn with_transformation_precision(mut self, precision: f64) -> Self {
61 assert!(precision >= 0.0);
62 self.transformation_precision = precision;
63
64 self
65 }
66
67 pub fn mean(&self) -> Option<&DVector<f64>> {
68 self.mean.as_ref()
69 }
70
71 pub fn mean_covariance(&self) -> Option<&DMatrix<f64>> {
72 self.mean_covariance.as_ref()
73 }
74
75 pub fn has_isotropic_noise_prior(&self) -> bool {
76 self.isotropic_noise_alpha.is_some()
77 }
78
79 pub fn isotropic_noise_alpha(&self) -> f64 {
80 self.isotropic_noise_alpha
81 .expect("isotropic noise prior not set")
82 }
83
84 pub fn isotropic_noise_beta(&self) -> f64 {
85 self.isotropic_noise_beta
86 .expect("isotropic noise prior not set")
87 }
88
89 pub fn transformation_precision(&self) -> f64 {
90 self.transformation_precision
91 }
92
93 pub fn has_mean_prior(&self) -> bool {
94 self.mean.is_some()
95 }
96
97 pub(crate) fn smooth_mean(&self, mean: DVector<f64>, precision: DMatrix<f64>) -> DVector<f64> {
98 let (prior_mean, prior_precision) = self
99 .mean
100 .as_ref()
101 .zip(self.mean_precision.as_ref())
102 .expect("mean prior not set");
103 let total_precision = prior_precision + &precision;
104 let numerator = prior_precision * prior_mean + &precision * mean;
105
106 total_precision
107 .qr()
108 .solve(&numerator)
109 .expect("total precision matrix is always invertible")
110 }
111}