1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
/*!
Neural Network training examples
# XOR
The following code shows a simple network using the sigmoid activation function
to learn the non linear XOR function. Use of a non linear activation function
is very important, as without them the network would not be able to remap the
inputs into a new space that can be linearly seperated.
Rather than symbolically differentiate the model y = sigmoid(sigmoid(x * w1) * w2) * w3
the [Record](super::differentiation::Record) struct is used to perform reverse
[automatic differentiation](super::differentiation). This adds a slight
memory overhead but also makes it easy to experiment with adding or tweaking layers
of the network or trying different activation functions like ReLu or tanh.
Note that the gradients recorded in each epoch must be cleared before training in
the next one.
Also note that the calls to `model` and `mean_squared_loss` use the rust turbofish syntax
to tell the compiler explicitly what type we're monomorphising, as otherwise the rust
compiler gets confused trying to infer the type.
```
use easy_ml::matrices::Matrix;
use easy_ml::numeric::{Numeric, NumericRef};
use easy_ml::numeric::extra::{Real, RealRef, Exp};
use easy_ml::differentiation::{Record, WengertList};
use rand::{Rng, SeedableRng};
use rand::distributions::Standard;
use textplots::{Chart, Plot, Shape};
/**
* Utility function to create a list of random numbers.
*/
fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
random_generator.sample_iter(Standard).take(n).collect()
}
/**
* The sigmoid function which will be used as a non linear activation function.
*
* This is written for a generic type, so it can be used with records and also
* with normal floats.
*/
fn sigmoid<T: Numeric + Real + Copy>(x: T) -> T {
// 1 / (1 + e^-x)
T::one() / (T::one() + (-x).exp())
}
/**
* A simple three layer neural network that outputs a scalar.
*
* This is written for a generic type, so it can be used with records and also
* with normal floats.
*/
fn model<T: Numeric + Real + Copy>(
input: &Matrix<T>, w1: &Matrix<T>, w2: &Matrix<T>, w3: &Matrix<T>
) -> T
where for<'a> &'a T: NumericRef<T> + RealRef<T> {
(((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).scalar()
}
/**
* Computes mean squared loss of the network against all the training data.
*
* This is written for a generic type, so it can be used with records and also
* with normal floats.
*/
fn mean_squared_loss<T: Numeric + Real + Copy>(
inputs: &Vec<Matrix<T>>, w1: &Matrix<T>, w2: &Matrix<T>, w3: &Matrix<T>, labels: &Vec<T>
) -> T
where for<'a> &'a T: NumericRef<T> + RealRef<T> {
inputs.iter().enumerate().fold(T::zero(), |acc, (i, input)| {
let output = model::<T>(input, w1, w2, w3);
let correct = labels[i];
// sum up the squared loss
acc + ((correct - output) * (correct - output))
}) / T::from_usize(inputs.len()).unwrap()
}
/**
* Updates the weight matrices to step the gradient by one step.
*
* Note that here we are no longer generic over the type, we need the methods
* defined on Record to do backprop.
*/
fn step_gradient(
inputs: &Vec<Matrix<Record<f32>>>,
w1: &mut Matrix<Record<f32>>, w2: &mut Matrix<Record<f32>>, w3: &mut Matrix<Record<f32>>,
labels: &Vec<Record<f32>>, learning_rate: f32, list: &WengertList<f32>
) -> f32 {
let loss = mean_squared_loss::<Record<f32>>(inputs, w1, w2, w3, labels);
let derivatives = loss.derivatives();
// update each element in the weight matrices by the derivatives
w1.map_mut(|x| x - (derivatives[&x] * learning_rate));
w2.map_mut(|x| x - (derivatives[&x] * learning_rate));
w3.map_mut(|x| x - (derivatives[&x] * learning_rate));
// reset gradients
list.clear();
w1.map_mut(Record::do_reset);
w2.map_mut(Record::do_reset);
w3.map_mut(Record::do_reset);
// return the loss
loss.number
}
// use a fixed seed random generator from the rand crate
let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
// randomly initalise the weights using the fixed seed generator for reproducibility
let list = WengertList::new();
// w1 will be a 3x3 matrix
let mut w1 = Matrix::from(vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
// w2 will be a 3x3 column matrix
let mut w2 = Matrix::from(vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
// w3 will be a 3x1 column matrix
let mut w3 = Matrix::column(n_random_numbers(&mut random_generator, 3))
.map(|x| Record::variable(x, &list));
println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);
// define XOR inputs, with biases added to the inputs
let inputs = vec![
Matrix::row(vec![ 0.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 0.0, 1.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 1.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 1.0, 1.0, 1.0 ]).map(|x| Record::constant(x))
];
// define XOR outputs which will be used as labels
let labels = vec![ 0.0, 1.0, 1.0, 0.0 ].drain(..).map(|x| Record::constant(x)).collect();
let learning_rate = 0.2;
let epochs = 4000;
// do the gradient descent and save the loss at each epoch
let mut losses = Vec::with_capacity(epochs);
for _ in 0..epochs {
losses.push(step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list))
}
// now plot the training loss
let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
chart.lineplot(
Shape::Lines(&losses.iter()
.cloned()
.enumerate()
.map(|(i, x)| (i as f32, x))
.collect::<Vec<(f32, f32)>>())
).display();
// note that with different hyperparameters, starting weights, or less training
// the network may not have converged and could still be outputting 0.5 for everything,
// the chart plot with this configuration is particularly interesting because the loss
// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
// finally learning how to remap the input data in a way which can then be linearly
// seperated to achieve ~0.0 loss.
// check that the weights are sensible
println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);
// check that the network has learned XOR properly
println!("0 0: {:?}", model::<Record<f32>>(&inputs[0], &w1, &w2, &w3).number);
println!("0 1: {:?}", model::<Record<f32>>(&inputs[1], &w1, &w2, &w3).number);
println!("1 0: {:?}", model::<Record<f32>>(&inputs[2], &w1, &w2, &w3).number);
println!("1 1: {:?}", model::<Record<f32>>(&inputs[3], &w1, &w2, &w3).number);
assert!(losses[epochs - 1] < 0.02);
// we can also extract the learned weights once done with training and avoid the memory
// overhead of Record
let w1_final = w1.map(|x| x.number);
let w2_final = w2.map(|x| x.number);
let w3_final = w3.map(|x| x.number);
println!("0 0: {:?}", model::<f32>(&inputs[0].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("0 1: {:?}", model::<f32>(&inputs[1].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 0: {:?}", model::<f32>(&inputs[2].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 1: {:?}", model::<f32>(&inputs[3].map(|x| x.number), &w1_final, &w2_final, &w3_final));
```
# [Handwritten digit recognition on the MNIST dataset](super::web_assembly#handwritten-digit-recognition-on-the-mnist-dataset)
*/