echo_state_network/reservoir_computing/
echo_state_network.rs1use nalgebra as na;
2
3use crate::*;
4
5pub struct EchoStateNetwork {
6 input: Input,
7 reservoir: Reservoir,
8 output: Output,
9 previous_y: na::DVector<f64>,
10 output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
11 inverse_output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
12 is_classification: bool,
13 n_y: u64,
14 n_u: u64,
15 feedback: Option<Feedback>,
16 is_noisy: bool,
17 rls: Option<RLS>,
18 ridge: Option<Ridge>,
19}
20
21impl EchoStateNetwork {
22 #[allow(clippy::too_many_arguments)]
23 pub fn new(
24 n_u: u64,
25 n_y: u64,
26 n_x: u64,
27 density: f64,
28 input_scale: f64,
29 rho: f64,
30 activation: fn(f64) -> f64,
31 feedback_scale: Option<f64>,
32 noise_level: Option<f64>,
33 leaking_rate: f64,
34 output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
35 inverse_output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
36 is_classification: bool,
37 ridge_beta: f64,
38 ) -> Self {
39 EchoStateNetwork {
40 input: Input::new(n_u, n_x, input_scale),
41 reservoir: Reservoir::new(n_x, density, rho, activation, leaking_rate, None),
42 output: Output::new(n_y, n_x),
43 previous_y: na::DVector::zeros(n_y as usize),
44 output_function,
45 inverse_output_function,
46 is_classification,
47 n_y,
48 n_u,
49 feedback: feedback_scale.map(|scale| Feedback::new(n_y, n_x, scale)),
50 is_noisy: noise_level.is_some(),
51 rls: Some(RLS::new(n_x, n_y, 1.0, 1.0)),
52 ridge: Some(Ridge::new(n_x, n_y, ridge_beta)),
53 }
54 }
55
56 pub fn serde_json(&self) -> serde_json::Result<String> {
57 let input = serde_json::to_string(&self.input)?;
58 let reservoir = serde_json::to_string(&self.reservoir)?;
59 let output = serde_json::to_string(&self.output)?;
60 let feedback = if let Some(fdb) = self.feedback.clone() {
61 serde_json::to_string(&fdb)?
62 } else {
63 "None".to_string()
64 };
65 let json = format!(
66 r#"{{
67 "input": {},
68 "reservoir": {},
69 "output": {},
70 "feedback": {}
71 }}"#,
72 input, reservoir, output, feedback
73 );
74 Ok(json)
75 }
76}
77
78impl ReservoirComputing for EchoStateNetwork {
79 fn train(&mut self, teaching_input: &[f64], teaching_output: &[f64]) {
80 if self.rls.is_none() {
81 panic!("RLS is not initialized");
82 }
83
84 let x = na::DVector::from_vec(teaching_input.to_vec());
85 let d = na::DVector::from_vec(teaching_output.to_vec());
86
87 self.rls.as_mut().unwrap().set_data(&x, &d);
88 let weight = self.rls.as_ref().unwrap().fit();
89 self.output.set_weight(weight);
90 }
91
92 fn offline_train(&mut self, teaching_input: &[Vec<f64>], teaching_output: &[Vec<f64>]) {
93 if self.ridge.is_none() {
94 panic!("Ridge is not initialized");
95 }
96
97 let train_length = teaching_input.len();
98 let input_elements = teaching_input
99 .iter()
100 .flatten()
101 .cloned()
102 .collect::<Vec<f64>>();
103 let teaching_input = na::DMatrix::from_column_slice(
104 self.n_u as usize,
105 train_length,
106 input_elements.as_slice(),
107 );
108 let output_elements = teaching_output
109 .iter()
110 .flatten()
111 .cloned()
112 .collect::<Vec<f64>>();
113 let teaching_output = na::DMatrix::from_column_slice(
114 self.n_y as usize,
115 train_length,
116 output_elements.as_slice(),
117 );
118
119 for n in 0..train_length {
120 let mut x_in = self.input.call(&teaching_input.column(n).clone_owned());
121
122 if let Some(fdb) = self.feedback.clone() {
123 let x_fdb = fdb.give_feedback(&self.previous_y);
124 x_in += x_fdb;
125 }
126
127 if self.is_noisy {
128 todo!()
129 }
130
131 let x_res = self.reservoir.call(x_in);
132
133 if self.is_classification {
134 todo!()
135 }
136
137 let d = teaching_output.column(n).clone_owned();
138 let d = (self.inverse_output_function)(&d);
139
140 self.ridge.as_mut().unwrap().set_data(&x_res, &d);
141
142 self.previous_y = d.clone();
143 }
144
145 let output_weight = self.ridge.as_mut().unwrap().fit();
146 self.output.set_weight(output_weight);
147 }
148
149 fn estimate(&mut self, input: &[f64]) -> Vec<f64> {
150 let input = na::DVector::from_column_slice(input);
151
152 let mut x_in = self.input.call(&input);
153
154 if let Some(fdb) = self.feedback.clone() {
155 let x_fdb = fdb.give_feedback(&self.previous_y);
156 x_in += x_fdb;
157 }
158
159 let x_res = self.reservoir.call(x_in);
160
161 if self.is_classification {
162 todo!()
163 }
164
165 let y_estimated = self.output.call(&x_res);
166 let y_estimated = (self.output_function)(&y_estimated);
167
168 self.previous_y = y_estimated.clone();
169
170 y_estimated.as_slice().to_vec()
171 }
172}