echo_state_network/
echo_state_network.rs1use std::vec;
2
3use nalgebra as na;
4
5use crate::*;
6
7pub struct EchoStateNetwork {
8 input: Input,
9 reservoir: Reservoir,
10 output: Output,
11 previous_y: na::DVector<f64>,
12 output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
13 inverse_output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
14 is_classification: bool,
15 n_y: u64,
16 n_u: u64,
17 feedback: Option<Feedback>,
18 is_noisy: bool,
19}
20
21impl EchoStateNetwork {
22 #[allow(clippy::too_many_arguments)]
23 pub fn new(
24 n_u: u64,
25 n_y: u64,
26 n_x: u64,
27 density: f64,
28 input_scale: f64,
29 rho: f64,
30 activation: fn(f64) -> f64,
31 feedback_scale: Option<f64>,
32 noise_level: Option<f64>,
33 leaking_rate: f64,
34 output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
35 inverse_output_function: fn(&na::DVector<f64>) -> na::DVector<f64>,
36 is_classification: bool,
37 ) -> Self {
38 EchoStateNetwork {
39 input: Input::new(n_u, n_x, input_scale),
40 reservoir: Reservoir::new(n_x, density, rho, activation, leaking_rate, None),
41 output: Output::new(n_y, n_x),
42 previous_y: na::DVector::zeros(n_y as usize),
43 output_function,
44 inverse_output_function,
45 is_classification,
46 n_y,
47 n_u,
48 feedback: feedback_scale.map(|scale| Feedback::new(n_y, n_x, scale)),
49 is_noisy: noise_level.is_some(),
50 }
51 }
52
53 pub fn train(
54 &mut self,
55 teaching_input: &[Vec<f64>],
56 teaching_output: &[Vec<f64>],
57 optimizer: &mut Ridge,
58 ) -> Vec<Vec<f64>> {
59 let train_length = teaching_input.len();
60 let input_elements = teaching_input
61 .iter()
62 .flatten()
63 .cloned()
64 .collect::<Vec<f64>>();
65 let teaching_input = na::DMatrix::from_column_slice(
66 self.n_u as usize,
67 train_length,
68 input_elements.as_slice(),
69 );
70 let output_elements = teaching_output
71 .iter()
72 .flatten()
73 .cloned()
74 .collect::<Vec<f64>>();
75 let teaching_output = na::DMatrix::from_column_slice(
76 self.n_y as usize,
77 train_length,
78 output_elements.as_slice(),
79 );
80
81 let mut y_log = vec![];
82
83 for n in 0..train_length {
84 let mut x_in = self.input.call(&teaching_input.column(n).clone_owned());
85
86 if let Some(fdb) = self.feedback.clone() {
87 let x_fdb = fdb.give_feedback(&self.previous_y);
88 x_in += x_fdb;
89 }
90
91 if self.is_noisy {
92 todo!()
93 }
94
95 let x_res = self.reservoir.call(x_in);
96
97 if self.is_classification {
98 todo!()
99 }
100
101 let d = teaching_output.column(n).clone_owned();
102 let d = (self.inverse_output_function)(&d);
103
104 optimizer.set_data(&x_res, &d);
105
106 let y = self.output.call(&x_res);
107 let output = (self.output_function)(&y);
108 y_log.push(output.as_slice().to_vec());
109 self.previous_y = d.clone();
110 }
111
112 let output_weight = optimizer.fit();
113 self.output.set_weight(output_weight);
114
115 y_log
116 }
117
118 pub fn estimate(&mut self, input: &[Vec<f64>]) -> Vec<Vec<f64>> {
119 let test_length = input.len();
120 let input_elements = input.iter().flatten().cloned().collect::<Vec<f64>>();
121 let input = na::DMatrix::from_column_slice(
122 self.n_u as usize,
123 test_length,
124 input_elements.as_slice(),
125 );
126
127 let mut y_log = vec![];
128
129 for n in 0..test_length {
130 let mut x_in = self.input.call(&input.column(n).clone_owned());
131
132 if let Some(fdb) = self.feedback.clone() {
133 let x_fdb = fdb.give_feedback(&self.previous_y);
134 x_in += x_fdb;
135 }
136
137 let x_res = self.reservoir.call(x_in);
138
139 if self.is_classification {
140 todo!()
141 }
142
143 let y_estimated = self.output.call(&x_res);
144 let y_estimated = (self.output_function)(&y_estimated);
145 y_log.push(y_estimated.as_slice().to_vec());
146
147 self.previous_y = y_estimated;
148 }
149
150 y_log
151 }
152
153 pub fn serde_json(&self) -> serde_json::Result<String> {
154 let input = serde_json::to_string(&self.input)?;
155 let reservoir = serde_json::to_string(&self.reservoir)?;
156 let output = serde_json::to_string(&self.output)?;
157 let feedback = if let Some(fdb) = self.feedback.clone() {
158 serde_json::to_string(&fdb)?
159 } else {
160 "null".to_string()
161 };
162 let json = format!(
163 r#"{{
164 "input": {},
165 "reservoir": {},
166 "output": {},
167 "feedback": {}
168 }}"#,
169 input, reservoir, output, feedback
170 );
171 Ok(json)
172 }
173}