echo_state_network/optimizer/
rls.rs1use nalgebra as na;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct RLS {
7 p: na::DMatrix<f64>,
9 lambda: f64,
11 weight: na::DMatrix<f64>,
13}
14
15impl RLS {
16 pub fn new(n_x: u64, n_y: u64, lambda: f64, alpha: f64) -> Self {
17 let mut p = na::DMatrix::identity(n_x as usize, n_x as usize);
18 p *= 1.0 / alpha;
19
20 let weight = na::DMatrix::zeros(n_y as usize, n_x as usize);
21
22 RLS { p, lambda, weight }
23 }
24
25 pub fn set_data(&mut self, x: &na::DVector<f64>, d: &na::DVector<f64>) {
26 let p = self.p.clone();
27
28 let p1 = p.clone() / self.lambda;
29 let p2 = p.clone() * x.clone() * x.clone().transpose() * p.clone().transpose();
30 let p3 = self.lambda + (x.clone().transpose() * p.clone() * x.clone())[(0, 0)];
31
32 self.p = p1 - p2 / p3;
33
34 let y = self.weight.clone() * x.clone();
35 let w1 = self.weight.clone();
36 let w2 = (1. / self.lambda) * (d.clone() - y.clone()) * (p * x.clone()).transpose();
37
38 self.weight = w1 + w2;
39 }
40
41 pub fn fit(&self) -> na::DMatrix<f64> {
42 self.weight.clone()
43 }
44}
45
46impl std::fmt::Display for RLS {
47 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
48 let mut displayed = format!("Lambda: {}", self.lambda);
49 displayed.push_str(&format!("\nP:\n{}", self.p));
50 displayed.push_str(&format!("\nWeight:\n{}", self.weight));
51 write!(f, "{}", displayed)
52 }
53}