1#![deny(missing_docs)]
10use std::fmt::Display;
11use crate::value::Expr;
12use rand::{distributions::Uniform, prelude::Distribution, thread_rng};
13
14pub struct Neuron {
19    w: Vec<Expr>,
20    b: Expr,
21    activation: Activation,
22}
23
24pub struct Layer {
29    neurons: Vec<Neuron>,
30}
31
32pub struct MLP {
38    layers: Vec<Layer>,
39}
40
41#[derive(Debug, Copy, Clone)]
45pub enum Activation {
46    None,
48    ReLU,
50    Tanh,
52}
53
54impl Neuron {
55    pub fn new(n_inputs: u32, activation: Activation) -> Neuron {
62        let between = Uniform::new_inclusive(-1.0, 1.0);
63        let mut rng = thread_rng();
64
65        let weights = (1..=n_inputs)
66            .map(|_| between.sample(&mut rng))
67            .map(|n| Expr::new_leaf(n))
68            .collect();
69
70        Neuron {
71            w: weights,
72            b: Expr::new_leaf(between.sample(&mut rng)),
73            activation,
74        }
75    }
76
77    pub fn forward(&self, x: Vec<Expr>) -> Expr {
82        assert_eq!(
83            x.len(),
84            self.w.len(),
85            "Number of inputs must match number of weights"
86        );
87
88        let mut sum = Expr::new_leaf(0.0);
89
90        for (i, x_i) in x.iter().enumerate() {
92            sum = sum + (x_i.clone() * self.w[i].clone());
93        }
94
95        let sum = sum + self.b.clone();
96        match self.activation {
97            Activation::None => sum,
98            Activation::ReLU => sum.relu(),
99            Activation::Tanh => sum.tanh(),
100        }
101    }
102}
103
104impl Display for Neuron {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        let weights = self.w
107            .iter()
108            .map(|w| format!("{:.2}", w.result))
109            .collect::<Vec<_>>()
110            .join(", ");
111        write!(f, "Neuron: w: [{:}], b: {:.2}", weights, self.b.result)
112    }
113}
114
115impl Layer {
116    pub fn new(n_inputs: u32, n_outputs: u32, activation: Activation) -> Layer {
121        Layer {
122            neurons: (0..n_outputs).map(|_| Neuron::new(n_inputs, activation)).collect(),
123        }
124    }
125
126    pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
130        self.neurons.iter().map(|n| n.forward(x.clone())).collect()
131    }
132}
133
134impl Display for Layer {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        let neurons = self.neurons
137            .iter()
138            .map(|n| format!("{:}", n))
139            .collect::<Vec<_>>()
140            .join("\n - ");
141        write!(f, "Layer:\n - {:}", neurons)
142    }
143}
144
145impl MLP {
146    pub fn new(
154        n_inputs: u32, input_activation: Activation,
155        n_hidden: Vec<u32>, hidden_activation: Activation,
156        n_outputs: u32, output_activation: Activation) -> MLP {
157
158        let mut layers = Vec::new();
159
160        layers.push(Layer::new(n_inputs, n_hidden[0], input_activation));
161        for i in 1..n_hidden.len() {
162            layers.push(Layer::new(n_hidden[i - 1], n_hidden[i], hidden_activation));
163        }
164        layers.push(Layer::new(n_hidden[n_hidden.len() - 1], n_outputs, output_activation));
165
166        MLP { layers }
167    }
168
169    pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
173        let mut y = x;
174        for layer in &self.layers {
175            y = layer.forward(y);
176        }
177        y
178    }
179}
180
181impl Display for MLP {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        let layers = self.layers
184            .iter()
185            .map(|l| format!("{:}", l))
186            .collect::<Vec<_>>()
187            .join("\n\n");
188        write!(f, "MLP:\n{:}", layers)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn can_instantiate_neuron() {
198        let n = Neuron::new(3, Activation::None);
199
200        assert_eq!(n.w.len(), 3);
201        for i in 0..3 {
202            assert!(n.w[i].result >= -1.0 && n.w[i].result <= 1.0);
203        }
204    }
205
206    #[test]
207    fn can_do_forward_pass_neuron() {
208        let n = Neuron::new(3, Activation::None);
209
210        let x = vec![
211            Expr::new_leaf(0.0),
212            Expr::new_leaf(1.0), 
213            Expr::new_leaf(2.0)
214        ];
215
216        let _ = n.forward(x);
217    }
218
219    #[test]
220    fn can_instantiate_layer() {
221        let l = Layer::new(3, 2, Activation::None);
222
223        assert_eq!(l.neurons.len(), 2);
224        assert_eq!(l.neurons[0].w.len(), 3);
225    }
226
227    #[test]
228    fn can_do_forward_pass_layer() {
229        let l = Layer::new(3, 2, Activation::Tanh);
230
231        let x = vec![
232            Expr::new_leaf(0.0),
233            Expr::new_leaf(1.0),
234            Expr::new_leaf(2.0)
235        ];
236
237        let y = l.forward(x);
238
239        assert_eq!(y.len(), 2);
240    }
241
242    #[test]
243    fn can_instantiate_mlp() {
244        let m = MLP::new(3, Activation::None,
245            vec![2, 2], Activation::Tanh,
246            1, Activation::None);
247
248        assert_eq!(m.layers.len(), 3);
249        assert_eq!(m.layers[0].neurons.len(), 2); assert_eq!(m.layers[0].neurons[0].w.len(), 3); assert_eq!(m.layers[1].neurons.len(), 2); assert_eq!(m.layers[1].neurons[0].w.len(), 2); assert_eq!(m.layers[2].neurons.len(), 1); assert_eq!(m.layers[2].neurons[0].w.len(), 2); }
258
259    #[test]
260    fn can_do_forward_pass_mlp() {
261        let m = MLP::new(3, Activation::None,
262            vec![2, 2], Activation::Tanh,
263            1, Activation::None);
264
265        let x = vec![
266            Expr::new_leaf(0.0),
267            Expr::new_leaf(1.0),
268            Expr::new_leaf(2.0)
269        ];
270
271        let y = m.forward(x);
272
273        assert_eq!(y.len(), 1);
274    }
275
276    #[test]
277    fn can_learn() {
278        let mlp = MLP::new(3, Activation::None,
279            vec![2, 2], Activation::Tanh,
280            1, Activation::None);
281
282        let mut inputs = vec![
283            vec![Expr::new_leaf(2.0), Expr::new_leaf(3.0), Expr::new_leaf(-1.0)],
284            vec![Expr::new_leaf(3.0), Expr::new_leaf(-1.0), Expr::new_leaf(0.5)],
285            vec![Expr::new_leaf(0.5), Expr::new_leaf(1.0), Expr::new_leaf(1.0)],
286            vec![Expr::new_leaf(1.0), Expr::new_leaf(1.0), Expr::new_leaf(-1.0)],
287        ];
288
289        inputs.iter_mut().for_each(|instance| 
291            instance.iter_mut().for_each(|input| 
292                input.is_learnable = false
293            )
294        );
295
296        let mut targets = vec![
297            Expr::new_leaf(1.0),
298            Expr::new_leaf(-1.0),
299            Expr::new_leaf(-1.0),
300            Expr::new_leaf(1.0),
301        ];
302        targets.iter_mut().for_each(|target| target.is_learnable = false);
304
305        let predicted = inputs
306            .iter()
307            .map(|x| mlp.forward(x.clone()))
308            .map(|x| x[0].clone())
310            .collect::<Vec<_>>();
311
312        let mut loss = predicted
314            .iter()
315            .zip(targets.iter())
316            .map(|(p, t)| {
317                let mut diff = p.clone() - t.clone();
318                diff.is_learnable = false;
319
320                let mut squared_exponent = Expr::new_leaf(2.0);
321                squared_exponent.is_learnable = false;
322
323                let mut mse = diff.clone().pow(squared_exponent);
324                mse.is_learnable = false;
325
326                mse
327            })
328            .sum::<Expr>();
329
330        let first_loss = loss.result.clone();
331        loss.learn(1e-04);
332        loss.recalculate();
333        let second_loss = loss.result.clone();
334
335        assert!(second_loss < first_loss, "Loss should decrease after learning ({} >= {})", second_loss, first_loss);
336    }
337}