use crate::algebra::{SMatrix, SVector};
use nalgebra::DVector;
use nalgebra::{storage::Storage, Dyn, U1};
pub trait Prior
{
fn default(input_dimension: usize) -> Self;
fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>;
fn fit<SM: Storage<f64, Dyn, Dyn> + Clone, SV: Storage<f64, Dyn, U1>>(&mut self,
_training_inputs: &SMatrix<SM>,
_training_outputs: &SVector<SV>)
{
}
}
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ZeroPrior {}
impl Prior for ZeroPrior
{
fn default(_input_dimension: usize) -> Self
{
Self {}
}
fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
{
DVector::zeros(input.nrows())
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ConstantPrior
{
c: f64
}
impl ConstantPrior
{
pub fn new(c: f64) -> Self
{
Self { c }
}
}
impl Prior for ConstantPrior
{
fn default(_input_dimension: usize) -> Self
{
Self::new(0f64)
}
fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
{
DVector::from_element(input.nrows(), self.c)
}
fn fit<SM: Storage<f64, Dyn, Dyn>, SV: Storage<f64, Dyn, U1>>(&mut self,
_training_inputs: &SMatrix<SM>,
training_outputs: &SVector<SV>)
{
self.c = training_outputs.mean();
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "friedrich_serde", derive(serde::Deserialize, serde::Serialize))]
pub struct LinearPrior
{
weights: DVector<f64>,
intercept: f64
}
impl LinearPrior
{
pub fn new(weights: DVector<f64>, intercept: f64) -> Self
{
LinearPrior { weights, intercept }
}
}
impl Prior for LinearPrior
{
fn default(input_dimension: usize) -> Self
{
Self { weights: DVector::zeros(input_dimension), intercept: 0f64 }
}
fn prior<S: Storage<f64, Dyn, Dyn>>(&self, input: &SMatrix<S>) -> DVector<f64>
{
let mut result = input * &self.weights;
result.add_scalar_mut(self.intercept);
result
}
fn fit<SM: Storage<f64, Dyn, Dyn> + Clone, SV: Storage<f64, Dyn, U1>>(&mut self,
training_inputs: &SMatrix<SM>,
training_outputs: &SVector<SV>)
{
let weights = training_inputs.clone()
.insert_column(0, 1.) .svd(true, true)
.solve(training_outputs, 0.)
.expect("Linear prior fit : solve failed.");
self.intercept = weights[0];
self.weights = weights.remove_row(0);
}
}