friedrich/parameters/prior.rs
1//! Prior
2//!
3//! When asked to predict a value for an input that is too dissimilar to known inputs, the model will return the prior.
4//! Furthermore the process will be fitted on the residual of the prior meaning that a good prior can significantly improve the precision of the model.
5//!
6//! This can be a constant but also a polynomial or any model of the data.
7//! User-defined priors should implement the Prior trait.
8
9use crate::algebra::{SMatrix, SVector};
10use nalgebra::DVector;
11use nalgebra::{storage::Storage, Dyn, U1};
12
13//---------------------------------------------------------------------------------------
14// TRAIT
15
16/// The Prior trait.
17///
18/// User-defined kernels should implement this trait.
19pub trait Prior
20{
21 /// Default value for the prior.
22 fn default(input_dimension: usize) -> Self;
23
24 /// Takes and input and return an output.
25 fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>;
26
27 /// Optional, function that fits the prior on training data.
28 fn fit<SM: Storage<f64, Dyn, Dyn> + Clone, SV: Storage<f64, Dyn, U1>>(&mut self,
29 _training_inputs: &SMatrix<SM>,
30 _training_outputs: &SVector<SV>)
31 {
32 }
33}
34
35//---------------------------------------------------------------------------------------
36// CLASSICAL PRIOR
37
38/// The Zero prior.
39///
40/// This prior always return zero.
41#[derive(Clone, Copy, Debug)]
42#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
43pub struct ZeroPrior {}
44
45impl Prior for ZeroPrior
46{
47 fn default(_input_dimension: usize) -> Self
48 {
49 Self {}
50 }
51
52 fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
53 {
54 DVector::zeros(input.nrows())
55 }
56}
57
58//-----------------------------------------------
59
60/// The Constant prior.
61///
62/// This prior returns a constant.
63/// It can be fit to return the mean of the training data.
64#[derive(Clone, Debug)]
65#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
66pub struct ConstantPrior
67{
68 c: f64
69}
70
71impl ConstantPrior
72{
73 /// Constructs a new constant prior
74 pub fn new(c: f64) -> Self
75 {
76 Self { c }
77 }
78}
79
80impl Prior for ConstantPrior
81{
82 fn default(_input_dimension: usize) -> Self
83 {
84 Self::new(0f64)
85 }
86
87 fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
88 {
89 DVector::from_element(input.nrows(), self.c)
90 }
91
92 /// the prior is fitted on the mean of the training outputs
93 fn fit<SM: Storage<f64, Dyn, Dyn>, SV: Storage<f64, Dyn, U1>>(&mut self,
94 _training_inputs: &SMatrix<SM>,
95 training_outputs: &SVector<SV>)
96 {
97 self.c = training_outputs.mean();
98 }
99}
100
101//-----------------------------------------------
102
103/// The Linear prior.
104///
105/// This prior is a linear function which can be fit on the training data.
106#[derive(Clone, Debug)]
107#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
108pub struct LinearPrior
109{
110 weights: DVector<f64>,
111 intercept: f64
112}
113
114impl LinearPrior
115{
116 /// Constructs a new linear prior.
117 /// The first row of w is the bias such that `prior = [1|input] * w`
118 pub fn new(weights: DVector<f64>, intercept: f64) -> Self
119 {
120 LinearPrior { weights, intercept }
121 }
122}
123
124impl Prior for LinearPrior
125{
126 fn default(input_dimension: usize) -> Self
127 {
128 Self { weights: DVector::zeros(input_dimension), intercept: 0f64 }
129 }
130
131 fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
132 {
133 let mut result = input * &self.weights;
134 result.add_scalar_mut(self.intercept);
135 result
136 }
137
138 /// Performs a linear fit to set the value of the prior.
139 fn fit<SM: Storage<f64, Dyn, Dyn> + Clone, SV: Storage<f64, Dyn, U1>>(&mut self,
140 training_inputs: &SMatrix<SM>,
141 training_outputs: &SVector<SV>)
142 {
143 // Solve linear system using an SVD decomposition.
144 let weights = training_inputs.clone()
145 .insert_column(0, 1.) // Add constant term for non-zero intercept.
146 .svd(true, true)
147 .solve(training_outputs, 0.)
148 .expect("Linear prior fit : solve failed.");
149
150 // TODO Solve cannot be used with qr and full_piv_lu due to issue 667
151 // (https://github.com/rustsim/nalgebra/issues/667).
152
153 // TODO Solve_mut fails on non-square systems with both qr and full_piv_lu due to issue 672
154 // (https://github.com/rustsim/nalgebra/issues/672).
155
156 // Extracts weights and intercept.
157 self.intercept = weights[0];
158 self.weights = weights.remove_row(0);
159 }
160}