1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
//! MLI intends to provide modern, data-driven abstractions for machine learning.
//!
//! MLI provides traits that work much like Combine's `Parser` in that one should only
//! need to have to create tensor processing primitives and then string them together to craft
//! a large system that continues to be useful as a primitive in an even larger system.
//!
//! To understand MLI, we actually just need to go over some basic Rust concepts.
//!
//! In Rust, the following relationships relate the owned versions of data with the borrowed versions:
//! - `T` -> `&mut T` -> `&T`
//! - `Vec<T>` -> `&mut [T]` -> `&[T]`
//! - `String` -> `&mut str` -> `&str`
//! - `FnOnce` -> `FnMut` -> `Fn`
//!
//! Most people don't think about the last one, because it is a trait, but it is critical to understand.
//! A `FnOnce` owns its enclosed environment. It can implement `FnMut` because since it owns the environment,
//! it can create a `&mut` to it. Similarly, a `FnMut` has a mutable borrow to its environment, thus it can
//! downgrade that mutable borrow to an immutable borrow.
//!
//! In MLI, the equivalent is:
//! - `Graph` -> `&mut Graph` -> `&Graph`
//!
//! Let us decompose the notion of what a compute graph is. A compute graph simply takes several inputs,
//! produces several intermediary outputs, and then produces the actual outputs. Each one of these
//! intermediary outputs is important because they affect the gradients of the functions that came after them.
//! If we discard them, then they will have to be recomputed from the input. If we only have a simple
//! convolution operation as our whole graph, there are no intermediary computations, just the
//! input and the output. However, we are more likely to have several layers of convolution, activation
//! functions, splitting, and merging. If we have just two layers of convolution with no activation function,
//! the first layer produces an output which is necessary to calculate the gradient `𝛿output/𝛿filter`.
//!
//! We would like to have an abstraction that treats a whole graph, its inputs, intermediary computations, and
//! its outputs, the same way we would treat a single simple static function like a `tanh`. We would also like
//! that abstraction to be zero-cost and ideally parallel.
//!
//! This is where the magic happens. The traits [`Forward`] and [`Backward`] only require a `&self` as an
//! argument. This is because they use the trainable variables immutably to do a forward and backward
//! propogation. The trainable variables within the graph are stored as the [`Backward::TrainDelta`] type.
//! This means that if we have a graph composed of [`Forward`] and [`Backward`] items, we can compute
//! all of the deltas for a batch without ever mutating the graph, thus we can compute all of the deltas
//! **in parallel** for each batch. Since the learning rate is intended to be incorporated into the
//! delta before hand, you can just sum all of the deltas in the batch and train the network.
//!
//! The network which is trained though can't possibly be the one we distributed to all the threads,
//! however, due to Rust's ownership and borrowing. This is where the references come in.
//! The `Graph` is trainable since we can mutate it. The `&Graph` immutably borrows
//! the `Graph`, which we can propogate [`Forward`] and [`Backward`] across
//! multiple threads. We can then sum all of the deltas from the threads when they are done and then
//! update the `Graph`, which is no longer borrowed.
//!
//! Currently there are three main traits:
//! - [`Forward`]
//!     - Implemented on anything that can go into a forward graph.
//!     - Outputs intermediary computations (to compute gradients later) and final output.
//! - [`Backward`]
//!     - Implemented on anything that can go into a backward graph.
//!     - Propogates gradients backwards through the neural network.
//!     - This is provided the original input and any intermediary computations.
//!     - Propogates the change from the output to the input and trainable variables.
//!     - Is `&self` and does not do the actual training, so this can be ran in parallel.
//! - [`Train`]
//!     - Implemented on anything that can go into a training graph.
//!     - Uses the change to update the trainable variables.
//!         - Change can be normalized across a mini-batch before being passed.
//!     - Implemented on the `mutable` version of the graph.

#![no_std]

mod chain;
pub use chain::*;

/// This trait is for algorithms that have an input and produce an output.
pub trait Forward {
    type Input;
    type Internal;
    type Output;

    /// `forward` produces:
    /// - All intermediary values produced, which can be reused in the back propogation step
    /// - The output `f` given an `input`
    fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output);

    /// `run` only produces the output from the input. The default implementation uses `forward`.
    /// You can make a custom implementation of this to avoid the overhead of producing and returning
    /// the internal variables.
    fn run(&self, input: &Self::Input) -> Self::Output {
        self.forward(input).1
    }
}

/// This trait indicates support of backwards propogation.
///
/// This trait also contains methods to perform training if training is possible.
/// If training is not possible, this trait can still be implemented with those definitions
/// being empty. In that case, machine learning algorithms will still be able to back propogate
/// over this operation, but training it will be a no-op.
pub trait Backward: Forward {
    type OutputDelta;
    type InputDelta;
    type TrainDelta;

    /// `partials` produces the change required in the input and trainable variables.
    ///
    /// Key:
    /// - `f` is the output
    /// - `x` is the input
    /// - `v` is the trainable variables
    /// - `E` is the loss or error
    /// - `delta` is equivalent to `Δf` or `-η * 𝛿E/𝛿f`
    ///
    /// This method should produce `(Δx, Δv)`:
    /// - `Δx = Δf * 𝛿f/𝛿x`
    /// - `Δv = Δf * 𝛿f/𝛿v`
    ///
    /// `𝛿f/𝛿x` and `𝛿f/𝛿v` can be an approximation, particularly if the function is not differentiable
    /// (e.g. ReLU and signum).
    fn backward(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> (Self::InputDelta, Self::TrainDelta);

    /// See [`Backward::backward`] for documentation.
    fn backward_input(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::InputDelta {
        self.backward(input, internal, output_delta).0
    }

    /// See [`Backward::backward`] for documentation.
    fn backward_train(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::TrainDelta {
        self.backward(input, internal, output_delta).1
    }
}

/// This trait is implemented on all operations that can be included in a trainable model.
///
/// This trait is implemented on all structures that implement `Backward` that are mutable
/// or `Static` even if they are immutable. This exists so that an immutable model can have
/// back propogation ran against in several different threads without the need for locking,
/// and then the mutable version of the model can implement `Train` to recieve the gradients
/// and be updated. If training is not possible, this trait can still be implemented with those
/// definitions being empty. In that case, machine learning algorithms will still be able to
/// incorporate this operation, but training it will be a no-op. This is necessary to
/// implement to be included in a trainable model.
pub trait Train: Backward {
    /// `train` takes in a train delta `Δv` and applies it to the trained variables.
    ///
    /// This should effectively perform `v += Δv`.
    fn train(&mut self, train_delta: &Self::TrainDelta);

    /// `propogate` should have the same effect as running `backwards` followed by `train`.
    ///
    /// See the source (by clicking the SRC link on the right) to see the simple definition.
    /// If the training is running on the same instance as the backpropogation, then `propogate`
    /// might perform better since it may not need to allocate more memory to update the trained variables.
    /// If an implementation can do this efficiently, it should create a custom implemenatation.
    fn propogate(
        &mut self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::InputDelta {
        let (input_delta, train_delta) = self.backward(input, internal, output_delta);
        self.train(&train_delta);
        input_delta
    }
}

pub trait Graph: Train + Sized {
    fn chain<U>(self, other: U) -> Chain<Self, U> {
        Chain(self, other)
    }
}

impl<T> Graph for T where T: Train {}

impl<'a, T> Forward for &'a T
where
    T: Forward,
{
    type Input = T::Input;
    type Internal = T::Internal;
    type Output = T::Output;

    fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output) {
        T::forward(self, input)
    }

    fn run(&self, input: &Self::Input) -> Self::Output {
        T::run(self, input)
    }
}

impl<'a, T> Forward for &'a mut T
where
    T: Forward,
{
    type Input = T::Input;
    type Internal = T::Internal;
    type Output = T::Output;

    fn forward(&self, input: &Self::Input) -> (Self::Internal, Self::Output) {
        T::forward(self, input)
    }

    fn run(&self, input: &Self::Input) -> Self::Output {
        T::run(self, input)
    }
}

impl<'a, T> Backward for &'a T
where
    T: Backward,
{
    type OutputDelta = T::OutputDelta;
    type InputDelta = T::InputDelta;
    type TrainDelta = T::TrainDelta;

    fn backward(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> (Self::InputDelta, Self::TrainDelta) {
        T::backward(self, input, internal, output_delta)
    }

    fn backward_input(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::InputDelta {
        T::backward_input(self, input, internal, output_delta)
    }

    fn backward_train(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::TrainDelta {
        T::backward_train(self, input, internal, output_delta)
    }
}

impl<'a, T> Backward for &'a mut T
where
    T: Backward,
{
    type OutputDelta = T::OutputDelta;
    type InputDelta = T::InputDelta;
    type TrainDelta = T::TrainDelta;

    fn backward(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> (Self::InputDelta, Self::TrainDelta) {
        T::backward(self, input, internal, output_delta)
    }

    fn backward_input(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::InputDelta {
        T::backward_input(self, input, internal, output_delta)
    }

    fn backward_train(
        &self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::TrainDelta {
        T::backward_train(self, input, internal, output_delta)
    }
}

impl<'a, T> Train for &'a mut T
where
    T: Train,
{
    fn train(&mut self, train_delta: &Self::TrainDelta) {
        T::train(self, train_delta)
    }

    fn propogate(
        &mut self,
        input: &Self::Input,
        internal: &Self::Internal,
        output_delta: &Self::OutputDelta,
    ) -> Self::InputDelta {
        T::propogate(self, input, internal, output_delta)
    }
}