easy_ml/
neural_networks.rs

1/*!
2Neural Network training examples
3
4# XOR
5
6The following code shows a simple network using the sigmoid activation function
7to learn the non linear XOR function. Use of a non linear activation function
8is very important, as without them the network would not be able to remap the
9inputs into a new space that can be linearly seperated.
10
11Rather than symbolically differentiate the model y = sigmoid(sigmoid(x * w1) * w2) * w3
12the [Record](super::differentiation::Record) struct is used to perform reverse
13[automatic differentiation](super::differentiation). This adds a slight
14memory overhead but also makes it easy to experiment with adding or tweaking layers
15of the network or trying different activation functions like ReLu or tanh.
16
17Note that the gradients recorded in each epoch must be cleared before training in
18the next one.
19
20## Matrix APIs
21
22```
23use easy_ml::matrices::Matrix;
24use easy_ml::matrices::views::{MatrixRange, MatrixView, MatrixRef, NoInteriorMutability, IndexRange};
25use easy_ml::numeric::Numeric;
26use easy_ml::numeric::extra::Real;
27use easy_ml::differentiation::{Record, RecordMatrix, WengertList, Index};
28
29use rand::{Rng, SeedableRng};
30use rand::distr::StandardUniform;
31
32use textplots::{Chart, Plot, Shape};
33
34/**
35 * Utility function to create a list of random numbers.
36 */
37fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
38    random_generator.sample_iter(StandardUniform).take(n).collect()
39}
40
41/**
42 * The sigmoid function which will be used as a non linear activation function.
43 *
44 * This is written for a generic type, so it can be used with records and also
45 * with normal floats.
46 */
47fn sigmoid<T: Real + Copy>(x: T) -> T {
48    // 1 / (1 + e^-x)
49    T::one() / (T::one() + (-x).exp())
50}
51
52/**
53 * A simple three layer neural network that outputs a scalar.
54 */
55fn model(
56    input: &Matrix<f32>, w1: &Matrix<f32>, w2: &Matrix<f32>, w3: &Matrix<f32>
57) -> f32 {
58    (((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).scalar()
59}
60
61/**
62 * A simple three layer neural network that outputs a scalar, using RecordMatrix types for the
63 * inputs to track derivatives.
64 */
65fn model_training<'a, I>(
66    input: &RecordMatrix<'a, f32, I>,
67    w1: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
68    w2: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
69    w3: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>
70) -> Record<'a, f32>
71where
72    I: MatrixRef<(f32, Index)> + NoInteriorMutability,
73{
74    (((input * w1).map(sigmoid).unwrap() * w2).map(sigmoid).unwrap() * w3).get_as_record(0, 0)
75}
76
77/**
78 * Computes mean squared loss of the network against all the training data.
79 */
80fn mean_squared_loss(
81   inputs: &Vec<Matrix<f32>>,
82   w1: &Matrix<f32>,
83   w2: &Matrix<f32>,
84   w3: &Matrix<f32>,
85   labels: &Vec<f32>
86) -> f32 {
87    inputs.iter().enumerate().fold(0.0, |acc, (i, input)| {
88        let output = model(input, w1, w2, w3);
89        let correct = labels[i];
90        // sum up the squared loss
91        acc + ((correct - output) * (correct - output))
92    }) / inputs.len() as f32
93}
94
95/**
96 * Computes mean squared loss of the network against all the training data, using RecordMatrix
97 * types for the inputs to track derivatives.
98 */
99fn mean_squared_loss_training<'a>(
100    inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
101    w1: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
102    w2: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
103    w3: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
104    labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
105) -> Record<'a, f32> {
106    let rows = inputs.rows();
107    let columns = inputs.columns();
108    let history = w1.history();
109    (0..rows).map(|r| {
110        // take each row as its own RecordMatrix input
111        (r, RecordMatrix::from_existing(
112            history,
113            MatrixView::from(
114                MatrixRange::from(
115                    inputs,
116                    IndexRange::new(r, 1),
117                    IndexRange::new(0, columns),
118                )
119            )
120        ))
121    }).fold(Record::constant(0.0), |acc, (r, input)| {
122        let output = model_training(&input, w1, w2, w3);
123        let correct = labels.get_as_record(0, r);
124        // sum up the squared loss
125        acc + ((correct - output) * (correct - output))
126    }) / (rows as f32)
127}
128
129/**
130 * Updates the weight matrices to step the gradient by one step.
131 *
132 * Note that here we need the methods defined on Record / RecordMatrix to do backprop. There is
133 * no non-training version of this we can define without deriative tracking.
134 */
135fn step_gradient<'a>(
136    inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
137    w1: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
138    w2: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
139    w3: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
140    labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
141    learning_rate: f32,
142    list: &'a WengertList<f32>
143) -> f32 {
144    let loss = mean_squared_loss_training(inputs, w1, w2, w3, labels);
145    let derivatives = loss.derivatives();
146    // update each element in the weight matrices by the derivatives
147    // unwrapping here is fine because we're not changing the history of any variables
148    w1.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
149    w2.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
150    w3.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
151    // reset gradients
152    list.clear();
153    w1.reset();
154    w2.reset();
155    w3.reset();
156    // return the loss
157    loss.number
158}
159
160// use a fixed seed random generator from the rand crate
161let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
162
163// randomly initalise the weights using the fixed seed generator for reproducibility
164let list = WengertList::new();
165// w1 will be a 3x3 matrix
166let mut w1 = RecordMatrix::variables(
167    &list,
168    Matrix::from(
169        vec![
170            n_random_numbers(&mut random_generator, 3),
171            n_random_numbers(&mut random_generator, 3),
172            n_random_numbers(&mut random_generator, 3)
173        ]
174    )
175);
176// w2 will be a 3x3 matrix
177let mut w2 = RecordMatrix::variables(
178    &list,
179    Matrix::from(
180        vec![
181            n_random_numbers(&mut random_generator, 3),
182            n_random_numbers(&mut random_generator, 3),
183            n_random_numbers(&mut random_generator, 3)
184        ]
185    )
186);
187// w3 will be a 3x1 column matrix
188let mut w3 = RecordMatrix::variables(
189    &list,
190    Matrix::column(n_random_numbers(&mut random_generator, 3))
191);
192println!("w1 {}", w1);
193println!("w2 {}", w2);
194println!("w3 {}", w3);
195
196// define XOR inputs, with biases added to the inputs
197let inputs = RecordMatrix::constants(
198    Matrix::from(
199        vec![
200            vec![ 0.0, 0.0, 1.0 ],
201            vec![ 0.0, 1.0, 1.0 ],
202            vec![ 1.0, 0.0, 1.0 ],
203            vec![ 1.0, 1.0, 1.0 ],
204        ]
205    )
206);
207// define XOR outputs which will be used as labels
208let labels = RecordMatrix::constants(
209    Matrix::row(vec![ 0.0, 1.0, 1.0, 0.0 ])
210);
211let learning_rate = 0.2;
212let epochs = 4000;
213
214// do the gradient descent and save the loss at each epoch
215let mut losses = Vec::with_capacity(epochs);
216for _ in 0..epochs {
217    losses.push(step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list))
218}
219
220// now plot the training loss
221let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
222chart.lineplot(
223    &Shape::Lines(&losses.iter()
224        .cloned()
225        .enumerate()
226        .map(|(i, x)| (i as f32, x))
227        .collect::<Vec<(f32, f32)>>())
228    ).display();
229
230// note that with different hyperparameters, starting weights, or less training
231// the network may not have converged and could still be outputting 0.5 for everything,
232// the chart plot with this configuration is particularly interesting because the loss
233// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
234// finally learning how to remap the input data in a way which can then be linearly
235// seperated to achieve ~0.0 loss.
236
237// check that the weights are sensible
238println!("w1 {}", w1);
239println!("w2 {}", w2);
240println!("w3 {}", w3);
241// check that the network has learned XOR properly
242
243let row_1 = RecordMatrix::from_existing(
244    Some(&list),
245    MatrixView::from(MatrixRange::from(&inputs, 0..1, 0..3))
246);
247let row_2 = RecordMatrix::from_existing(
248    Some(&list),
249    MatrixView::from(MatrixRange::from(&inputs, 1..2, 0..3))
250);
251let row_3 = RecordMatrix::from_existing(
252    Some(&list),
253    MatrixView::from(MatrixRange::from(&inputs, 2..3, 0..3))
254);
255let row_4 = RecordMatrix::from_existing(
256    Some(&list),
257    MatrixView::from(MatrixRange::from(&inputs, 3..4, 0..3))
258);
259println!("0 0: {:?}", model_training(&row_1, &w1, &w2, &w3).number);
260println!("0 1: {:?}", model_training(&row_2, &w1, &w2, &w3).number);
261println!("1 0: {:?}", model_training(&row_3, &w1, &w2, &w3).number);
262println!("1 1: {:?}", model_training(&row_4, &w1, &w2, &w3).number);
263assert!(losses[epochs - 1] < 0.02);
264
265// we can also extract the learned weights once done with training and avoid the memory
266// overhead of Record
267let w1_final = w1.view().map(|(x, _)| x);
268let w2_final = w2.view().map(|(x, _)| x);
269let w3_final = w3.view().map(|(x, _)| x);
270println!("0 0: {:?}", model(&row_1.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
271println!("0 1: {:?}", model(&row_2.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
272println!("1 0: {:?}", model(&row_3.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
273println!("1 1: {:?}", model(&row_4.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
274```
275
276## Tensor APIs
277
278```
279use easy_ml::tensors::Tensor;
280use easy_ml::tensors::views::{TensorView, TensorRef};
281use easy_ml::numeric::{Numeric, NumericRef};
282use easy_ml::numeric::extra::{Real, RealRef, Exp};
283use easy_ml::differentiation::{Record, RecordTensor, WengertList, Index};
284
285use rand::{Rng, SeedableRng};
286use rand::distr::StandardUniform;
287
288use textplots::{Chart, Plot, Shape};
289
290/**
291 * Utility function to create a list of random numbers.
292 */
293fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
294    random_generator.sample_iter(StandardUniform).take(n).collect()
295}
296
297/**
298 * The sigmoid function which will be used as a non linear activation function.
299 *
300 * This is written for a generic type, so it can be used with records and also
301 * with normal floats.
302 */
303fn sigmoid<T: Real + Copy>(x: T) -> T {
304    // 1 / (1 + e^-x)
305    T::one() / (T::one() + (-x).exp())
306}
307
308/**
309 * A simple three layer neural network that outputs a scalar.
310 */
311fn model<I>(
312    input: &TensorView<f32, I, 2>,
313    w1: &Tensor<f32, 2>,
314    w2: &Tensor<f32, 2>,
315    w3: &Tensor<f32, 2>,
316) -> f32
317where
318    I: TensorRef<f32, 2>,
319{
320    (((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).first()
321}
322
323/**
324 * A simple three layer neural network that outputs a scalar, using RecordTensor types for
325 * the inputs to track derivatives.
326 */
327fn model_training<'a, I>(
328    input: &RecordTensor<'a, f32, I, 2>,
329    w1: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
330    w2: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
331    w3: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
332) -> Record<'a, f32>
333where
334    I: TensorRef<(f32, Index), 2>,
335{
336    (((input * w1).map(sigmoid).unwrap() * w2).map(sigmoid).unwrap() * w3)
337        .index()
338        .get_as_record([0, 0])
339}
340
341/**
342 * Computes mean squared loss of the network against all the training data.
343 */
344fn mean_squared_loss(
345   inputs: &Tensor<f32, 3>,
346   w1: &Tensor<f32, 2>,
347   w2: &Tensor<f32, 2>,
348   w3: &Tensor<f32, 2>,
349   labels: &Tensor<f32, 1>,
350) -> f32 {
351    let inputs_shape = inputs.shape();
352    let number_of_samples = inputs_shape[0].1;
353    let samples_name = inputs_shape[0].0;
354    {
355        let mut sum = 0.0;
356        for i in 0..number_of_samples {
357            let input = inputs.select([(samples_name, i)]);
358            let output = model(&input, w1, w2, w3);
359            let correct = labels.index().get([i]);
360            // sum up the squared loss
361            sum = sum + ((correct - output) * (correct - output));
362        }
363        sum / (number_of_samples as f32)
364    }
365}
366
367/**
368 * Computes mean squared loss of the network against all the training data, using RecordTensor
369 * types for the inputs to track derivatives.
370 */
371fn mean_squared_loss_training<'a>(
372   inputs: &RecordTensor<'a, f32, Tensor<(f32, Index), 3>, 3>,
373   w1: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
374   w2: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
375   w3: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
376   labels: &RecordTensor<'a, f32, Tensor<(f32, Index), 1>, 1>,
377) -> Record<'a, f32> {
378    let inputs_shape = inputs.shape();
379    let number_of_samples = inputs_shape[0].1;
380    let samples_name = inputs_shape[0].0;
381    let history = w1.history();
382    {
383        let mut sum = Record::constant(0.0);
384        for i in 0..number_of_samples {
385            let input = inputs.view();
386            let input = RecordTensor::from_existing(
387                history,
388                input.select([(samples_name, i)])
389            );
390            let output = model_training(&input, w1, w2, w3);
391            let correct = labels.index().get_as_record([i]);
392            // sum up the squared loss
393            sum = sum + ((correct - output) * (correct - output));
394        }
395        sum / (number_of_samples as f32)
396    }
397}
398
399/**
400 * Updates the weight matrices to step the gradient by one step.
401 *
402 * Note that here we need the methods defined on Record / RecordTensor to do backprop. There is
403 * no non-training version of this we can define without deriative tracking.
404 */
405fn step_gradient<'a>(
406    inputs: &RecordTensor<'a, f32, Tensor<(f32, Index), 3>, 3>,
407    w1: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
408    w2: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
409    w3: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
410    labels: &RecordTensor<'a, f32, Tensor<(f32, Index), 1>, 1>,
411    learning_rate: f32,
412    list: &'a WengertList<f32>
413) -> f32 {
414    let loss = mean_squared_loss_training(inputs, w1, w2, w3, labels);
415    let derivatives = loss.derivatives();
416    // update each element in the weight matrices by the derivatives
417    // unwrapping here is fine because we're not changing the history of any variables
418    w1.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
419    w2.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
420    w3.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
421    // reset gradients
422    list.clear();
423    w1.reset();
424    w2.reset();
425    w3.reset();
426    // return the loss
427    loss.number
428}
429
430// use a fixed seed random generator from the rand crate
431let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
432
433// randomly initalise the weights using the fixed seed generator for reproducibility
434let list = WengertList::new();
435// w1 will be a 3x3 matrix
436let mut w1 = RecordTensor::variables(
437    &list,
438    Tensor::from([("r", 3), ("c", 3)], n_random_numbers(&mut random_generator, 9))
439);
440// w2 will be a 3x3 matrix
441let mut w2 = RecordTensor::variables(
442    &list,
443    Tensor::from([("r", 3), ("c", 3)], n_random_numbers(&mut random_generator, 9))
444);
445// w3 will be a 3x1 column matrix
446// Note: We keep the shape here as 3x1 instead of just a 3 length vector to keep matrix
447// multiplication simple like the Matrix example
448let mut w3 = RecordTensor::variables(
449    &list,
450    Tensor::from([("r", 3), ("c", 1)], n_random_numbers(&mut random_generator, 3))
451);
452println!("w1 {}", w1);
453println!("w2 {}", w2);
454println!("w3 {}", w3);
455
456// define XOR inputs, with biases added to the inputs
457// again, it keeps the matrix multiplication easier if we stick to row matrices
458// instead of actual vectors here
459let inputs = RecordTensor::constants(
460    Tensor::from([("sample", 4), ("r", 1), ("c", 3)], vec![
461        0.0, 0.0, 1.0,
462
463        0.0, 1.0, 1.0,
464
465        1.0, 0.0, 1.0,
466
467        1.0, 1.0, 1.0
468    ])
469);
470// define XOR outputs which will be used as labels
471let labels = RecordTensor::constants(
472    Tensor::from([("sample", 4)], vec![ 0.0, 1.0, 1.0, 0.0 ])
473);
474let learning_rate = 0.2;
475let epochs = 4000;
476
477// do the gradient descent and save the loss at each epoch
478let mut losses = Vec::with_capacity(epochs);
479for _ in 0..epochs {
480    losses.push(
481        step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list)
482    );
483}
484
485// now plot the training loss
486let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
487chart.lineplot(
488    &Shape::Lines(&losses.iter()
489        .cloned()
490        .enumerate()
491        .map(|(i, x)| (i as f32, x))
492        .collect::<Vec<(f32, f32)>>())
493    ).display();
494
495// note that with different hyperparameters, starting weights, or less training
496// the network may not have converged and could still be outputting 0.5 for everything,
497// the chart plot with this configuration is particularly interesting because the loss
498// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
499// finally learning how to remap the input data in a way which can then be linearly
500// seperated to achieve ~0.0 loss.
501
502// check that the weights are sensible
503println!("w1 {}", w1);
504println!("w2 {}", w2);
505println!("w3 {}", w3);
506// check that the network has learned XOR properly
507
508{
509    let inputs = inputs.view();
510    let row_1 = RecordTensor::from_existing(
511        Some(&list),
512        inputs.select([("sample", 0)])
513    );
514    let row_2 = RecordTensor::from_existing(
515        Some(&list),
516        inputs.select([("sample", 1)])
517    );
518    let row_3 = RecordTensor::from_existing(
519        Some(&list),
520        inputs.select([("sample", 2)])
521    );
522    let row_4 = RecordTensor::from_existing(
523        Some(&list),
524        inputs.select([("sample", 3)])
525    );
526
527    println!("0 0: {:?}", model_training(&row_1, &w1, &w2, &w3).number);
528    println!("0 1: {:?}", model_training(&row_2, &w1, &w2, &w3).number);
529    println!("1 0: {:?}", model_training(&row_3, &w1, &w2, &w3).number);
530    println!("1 1: {:?}", model_training(&row_4, &w1, &w2, &w3).number);
531}
532assert!(losses[epochs - 1] < 0.02);
533
534// we can also extract the learned weights once done with training and avoid the memory
535// overhead of Record
536let w1_final = w1.view().map(|(x, _)| x);
537let w2_final = w2.view().map(|(x, _)| x);
538let w3_final = w3.view().map(|(x, _)| x);
539let inputs_final = inputs.view().map(|(x, _)| x);
540println!(
541    "0 0: {:?}",
542    model(&inputs_final.select([("sample", 0)]), &w1_final, &w2_final, &w3_final)
543);
544println!(
545    "0 1: {:?}",
546    model(&inputs_final.select([("sample", 1)]), &w1_final, &w2_final, &w3_final)
547);
548println!(
549    "1 0: {:?}",
550    model(&inputs_final.select([("sample", 2)]), &w1_final, &w2_final, &w3_final)
551);
552println!(
553    "1 1: {:?}",
554    model(&inputs_final.select([("sample", 3)]), &w1_final, &w2_final, &w3_final)
555);
556```
557
558# Handwritten digit recognition on the MNIST dataset
559
560[Web Assembly example](super::web_assembly#handwritten-digit-recognition-on-the-mnist-dataset)
561 */