echo_state_network/optimizer/
ridge.rs

1use nalgebra as na;
2use serde::{Deserialize, Serialize};
3
4/// Ridge regression model.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Ridge {
7    beta: f64,
8    x_xt: na::DMatrix<f64>,
9    d_xt: na::DMatrix<f64>,
10}
11
12impl Ridge {
13    /// Create a new Ridge regression model.
14    /// 'n_x' is the number of input variables and 'n_y' is the number of output variables.
15    /// 'beta' is the regularization parameter.
16    pub fn new(n_x: u64, n_y: u64, beta: f64) -> Self {
17        let x_xt = na::DMatrix::zeros(n_x as usize, n_x as usize);
18        let d_xt = na::DMatrix::zeros(n_y as usize, n_x as usize);
19
20        Ridge { beta, x_xt, d_xt }
21    }
22
23    /// Update the internal state of the Ridge regression model.
24    /// 'x' is the input vector (explanatory variable) and 'd' is the output vector (response variable).
25    pub fn set_data(&mut self, x: &na::DVector<f64>, d: &na::DVector<f64>) {
26        self.x_xt = self.x_xt.clone() + x.clone() * x.clone().transpose();
27        self.d_xt = self.d_xt.clone() + d.clone() * x.clone().transpose();
28    }
29
30    /// Fit the Ridge regression model and return the weight matrix.
31    pub fn fit(&self) -> na::DMatrix<f64> {
32        let n_x = self.x_xt.ncols();
33        let x_xt_inv = (self.x_xt.clone() + self.beta * na::DMatrix::identity(n_x, n_x))
34            .try_inverse()
35            .unwrap();
36
37        self.d_xt.clone() * x_xt_inv
38    }
39}
40
41impl std::fmt::Display for Ridge {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        let mut displayed = format!("Beta: {}", self.beta);
44        displayed.push_str(&format!("\nx_xt:\n{}", self.x_xt));
45        displayed.push_str(&format!("\nd_xt:\n{}", self.d_xt));
46        write!(f, "{}", displayed)
47    }
48}