easy_ml/neural_networks.rs
1/*!
2Neural Network training examples
3
4# XOR
5
6The following code shows a simple network using the sigmoid activation function
7to learn the non linear XOR function. Use of a non linear activation function
8is very important, as without them the network would not be able to remap the
9inputs into a new space that can be linearly seperated.
10
11Rather than symbolically differentiate the model y = sigmoid(sigmoid(x * w1) * w2) * w3
12the [Record](super::differentiation::Record) struct is used to perform reverse
13[automatic differentiation](super::differentiation). This adds a slight
14memory overhead but also makes it easy to experiment with adding or tweaking layers
15of the network or trying different activation functions like ReLu or tanh.
16
17Note that the gradients recorded in each epoch must be cleared before training in
18the next one.
19
20## Matrix APIs
21
22```
23use easy_ml::matrices::Matrix;
24use easy_ml::matrices::views::{MatrixRange, MatrixView, MatrixRef, NoInteriorMutability, IndexRange};
25use easy_ml::numeric::Numeric;
26use easy_ml::numeric::extra::Real;
27use easy_ml::differentiation::{Record, RecordMatrix, WengertList, Index};
28
29use rand::{Rng, SeedableRng};
30use rand::distr::StandardUniform;
31
32use textplots::{Chart, Plot, Shape};
33
34/**
35 * Utility function to create a list of random numbers.
36 */
37fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
38 random_generator.sample_iter(StandardUniform).take(n).collect()
39}
40
41/**
42 * The sigmoid function which will be used as a non linear activation function.
43 *
44 * This is written for a generic type, so it can be used with records and also
45 * with normal floats.
46 */
47fn sigmoid<T: Real + Copy>(x: T) -> T {
48 // 1 / (1 + e^-x)
49 T::one() / (T::one() + (-x).exp())
50}
51
52/**
53 * A simple three layer neural network that outputs a scalar.
54 */
55fn model(
56 input: &Matrix<f32>, w1: &Matrix<f32>, w2: &Matrix<f32>, w3: &Matrix<f32>
57) -> f32 {
58 (((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).scalar()
59}
60
61/**
62 * A simple three layer neural network that outputs a scalar, using RecordMatrix types for the
63 * inputs to track derivatives.
64 */
65fn model_training<'a, I>(
66 input: &RecordMatrix<'a, f32, I>,
67 w1: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
68 w2: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
69 w3: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>
70) -> Record<'a, f32>
71where
72 I: MatrixRef<(f32, Index)> + NoInteriorMutability,
73{
74 (((input * w1).map(sigmoid).unwrap() * w2).map(sigmoid).unwrap() * w3).get_as_record(0, 0)
75}
76
77/**
78 * Computes mean squared loss of the network against all the training data.
79 */
80fn mean_squared_loss(
81 inputs: &Vec<Matrix<f32>>,
82 w1: &Matrix<f32>,
83 w2: &Matrix<f32>,
84 w3: &Matrix<f32>,
85 labels: &Vec<f32>
86) -> f32 {
87 inputs.iter().enumerate().fold(0.0, |acc, (i, input)| {
88 let output = model(input, w1, w2, w3);
89 let correct = labels[i];
90 // sum up the squared loss
91 acc + ((correct - output) * (correct - output))
92 }) / inputs.len() as f32
93}
94
95/**
96 * Computes mean squared loss of the network against all the training data, using RecordMatrix
97 * types for the inputs to track derivatives.
98 */
99fn mean_squared_loss_training<'a>(
100 inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
101 w1: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
102 w2: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
103 w3: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
104 labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
105) -> Record<'a, f32> {
106 let rows = inputs.rows();
107 let columns = inputs.columns();
108 let history = w1.history();
109 (0..rows).map(|r| {
110 // take each row as its own RecordMatrix input
111 (r, RecordMatrix::from_existing(
112 history,
113 MatrixView::from(
114 MatrixRange::from(
115 inputs,
116 IndexRange::new(r, 1),
117 IndexRange::new(0, columns),
118 )
119 )
120 ))
121 }).fold(Record::constant(0.0), |acc, (r, input)| {
122 let output = model_training(&input, w1, w2, w3);
123 let correct = labels.get_as_record(0, r);
124 // sum up the squared loss
125 acc + ((correct - output) * (correct - output))
126 }) / (rows as f32)
127}
128
129/**
130 * Updates the weight matrices to step the gradient by one step.
131 *
132 * Note that here we need the methods defined on Record / RecordMatrix to do backprop. There is
133 * no non-training version of this we can define without deriative tracking.
134 */
135fn step_gradient<'a>(
136 inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
137 w1: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
138 w2: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
139 w3: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
140 labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
141 learning_rate: f32,
142 list: &'a WengertList<f32>
143) -> f32 {
144 let loss = mean_squared_loss_training(inputs, w1, w2, w3, labels);
145 let derivatives = loss.derivatives();
146 // update each element in the weight matrices by the derivatives
147 // unwrapping here is fine because we're not changing the history of any variables
148 w1.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
149 w2.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
150 w3.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
151 // reset gradients
152 list.clear();
153 w1.reset();
154 w2.reset();
155 w3.reset();
156 // return the loss
157 loss.number
158}
159
160// use a fixed seed random generator from the rand crate
161let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
162
163// randomly initalise the weights using the fixed seed generator for reproducibility
164let list = WengertList::new();
165// w1 will be a 3x3 matrix
166let mut w1 = RecordMatrix::variables(
167 &list,
168 Matrix::from(
169 vec![
170 n_random_numbers(&mut random_generator, 3),
171 n_random_numbers(&mut random_generator, 3),
172 n_random_numbers(&mut random_generator, 3)
173 ]
174 )
175);
176// w2 will be a 3x3 matrix
177let mut w2 = RecordMatrix::variables(
178 &list,
179 Matrix::from(
180 vec![
181 n_random_numbers(&mut random_generator, 3),
182 n_random_numbers(&mut random_generator, 3),
183 n_random_numbers(&mut random_generator, 3)
184 ]
185 )
186);
187// w3 will be a 3x1 column matrix
188let mut w3 = RecordMatrix::variables(
189 &list,
190 Matrix::column(n_random_numbers(&mut random_generator, 3))
191);
192println!("w1 {}", w1);
193println!("w2 {}", w2);
194println!("w3 {}", w3);
195
196// define XOR inputs, with biases added to the inputs
197let inputs = RecordMatrix::constants(
198 Matrix::from(
199 vec![
200 vec![ 0.0, 0.0, 1.0 ],
201 vec![ 0.0, 1.0, 1.0 ],
202 vec![ 1.0, 0.0, 1.0 ],
203 vec![ 1.0, 1.0, 1.0 ],
204 ]
205 )
206);
207// define XOR outputs which will be used as labels
208let labels = RecordMatrix::constants(
209 Matrix::row(vec![ 0.0, 1.0, 1.0, 0.0 ])
210);
211let learning_rate = 0.2;
212let epochs = 4000;
213
214// do the gradient descent and save the loss at each epoch
215let mut losses = Vec::with_capacity(epochs);
216for _ in 0..epochs {
217 losses.push(step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list))
218}
219
220// now plot the training loss
221let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
222chart.lineplot(
223 &Shape::Lines(&losses.iter()
224 .cloned()
225 .enumerate()
226 .map(|(i, x)| (i as f32, x))
227 .collect::<Vec<(f32, f32)>>())
228 ).display();
229
230// note that with different hyperparameters, starting weights, or less training
231// the network may not have converged and could still be outputting 0.5 for everything,
232// the chart plot with this configuration is particularly interesting because the loss
233// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
234// finally learning how to remap the input data in a way which can then be linearly
235// seperated to achieve ~0.0 loss.
236
237// check that the weights are sensible
238println!("w1 {}", w1);
239println!("w2 {}", w2);
240println!("w3 {}", w3);
241// check that the network has learned XOR properly
242
243let row_1 = RecordMatrix::from_existing(
244 Some(&list),
245 MatrixView::from(MatrixRange::from(&inputs, 0..1, 0..3))
246);
247let row_2 = RecordMatrix::from_existing(
248 Some(&list),
249 MatrixView::from(MatrixRange::from(&inputs, 1..2, 0..3))
250);
251let row_3 = RecordMatrix::from_existing(
252 Some(&list),
253 MatrixView::from(MatrixRange::from(&inputs, 2..3, 0..3))
254);
255let row_4 = RecordMatrix::from_existing(
256 Some(&list),
257 MatrixView::from(MatrixRange::from(&inputs, 3..4, 0..3))
258);
259println!("0 0: {:?}", model_training(&row_1, &w1, &w2, &w3).number);
260println!("0 1: {:?}", model_training(&row_2, &w1, &w2, &w3).number);
261println!("1 0: {:?}", model_training(&row_3, &w1, &w2, &w3).number);
262println!("1 1: {:?}", model_training(&row_4, &w1, &w2, &w3).number);
263assert!(losses[epochs - 1] < 0.02);
264
265// we can also extract the learned weights once done with training and avoid the memory
266// overhead of Record
267let w1_final = w1.view().map(|(x, _)| x);
268let w2_final = w2.view().map(|(x, _)| x);
269let w3_final = w3.view().map(|(x, _)| x);
270println!("0 0: {:?}", model(&row_1.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
271println!("0 1: {:?}", model(&row_2.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
272println!("1 0: {:?}", model(&row_3.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
273println!("1 1: {:?}", model(&row_4.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
274```
275
276## Tensor APIs
277
278```
279use easy_ml::tensors::Tensor;
280use easy_ml::tensors::views::{TensorView, TensorRef};
281use easy_ml::numeric::{Numeric, NumericRef};
282use easy_ml::numeric::extra::{Real, RealRef, Exp};
283use easy_ml::differentiation::{Record, RecordTensor, WengertList, Index};
284
285use rand::{Rng, SeedableRng};
286use rand::distr::StandardUniform;
287
288use textplots::{Chart, Plot, Shape};
289
290/**
291 * Utility function to create a list of random numbers.
292 */
293fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
294 random_generator.sample_iter(StandardUniform).take(n).collect()
295}
296
297/**
298 * The sigmoid function which will be used as a non linear activation function.
299 *
300 * This is written for a generic type, so it can be used with records and also
301 * with normal floats.
302 */
303fn sigmoid<T: Real + Copy>(x: T) -> T {
304 // 1 / (1 + e^-x)
305 T::one() / (T::one() + (-x).exp())
306}
307
308/**
309 * A simple three layer neural network that outputs a scalar.
310 */
311fn model<I>(
312 input: &TensorView<f32, I, 2>,
313 w1: &Tensor<f32, 2>,
314 w2: &Tensor<f32, 2>,
315 w3: &Tensor<f32, 2>,
316) -> f32
317where
318 I: TensorRef<f32, 2>,
319{
320 (((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).first()
321}
322
323/**
324 * A simple three layer neural network that outputs a scalar, using RecordTensor types for
325 * the inputs to track derivatives.
326 */
327fn model_training<'a, I>(
328 input: &RecordTensor<'a, f32, I, 2>,
329 w1: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
330 w2: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
331 w3: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
332) -> Record<'a, f32>
333where
334 I: TensorRef<(f32, Index), 2>,
335{
336 (((input * w1).map(sigmoid).unwrap() * w2).map(sigmoid).unwrap() * w3)
337 .index()
338 .get_as_record([0, 0])
339}
340
341/**
342 * Computes mean squared loss of the network against all the training data.
343 */
344fn mean_squared_loss(
345 inputs: &Tensor<f32, 3>,
346 w1: &Tensor<f32, 2>,
347 w2: &Tensor<f32, 2>,
348 w3: &Tensor<f32, 2>,
349 labels: &Tensor<f32, 1>,
350) -> f32 {
351 let inputs_shape = inputs.shape();
352 let number_of_samples = inputs_shape[0].1;
353 let samples_name = inputs_shape[0].0;
354 {
355 let mut sum = 0.0;
356 for i in 0..number_of_samples {
357 let input = inputs.select([(samples_name, i)]);
358 let output = model(&input, w1, w2, w3);
359 let correct = labels.index().get([i]);
360 // sum up the squared loss
361 sum = sum + ((correct - output) * (correct - output));
362 }
363 sum / (number_of_samples as f32)
364 }
365}
366
367/**
368 * Computes mean squared loss of the network against all the training data, using RecordTensor
369 * types for the inputs to track derivatives.
370 */
371fn mean_squared_loss_training<'a>(
372 inputs: &RecordTensor<'a, f32, Tensor<(f32, Index), 3>, 3>,
373 w1: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
374 w2: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
375 w3: &RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
376 labels: &RecordTensor<'a, f32, Tensor<(f32, Index), 1>, 1>,
377) -> Record<'a, f32> {
378 let inputs_shape = inputs.shape();
379 let number_of_samples = inputs_shape[0].1;
380 let samples_name = inputs_shape[0].0;
381 let history = w1.history();
382 {
383 let mut sum = Record::constant(0.0);
384 for i in 0..number_of_samples {
385 let input = inputs.view();
386 let input = RecordTensor::from_existing(
387 history,
388 input.select([(samples_name, i)])
389 );
390 let output = model_training(&input, w1, w2, w3);
391 let correct = labels.index().get_as_record([i]);
392 // sum up the squared loss
393 sum = sum + ((correct - output) * (correct - output));
394 }
395 sum / (number_of_samples as f32)
396 }
397}
398
399/**
400 * Updates the weight matrices to step the gradient by one step.
401 *
402 * Note that here we need the methods defined on Record / RecordTensor to do backprop. There is
403 * no non-training version of this we can define without deriative tracking.
404 */
405fn step_gradient<'a>(
406 inputs: &RecordTensor<'a, f32, Tensor<(f32, Index), 3>, 3>,
407 w1: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
408 w2: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
409 w3: &mut RecordTensor<'a, f32, Tensor<(f32, Index), 2>, 2>,
410 labels: &RecordTensor<'a, f32, Tensor<(f32, Index), 1>, 1>,
411 learning_rate: f32,
412 list: &'a WengertList<f32>
413) -> f32 {
414 let loss = mean_squared_loss_training(inputs, w1, w2, w3, labels);
415 let derivatives = loss.derivatives();
416 // update each element in the weight matrices by the derivatives
417 // unwrapping here is fine because we're not changing the history of any variables
418 w1.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
419 w2.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
420 w3.map_mut(|x| x - (derivatives[&x] * learning_rate)).unwrap();
421 // reset gradients
422 list.clear();
423 w1.reset();
424 w2.reset();
425 w3.reset();
426 // return the loss
427 loss.number
428}
429
430// use a fixed seed random generator from the rand crate
431let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
432
433// randomly initalise the weights using the fixed seed generator for reproducibility
434let list = WengertList::new();
435// w1 will be a 3x3 matrix
436let mut w1 = RecordTensor::variables(
437 &list,
438 Tensor::from([("r", 3), ("c", 3)], n_random_numbers(&mut random_generator, 9))
439);
440// w2 will be a 3x3 matrix
441let mut w2 = RecordTensor::variables(
442 &list,
443 Tensor::from([("r", 3), ("c", 3)], n_random_numbers(&mut random_generator, 9))
444);
445// w3 will be a 3x1 column matrix
446// Note: We keep the shape here as 3x1 instead of just a 3 length vector to keep matrix
447// multiplication simple like the Matrix example
448let mut w3 = RecordTensor::variables(
449 &list,
450 Tensor::from([("r", 3), ("c", 1)], n_random_numbers(&mut random_generator, 3))
451);
452println!("w1 {}", w1);
453println!("w2 {}", w2);
454println!("w3 {}", w3);
455
456// define XOR inputs, with biases added to the inputs
457// again, it keeps the matrix multiplication easier if we stick to row matrices
458// instead of actual vectors here
459let inputs = RecordTensor::constants(
460 Tensor::from([("sample", 4), ("r", 1), ("c", 3)], vec![
461 0.0, 0.0, 1.0,
462
463 0.0, 1.0, 1.0,
464
465 1.0, 0.0, 1.0,
466
467 1.0, 1.0, 1.0
468 ])
469);
470// define XOR outputs which will be used as labels
471let labels = RecordTensor::constants(
472 Tensor::from([("sample", 4)], vec![ 0.0, 1.0, 1.0, 0.0 ])
473);
474let learning_rate = 0.2;
475let epochs = 4000;
476
477// do the gradient descent and save the loss at each epoch
478let mut losses = Vec::with_capacity(epochs);
479for _ in 0..epochs {
480 losses.push(
481 step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list)
482 );
483}
484
485// now plot the training loss
486let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
487chart.lineplot(
488 &Shape::Lines(&losses.iter()
489 .cloned()
490 .enumerate()
491 .map(|(i, x)| (i as f32, x))
492 .collect::<Vec<(f32, f32)>>())
493 ).display();
494
495// note that with different hyperparameters, starting weights, or less training
496// the network may not have converged and could still be outputting 0.5 for everything,
497// the chart plot with this configuration is particularly interesting because the loss
498// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
499// finally learning how to remap the input data in a way which can then be linearly
500// seperated to achieve ~0.0 loss.
501
502// check that the weights are sensible
503println!("w1 {}", w1);
504println!("w2 {}", w2);
505println!("w3 {}", w3);
506// check that the network has learned XOR properly
507
508{
509 let inputs = inputs.view();
510 let row_1 = RecordTensor::from_existing(
511 Some(&list),
512 inputs.select([("sample", 0)])
513 );
514 let row_2 = RecordTensor::from_existing(
515 Some(&list),
516 inputs.select([("sample", 1)])
517 );
518 let row_3 = RecordTensor::from_existing(
519 Some(&list),
520 inputs.select([("sample", 2)])
521 );
522 let row_4 = RecordTensor::from_existing(
523 Some(&list),
524 inputs.select([("sample", 3)])
525 );
526
527 println!("0 0: {:?}", model_training(&row_1, &w1, &w2, &w3).number);
528 println!("0 1: {:?}", model_training(&row_2, &w1, &w2, &w3).number);
529 println!("1 0: {:?}", model_training(&row_3, &w1, &w2, &w3).number);
530 println!("1 1: {:?}", model_training(&row_4, &w1, &w2, &w3).number);
531}
532assert!(losses[epochs - 1] < 0.02);
533
534// we can also extract the learned weights once done with training and avoid the memory
535// overhead of Record
536let w1_final = w1.view().map(|(x, _)| x);
537let w2_final = w2.view().map(|(x, _)| x);
538let w3_final = w3.view().map(|(x, _)| x);
539let inputs_final = inputs.view().map(|(x, _)| x);
540println!(
541 "0 0: {:?}",
542 model(&inputs_final.select([("sample", 0)]), &w1_final, &w2_final, &w3_final)
543);
544println!(
545 "0 1: {:?}",
546 model(&inputs_final.select([("sample", 1)]), &w1_final, &w2_final, &w3_final)
547);
548println!(
549 "1 0: {:?}",
550 model(&inputs_final.select([("sample", 2)]), &w1_final, &w2_final, &w3_final)
551);
552println!(
553 "1 1: {:?}",
554 model(&inputs_final.select([("sample", 3)]), &w1_final, &w2_final, &w3_final)
555);
556```
557
558# Handwritten digit recognition on the MNIST dataset
559
560[Web Assembly example](super::web_assembly#handwritten-digit-recognition-on-the-mnist-dataset)
561 */