1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
/*!
Neural Network training examples

# XOR

The following code shows a simple network using the sigmoid activation function
to learn the non linear XOR function. Use of a non linear activation function
is very important, as without them the network would not be able to remap the
inputs into a new space that can be linearly seperated.

Rather than symbolically differentiate the model y = sigmoid(sigmoid(x * w1) * w2) * w3
the [Record](super::differentiation::Record) struct is used to perform reverse
[automatic differentiation](super::differentiation). This adds a slight
memory overhead but also makes it easy to experiment with adding or tweaking layers
of the network or trying different activation functions like ReLu or tanh.

Note that the gradients recorded in each epoch must be cleared before training in
the next one.

Also note that the calls to `model` and `mean_squared_loss` use the rust turbofish syntax
to tell the compiler explicitly what type we're monomorphising, as otherwise the rust
compiler gets confused trying to infer the type.
```
use easy_ml::matrices::Matrix;
use easy_ml::numeric::{Numeric, NumericRef};
use easy_ml::numeric::extra::{Real, RealRef, Exp};
use easy_ml::differentiation::{Record, WengertList};

use rand::{Rng, SeedableRng};
use rand::distributions::Standard;

use textplots::{Chart, Plot, Shape};

/**
 * Utility function to create a list of random numbers.
 */
fn n_random_numbers<R: Rng>(random_generator: &mut R, n: usize) -> Vec<f32> {
    random_generator.sample_iter(Standard).take(n).collect()
}

/**
 * The sigmoid function which will be used as a non linear activation function.
 *
 * This is written for a generic type, so it can be used with records and also
 * with normal floats.
 */
fn sigmoid<T: Numeric + Real + Copy>(x: T) -> T {
    // 1 / (1 + e^-x)
    T::one() / (T::one() + (-x).exp())
}

/**
 * A simple three layer neural network that outputs a scalar.
 *
 * This is written for a generic type, so it can be used with records and also
 * with normal floats.
 */
fn model<T: Numeric + Real + Copy>(
    input: &Matrix<T>, w1: &Matrix<T>, w2: &Matrix<T>, w3: &Matrix<T>
) -> T
where for<'a> &'a T: NumericRef<T> + RealRef<T> {
    (((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).scalar()
}

/**
 * Computes mean squared loss of the network against all the training data.
 *
 * This is written for a generic type, so it can be used with records and also
 * with normal floats.
 */
fn mean_squared_loss<T: Numeric + Real + Copy>(
   inputs: &Vec<Matrix<T>>, w1: &Matrix<T>, w2: &Matrix<T>, w3: &Matrix<T>, labels: &Vec<T>
) -> T
where for<'a> &'a T: NumericRef<T> + RealRef<T> {
    inputs.iter().enumerate().fold(T::zero(), |acc, (i, input)| {
        let output = model::<T>(input, w1, w2, w3);
        let correct = labels[i];
        // sum up the squared loss
        acc + ((correct - output) * (correct - output))
    }) / T::from_usize(inputs.len()).unwrap()
}

/**
 * Updates the weight matrices to step the gradient by one step.
 *
 * Note that here we are no longer generic over the type, we need the methods
 * defined on Record to do backprop.
 */
fn step_gradient(
    inputs: &Vec<Matrix<Record<f32>>>,
    w1: &mut Matrix<Record<f32>>, w2: &mut Matrix<Record<f32>>, w3: &mut Matrix<Record<f32>>,
    labels: &Vec<Record<f32>>, learning_rate: f32, list: &WengertList<f32>
) -> f32 {
    let loss = mean_squared_loss::<Record<f32>>(inputs, w1, w2, w3, labels);
    let derivatives = loss.derivatives();
    // update each element in the weight matrices by the derivatives
    w1.map_mut(|x| x - (derivatives[&x] * learning_rate));
    w2.map_mut(|x| x - (derivatives[&x] * learning_rate));
    w3.map_mut(|x| x - (derivatives[&x] * learning_rate));
    // reset gradients
    list.clear();
    w1.map_mut(Record::do_reset);
    w2.map_mut(Record::do_reset);
    w3.map_mut(Record::do_reset);
    // return the loss
    loss.number
}

// use a fixed seed random generator from the rand crate
let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);

// randomly initalise the weights using the fixed seed generator for reproducibility
let list = WengertList::new();
// w1 will be a 3x3 matrix
let mut w1 = Matrix::from(vec![
    n_random_numbers(&mut random_generator, 3),
    n_random_numbers(&mut random_generator, 3),
    n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
// w2 will be a 3x3 column matrix
let mut w2 = Matrix::from(vec![
    n_random_numbers(&mut random_generator, 3),
    n_random_numbers(&mut random_generator, 3),
    n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
// w3 will be a 3x1 column matrix
let mut w3 = Matrix::column(n_random_numbers(&mut random_generator, 3))
.map(|x| Record::variable(x, &list));
println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);

// define XOR inputs, with biases added to the inputs
let inputs = vec![
    Matrix::row(vec![ 0.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
    Matrix::row(vec![ 0.0, 1.0, 1.0 ]).map(|x| Record::constant(x)),
    Matrix::row(vec![ 1.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
    Matrix::row(vec![ 1.0, 1.0, 1.0 ]).map(|x| Record::constant(x))
];
// define XOR outputs which will be used as labels
let labels = vec![ 0.0, 1.0, 1.0, 0.0 ].drain(..).map(|x| Record::constant(x)).collect();
let learning_rate = 0.2;
let epochs = 4000;

// do the gradient descent and save the loss at each epoch
let mut losses = Vec::with_capacity(epochs);
for _ in 0..epochs {
    losses.push(step_gradient(&inputs, &mut w1, &mut w2, &mut w3, &labels, learning_rate, &list))
}

// now plot the training loss
let mut chart = Chart::new(180, 60, 0.0, epochs as f32);
chart.lineplot(
    Shape::Lines(&losses.iter()
        .cloned()
        .enumerate()
        .map(|(i, x)| (i as f32, x))
        .collect::<Vec<(f32, f32)>>())
    ).display();

// note that with different hyperparameters, starting weights, or less training
// the network may not have converged and could still be outputting 0.5 for everything,
// the chart plot with this configuration is particularly interesting because the loss
// hovers around 0.3 to 0.2 for a while (while outputting 0.5 for every input) before
// finally learning how to remap the input data in a way which can then be linearly
// seperated to achieve ~0.0 loss.

// check that the weights are sensible
println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);
// check that the network has learned XOR properly
println!("0 0: {:?}", model::<Record<f32>>(&inputs[0], &w1, &w2, &w3).number);
println!("0 1: {:?}", model::<Record<f32>>(&inputs[1], &w1, &w2, &w3).number);
println!("1 0: {:?}", model::<Record<f32>>(&inputs[2], &w1, &w2, &w3).number);
println!("1 1: {:?}", model::<Record<f32>>(&inputs[3], &w1, &w2, &w3).number);
assert!(losses[epochs - 1] < 0.02);

// we can also extract the learned weights once done with training and avoid the memory
// overhead of Record
let w1_final = w1.map(|x| x.number);
let w2_final = w2.map(|x| x.number);
let w3_final = w3.map(|x| x.number);
println!("0 0: {:?}", model::<f32>(&inputs[0].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("0 1: {:?}", model::<f32>(&inputs[1].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 0: {:?}", model::<f32>(&inputs[2].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 1: {:?}", model::<f32>(&inputs[3].map(|x| x.number), &w1_final, &w2_final, &w3_final));
```

# [Handwritten digit recognition on the MNIST dataset](super::web_assembly#handwritten-digit-recognition-on-the-mnist-dataset)
 */