echo_state_network/optimizer/
rls.rs

1use nalgebra as na;
2use serde::{Deserialize, Serialize};
3
4/// Recursive Least Squares (RLS) optimizer.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct RLS {
7    /// Auxiliary variable
8    p: na::DMatrix<f64>,
9    /// Forgetting factor
10    lambda: f64,
11    /// Weight matrix
12    weight: na::DMatrix<f64>,
13}
14
15impl RLS {
16    pub fn new(n_x: u64, n_y: u64, lambda: f64, alpha: f64) -> Self {
17        let mut p = na::DMatrix::identity(n_x as usize, n_x as usize);
18        p *= 1.0 / alpha;
19
20        let weight = na::DMatrix::zeros(n_y as usize, n_x as usize);
21
22        RLS { p, lambda, weight }
23    }
24
25    pub fn set_data(&mut self, x: &na::DVector<f64>, d: &na::DVector<f64>) {
26        let p = self.p.clone();
27
28        let p1 = p.clone() / self.lambda;
29        let p2 = p.clone() * x.clone() * x.clone().transpose() * p.clone().transpose();
30        let p3 = self.lambda + (x.clone().transpose() * p.clone() * x.clone())[(0, 0)];
31
32        self.p = p1 - p2 / p3;
33
34        let y = self.weight.clone() * x.clone();
35        let w1 = self.weight.clone();
36        let w2 = (1. / self.lambda) * (d.clone() - y.clone()) * (p * x.clone()).transpose();
37
38        self.weight = w1 + w2;
39    }
40
41    pub fn fit(&self) -> na::DMatrix<f64> {
42        self.weight.clone()
43    }
44}
45
46impl std::fmt::Display for RLS {
47    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
48        let mut displayed = format!("Lambda: {}", self.lambda);
49        displayed.push_str(&format!("\nP:\n{}", self.p));
50        displayed.push_str(&format!("\nWeight:\n{}", self.weight));
51        write!(f, "{}", displayed)
52    }
53}