echo_state_network/optimizer/
ridge.rs1use nalgebra as na;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Ridge {
7 beta: f64,
8 x_xt: na::DMatrix<f64>,
9 d_xt: na::DMatrix<f64>,
10}
11
12impl Ridge {
13 pub fn new(n_x: u64, n_y: u64, beta: f64) -> Self {
17 let x_xt = na::DMatrix::zeros(n_x as usize, n_x as usize);
18 let d_xt = na::DMatrix::zeros(n_y as usize, n_x as usize);
19
20 Ridge { beta, x_xt, d_xt }
21 }
22
23 pub fn set_data(&mut self, x: &na::DVector<f64>, d: &na::DVector<f64>) {
26 self.x_xt = self.x_xt.clone() + x.clone() * x.clone().transpose();
27 self.d_xt = self.d_xt.clone() + d.clone() * x.clone().transpose();
28 }
29
30 pub fn fit(&self) -> na::DMatrix<f64> {
32 let n_x = self.x_xt.ncols();
33 let x_xt_inv = (self.x_xt.clone() + self.beta * na::DMatrix::identity(n_x, n_x))
34 .try_inverse()
35 .unwrap();
36
37 self.d_xt.clone() * x_xt_inv
38 }
39}
40
41impl std::fmt::Display for Ridge {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 let mut displayed = format!("Beta: {}", self.beta);
44 displayed.push_str(&format!("\nx_xt:\n{}", self.x_xt));
45 displayed.push_str(&format!("\nd_xt:\n{}", self.d_xt));
46 write!(f, "{}", displayed)
47 }
48}