1#![deny(missing_docs)]
10use std::fmt::Display;
11use crate::value::Expr;
12use rand::{distributions::Uniform, prelude::Distribution, thread_rng};
13
14pub struct Neuron {
19 w: Vec<Expr>,
20 b: Expr,
21 activation: Activation,
22}
23
24pub struct Layer {
29 neurons: Vec<Neuron>,
30}
31
32pub struct MLP {
38 layers: Vec<Layer>,
39}
40
41#[derive(Debug, Copy, Clone)]
45pub enum Activation {
46 None,
48 ReLU,
50 Tanh,
52}
53
54impl Neuron {
55 pub fn new(n_inputs: u32, activation: Activation) -> Neuron {
62 let between = Uniform::new_inclusive(-1.0, 1.0);
63 let mut rng = thread_rng();
64
65 let weights = (1..=n_inputs)
66 .map(|_| between.sample(&mut rng))
67 .map(|n| Expr::new_leaf(n))
68 .collect();
69
70 Neuron {
71 w: weights,
72 b: Expr::new_leaf(between.sample(&mut rng)),
73 activation,
74 }
75 }
76
77 pub fn forward(&self, x: Vec<Expr>) -> Expr {
82 assert_eq!(
83 x.len(),
84 self.w.len(),
85 "Number of inputs must match number of weights"
86 );
87
88 let mut sum = Expr::new_leaf(0.0);
89
90 for (i, x_i) in x.iter().enumerate() {
92 sum = sum + (x_i.clone() * self.w[i].clone());
93 }
94
95 let sum = sum + self.b.clone();
96 match self.activation {
97 Activation::None => sum,
98 Activation::ReLU => sum.relu(),
99 Activation::Tanh => sum.tanh(),
100 }
101 }
102}
103
104impl Display for Neuron {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let weights = self.w
107 .iter()
108 .map(|w| format!("{:.2}", w.result))
109 .collect::<Vec<_>>()
110 .join(", ");
111 write!(f, "Neuron: w: [{:}], b: {:.2}", weights, self.b.result)
112 }
113}
114
115impl Layer {
116 pub fn new(n_inputs: u32, n_outputs: u32, activation: Activation) -> Layer {
121 Layer {
122 neurons: (0..n_outputs).map(|_| Neuron::new(n_inputs, activation)).collect(),
123 }
124 }
125
126 pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
130 self.neurons.iter().map(|n| n.forward(x.clone())).collect()
131 }
132}
133
134impl Display for Layer {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 let neurons = self.neurons
137 .iter()
138 .map(|n| format!("{:}", n))
139 .collect::<Vec<_>>()
140 .join("\n - ");
141 write!(f, "Layer:\n - {:}", neurons)
142 }
143}
144
145impl MLP {
146 pub fn new(
154 n_inputs: u32, input_activation: Activation,
155 n_hidden: Vec<u32>, hidden_activation: Activation,
156 n_outputs: u32, output_activation: Activation) -> MLP {
157
158 let mut layers = Vec::new();
159
160 layers.push(Layer::new(n_inputs, n_hidden[0], input_activation));
161 for i in 1..n_hidden.len() {
162 layers.push(Layer::new(n_hidden[i - 1], n_hidden[i], hidden_activation));
163 }
164 layers.push(Layer::new(n_hidden[n_hidden.len() - 1], n_outputs, output_activation));
165
166 MLP { layers }
167 }
168
169 pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
173 let mut y = x;
174 for layer in &self.layers {
175 y = layer.forward(y);
176 }
177 y
178 }
179}
180
181impl Display for MLP {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 let layers = self.layers
184 .iter()
185 .map(|l| format!("{:}", l))
186 .collect::<Vec<_>>()
187 .join("\n\n");
188 write!(f, "MLP:\n{:}", layers)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn can_instantiate_neuron() {
198 let n = Neuron::new(3, Activation::None);
199
200 assert_eq!(n.w.len(), 3);
201 for i in 0..3 {
202 assert!(n.w[i].result >= -1.0 && n.w[i].result <= 1.0);
203 }
204 }
205
206 #[test]
207 fn can_do_forward_pass_neuron() {
208 let n = Neuron::new(3, Activation::None);
209
210 let x = vec![
211 Expr::new_leaf(0.0),
212 Expr::new_leaf(1.0),
213 Expr::new_leaf(2.0)
214 ];
215
216 let _ = n.forward(x);
217 }
218
219 #[test]
220 fn can_instantiate_layer() {
221 let l = Layer::new(3, 2, Activation::None);
222
223 assert_eq!(l.neurons.len(), 2);
224 assert_eq!(l.neurons[0].w.len(), 3);
225 }
226
227 #[test]
228 fn can_do_forward_pass_layer() {
229 let l = Layer::new(3, 2, Activation::Tanh);
230
231 let x = vec![
232 Expr::new_leaf(0.0),
233 Expr::new_leaf(1.0),
234 Expr::new_leaf(2.0)
235 ];
236
237 let y = l.forward(x);
238
239 assert_eq!(y.len(), 2);
240 }
241
242 #[test]
243 fn can_instantiate_mlp() {
244 let m = MLP::new(3, Activation::None,
245 vec![2, 2], Activation::Tanh,
246 1, Activation::None);
247
248 assert_eq!(m.layers.len(), 3);
249 assert_eq!(m.layers[0].neurons.len(), 2); assert_eq!(m.layers[0].neurons[0].w.len(), 3); assert_eq!(m.layers[1].neurons.len(), 2); assert_eq!(m.layers[1].neurons[0].w.len(), 2); assert_eq!(m.layers[2].neurons.len(), 1); assert_eq!(m.layers[2].neurons[0].w.len(), 2); }
258
259 #[test]
260 fn can_do_forward_pass_mlp() {
261 let m = MLP::new(3, Activation::None,
262 vec![2, 2], Activation::Tanh,
263 1, Activation::None);
264
265 let x = vec![
266 Expr::new_leaf(0.0),
267 Expr::new_leaf(1.0),
268 Expr::new_leaf(2.0)
269 ];
270
271 let y = m.forward(x);
272
273 assert_eq!(y.len(), 1);
274 }
275
276 #[test]
277 fn can_learn() {
278 let mlp = MLP::new(3, Activation::None,
279 vec![2, 2], Activation::Tanh,
280 1, Activation::None);
281
282 let mut inputs = vec![
283 vec![Expr::new_leaf(2.0), Expr::new_leaf(3.0), Expr::new_leaf(-1.0)],
284 vec![Expr::new_leaf(3.0), Expr::new_leaf(-1.0), Expr::new_leaf(0.5)],
285 vec![Expr::new_leaf(0.5), Expr::new_leaf(1.0), Expr::new_leaf(1.0)],
286 vec![Expr::new_leaf(1.0), Expr::new_leaf(1.0), Expr::new_leaf(-1.0)],
287 ];
288
289 inputs.iter_mut().for_each(|instance|
291 instance.iter_mut().for_each(|input|
292 input.is_learnable = false
293 )
294 );
295
296 let mut targets = vec![
297 Expr::new_leaf(1.0),
298 Expr::new_leaf(-1.0),
299 Expr::new_leaf(-1.0),
300 Expr::new_leaf(1.0),
301 ];
302 targets.iter_mut().for_each(|target| target.is_learnable = false);
304
305 let predicted = inputs
306 .iter()
307 .map(|x| mlp.forward(x.clone()))
308 .map(|x| x[0].clone())
310 .collect::<Vec<_>>();
311
312 let mut loss = predicted
314 .iter()
315 .zip(targets.iter())
316 .map(|(p, t)| {
317 let mut diff = p.clone() - t.clone();
318 diff.is_learnable = false;
319
320 let mut squared_exponent = Expr::new_leaf(2.0);
321 squared_exponent.is_learnable = false;
322
323 let mut mse = diff.clone().pow(squared_exponent);
324 mse.is_learnable = false;
325
326 mse
327 })
328 .sum::<Expr>();
329
330 let first_loss = loss.result.clone();
331 loss.learn(1e-04);
332 loss.recalculate();
333 let second_loss = loss.result.clone();
334
335 assert!(second_loss < first_loss, "Loss should decrease after learning ({} >= {})", second_loss, first_loss);
336 }
337}