1#![deny(missing_docs)]
10use std::fmt::Display;
11use crate::value::Expr;
12use rand::{distributions::Uniform, prelude::Distribution, thread_rng};
13
14pub struct Neuron {
19    w: Vec<Expr>,
20    b: Expr,
21    activation: Activation,
22}
23
24pub struct Layer {
29    neurons: Vec<Neuron>,
30}
31
32pub struct MLP {
38    layers: Vec<Layer>,
39}
40
41#[derive(Debug, Copy, Clone)]
45pub enum Activation {
46    None,
48    ReLU,
50    Tanh,
52}
53
54impl Neuron {
55    pub fn new(n_inputs: u32, activation: Activation) -> Neuron {
62        let between = Uniform::new_inclusive(-1.0, 1.0);
63        let mut rng = thread_rng();
64
65        let weights = (1..=n_inputs)
66            .map(|_| between.sample(&mut rng))
67            .enumerate()
68            .map(|(i, n)| Expr::new_leaf(n, &format!("w_{:}", i)))
69            .collect();
70
71        Neuron {
72            w: weights,
73            b: Expr::new_leaf(between.sample(&mut rng), "b"),
74            activation,
75        }
76    }
77
78    pub fn forward(&self, x: Vec<Expr>) -> Expr {
83        assert_eq!(
84            x.len(),
85            self.w.len(),
86            "Number of inputs must match number of weights"
87        );
88
89        let mut sum = Expr::new_leaf(0.0, "0.0");
90
91        for (i, x_i) in x.iter().enumerate() {
93            sum = sum + (x_i.clone() * self.w[i].clone());
94        }
95
96        let sum = sum + self.b.clone();
97        match self.activation {
98            Activation::None => sum,
99            Activation::ReLU => sum.relu("activation"),
100            Activation::Tanh => sum.tanh("activation"),
101        }
102    }
103}
104
105impl Display for Neuron {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        let weights = self.w
108            .iter()
109            .map(|w| format!("{:.2}", w.result))
110            .collect::<Vec<_>>()
111            .join(", ");
112        write!(f, "Neuron: w: [{:}], b: {:.2}", weights, self.b.result)
113    }
114}
115
116impl Layer {
117    pub fn new(n_inputs: u32, n_outputs: u32, activation: Activation) -> Layer {
122        Layer {
123            neurons: (0..n_outputs).map(|_| Neuron::new(n_inputs, activation)).collect(),
124        }
125    }
126
127    pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
131        self.neurons.iter().map(|n| n.forward(x.clone())).collect()
132    }
133}
134
135impl Display for Layer {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        let neurons = self.neurons
138            .iter()
139            .map(|n| format!("{:}", n))
140            .collect::<Vec<_>>()
141            .join("\n - ");
142        write!(f, "Layer:\n - {:}", neurons)
143    }
144}
145
146impl MLP {
147    pub fn new(
155        n_inputs: u32, input_activation: Activation,
156        n_hidden: Vec<u32>, hidden_activation: Activation,
157        n_outputs: u32, output_activation: Activation) -> MLP {
158
159        let mut layers = Vec::new();
160
161        layers.push(Layer::new(n_inputs, n_hidden[0], input_activation));
162        for i in 1..n_hidden.len() {
163            layers.push(Layer::new(n_hidden[i - 1], n_hidden[i], hidden_activation));
164        }
165        layers.push(Layer::new(n_hidden[n_hidden.len() - 1], n_outputs, output_activation));
166
167        MLP { layers }
168    }
169
170    pub fn forward(&self, x: Vec<Expr>) -> Vec<Expr> {
174        let mut y = x;
175        for layer in &self.layers {
176            y = layer.forward(y);
177        }
178        y
179    }
180}
181
182impl Display for MLP {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        let layers = self.layers
185            .iter()
186            .map(|l| format!("{:}", l))
187            .collect::<Vec<_>>()
188            .join("\n\n");
189        write!(f, "MLP:\n{:}", layers)
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn can_instantiate_neuron() {
199        let n = Neuron::new(3, Activation::None);
200
201        assert_eq!(n.w.len(), 3);
202        for i in 0..3 {
203            assert!(n.w[i].result >= -1.0 && n.w[i].result <= 1.0);
204        }
205    }
206
207    #[test]
208    fn can_do_forward_pass_neuron() {
209        let n = Neuron::new(3, Activation::None);
210
211        let x = vec![
212            Expr::new_leaf(0.0, "x_1"),
213            Expr::new_leaf(1.0, "x_2"), 
214            Expr::new_leaf(2.0, "x_3")
215        ];
216
217        let _ = n.forward(x);
218    }
219
220    #[test]
221    fn can_instantiate_layer() {
222        let l = Layer::new(3, 2, Activation::None);
223
224        assert_eq!(l.neurons.len(), 2);
225        assert_eq!(l.neurons[0].w.len(), 3);
226    }
227
228    #[test]
229    fn can_do_forward_pass_layer() {
230        let l = Layer::new(3, 2, Activation::Tanh);
231
232        let x = vec![
233            Expr::new_leaf(0.0, "x_1"),
234            Expr::new_leaf(1.0, "x_2"),
235            Expr::new_leaf(2.0, "x_3")
236        ];
237
238        let y = l.forward(x);
239
240        assert_eq!(y.len(), 2);
241    }
242
243    #[test]
244    fn can_instantiate_mlp() {
245        let m = MLP::new(3, Activation::None,
246            vec![2, 2], Activation::Tanh,
247            1, Activation::None);
248
249        assert_eq!(m.layers.len(), 3);
250        assert_eq!(m.layers[0].neurons.len(), 2); assert_eq!(m.layers[0].neurons[0].w.len(), 3); assert_eq!(m.layers[1].neurons.len(), 2); assert_eq!(m.layers[1].neurons[0].w.len(), 2); assert_eq!(m.layers[2].neurons.len(), 1); assert_eq!(m.layers[2].neurons[0].w.len(), 2); }
259
260    #[test]
261    fn can_do_forward_pass_mlp() {
262        let m = MLP::new(3, Activation::None,
263            vec![2, 2], Activation::Tanh,
264            1, Activation::None);
265
266        let x = vec![
267            Expr::new_leaf(0.0, "x_1"),
268            Expr::new_leaf(1.0, "x_2"),
269            Expr::new_leaf(2.0, "x_3")
270        ];
271
272        let y = m.forward(x);
273
274        assert_eq!(y.len(), 1);
275    }
276
277    #[test]
278    fn can_learn() {
279        let mlp = MLP::new(3, Activation::None,
280            vec![2, 2], Activation::Tanh,
281            1, Activation::None);
282
283        let mut inputs = vec![
284            vec![Expr::new_leaf(2.0, "x_1,1"), Expr::new_leaf(3.0, "x_1,2"), Expr::new_leaf(-1.0, "x_1,3")],
285            vec![Expr::new_leaf(3.0, "x_2,1"), Expr::new_leaf(-1.0, "x_2,2"), Expr::new_leaf(0.5, "x_2,3")],
286            vec![Expr::new_leaf(0.5, "x_3,1"), Expr::new_leaf(1.0, "x_3,2"), Expr::new_leaf(1.0, "x_3,3")],
287            vec![Expr::new_leaf(1.0, "x_4,1"), Expr::new_leaf(1.0, "x_4,2"), Expr::new_leaf(-1.0, "x_4,3")],
288        ];
289
290        inputs.iter_mut().for_each(|instance| 
292            instance.iter_mut().for_each(|input| 
293                input.is_learnable = false
294            )
295        );
296
297        let mut targets = vec![
298            Expr::new_leaf(1.0, "y_1"),
299            Expr::new_leaf(-1.0, "y_2"),
300            Expr::new_leaf(-1.0, "y_3"),
301            Expr::new_leaf(1.0, "y_4"),
302        ];
303        targets.iter_mut().for_each(|target| target.is_learnable = false);
305
306        let predicted = inputs
307            .iter()
308            .map(|x| mlp.forward(x.clone()))
309            .map(|x| x[0].clone())
311            .collect::<Vec<_>>();
312
313        let mut loss = predicted
315            .iter()
316            .zip(targets.iter())
317            .map(|(p, t)| {
318                let mut diff = p.clone() - t.clone();
319                diff.is_learnable = false;
320
321                let mut squared_exponent = Expr::new_leaf(2.0, "2");
322                squared_exponent.is_learnable = false;
323
324                let mut mse = diff.clone().pow(squared_exponent, "mse");
325                mse.is_learnable = false;
326
327                mse
328            })
329            .sum::<Expr>();
330
331        let first_loss = loss.result.clone();
332        loss.learn(1e-04);
333        loss.recalculate();
334        let second_loss = loss.result.clone();
335
336        assert!(second_loss < first_loss, "Loss should decrease after learning ({} >= {})", second_loss, first_loss);
337    }
338}