friedrich/parameters/
prior.rs

1//! Prior
2//!
3//! When asked to predict a value for an input that is too dissimilar to known inputs, the model will return the prior.
4//! Furthermore the process will be fitted on the residual of the prior meaning that a good prior can significantly improve the precision of the model.
5//!
6//! This can be a constant but also a polynomial or any model of the data.
7//! User-defined priors should implement the Prior trait.
8
9use crate::algebra::{SMatrix, SVector};
10use nalgebra::DVector;
11use nalgebra::{storage::Storage, Dynamic, U1};
12
13//---------------------------------------------------------------------------------------
14// TRAIT
15
16/// The Prior trait.
17///
18/// User-defined kernels should implement this trait.
19pub trait Prior
20{
21    /// Default value for the prior.
22    fn default(input_dimension: usize) -> Self;
23
24    /// Takes and input and return an output.
25    fn prior<S: Storage<f64, Dynamic, Dynamic>>(&self, input: &SMatrix<S>) -> DVector<f64>;
26
27    /// Optional, function that fits the prior on training data.
28    fn fit<SM: Storage<f64, Dynamic, Dynamic> + Clone, SV: Storage<f64, Dynamic, U1>>(&mut self,
29                                                                                      _training_inputs: &SMatrix<SM>,
30                                                                                      _training_outputs: &SVector<SV>)
31    {
32    }
33}
34
35//---------------------------------------------------------------------------------------
36// CLASSICAL PRIOR
37
38/// The Zero prior.
39///
40/// This prior always return zero.
41#[derive(Clone, Copy, Debug)]
42#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
43pub struct ZeroPrior {}
44
45impl Prior for ZeroPrior
46{
47    fn default(_input_dimension: usize) -> Self
48    {
49        Self {}
50    }
51
52    fn prior<S: Storage<f64, Dynamic, Dynamic>>(&self, input: &SMatrix<S>) -> DVector<f64>
53    {
54        DVector::zeros(input.nrows())
55    }
56}
57
58//-----------------------------------------------
59
60/// The Constant prior.
61///
62/// This prior returns a constant.
63/// It can be fit to return the mean of the training data.
64#[derive(Clone, Debug)]
65#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
66pub struct ConstantPrior
67{
68    c: f64
69}
70
71impl ConstantPrior
72{
73    /// Constructs a new constant prior
74    pub fn new(c: f64) -> Self
75    {
76        Self { c }
77    }
78}
79
80impl Prior for ConstantPrior
81{
82    fn default(_input_dimension: usize) -> Self
83    {
84        Self::new(0f64)
85    }
86
87    fn prior<S: Storage<f64, Dynamic, Dynamic>>(&self, input: &SMatrix<S>) -> DVector<f64>
88    {
89        DVector::from_element(input.nrows(), self.c)
90    }
91
92    /// the prior is fitted on the mean of the training outputs
93    fn fit<SM: Storage<f64, Dynamic, Dynamic>, SV: Storage<f64, Dynamic, U1>>(&mut self,
94                                                                              _training_inputs: &SMatrix<SM>,
95                                                                              training_outputs: &SVector<SV>)
96    {
97        self.c = training_outputs.mean();
98    }
99}
100
101//-----------------------------------------------
102
103/// The Linear prior.
104///
105/// This prior is a linear function which can be fit on the training data.
106#[derive(Clone, Debug)]
107#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
108pub struct LinearPrior
109{
110    weights: DVector<f64>,
111    intercept: f64
112}
113
114impl LinearPrior
115{
116    /// Constructs a new linear prior.
117    /// The first row of w is the bias such that `prior = [1|input] * w`
118    pub fn new(weights: DVector<f64>, intercept: f64) -> Self
119    {
120        LinearPrior { weights, intercept }
121    }
122}
123
124impl Prior for LinearPrior
125{
126    fn default(input_dimension: usize) -> Self
127    {
128        Self { weights: DVector::zeros(input_dimension), intercept: 0f64 }
129    }
130
131    fn prior<S: Storage<f64, Dynamic, Dynamic>>(&self, input: &SMatrix<S>) -> DVector<f64>
132    {
133        let mut result = input * &self.weights;
134        result.add_scalar_mut(self.intercept);
135        result
136    }
137
138    /// Performs a linear fit to set the value of the prior.
139    fn fit<SM: Storage<f64, Dynamic, Dynamic> + Clone, SV: Storage<f64, Dynamic, U1>>(&mut self,
140                                                                                      training_inputs: &SMatrix<SM>,
141                                                                                      training_outputs: &SVector<SV>)
142    {
143        // Solve linear system using an SVD decomposition.
144        let weights = training_inputs.clone()
145                                     .insert_column(0, 1.) // Add constant term for non-zero intercept.
146                                     .svd(true, true)
147                                     .solve(training_outputs, 0.)
148                                     .expect("Linear prior fit : solve failed.");
149
150        // TODO Solve cannot be used with qr and full_piv_lu due to issue 667
151        //  (https://github.com/rustsim/nalgebra/issues/667).
152
153        // TODO Solve_mut fails on non-square systems with both qr and full_piv_lu due to issue 672
154        //  (https://github.com/rustsim/nalgebra/issues/672).
155
156        // Extracts weights and intercept.
157        self.intercept = weights[0];
158        self.weights = weights.remove_row(0);
159    }
160}