easy_ml/
differentiation.rs

1#![allow(clippy::double_parens)]
2/*!
3 * (Automatic) Differentiation helpers
4 *
5 * # Automatic Differentiation
6 *
7 * This module provides structs for performing Forward and Reverse Automatic Differentiation
8 *
9 * ## Automatic Differentiation is not [Numerical Differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation)
10 *
11 * You were probably introduced to differentiation as numeric differentiation,
12 * ie if you have a function 3x<sup>2</sup> then you can estimate its gradient
13 * at some value x by computing 3x<sup>2</sup> and 3(x+h)<sup>2</sup> where h
14 * is a very small value. The tangent line these two points create gives you an approximation
15 * of the gradient when you calculate (f(x+h) - f(x)) / h. Unfortunately floating
16 * point numbers in computers have limited precision, so this method is only approximate
17 * and can result in floating point errors. 1 + 1 might equal 2 but as you go smaller
18 * 10<sup>-i</sup> + 10<sup>-i</sup> starts to loook rather like 10<sup>-i</sup> as i goes
19 * into double digits.
20 *
21 * ## Automatic Differentiation is not Symbolic Differentiation
22 *
23 * If you were taught calculus you have probably done plenty of symbolic differentiation
24 * by hand. A function 3x<sup>2</sup> can be symbolically differentiated into 6x by applying
25 * simple rules to manipulate the algebra. Unfortunately the rules aren't so simple for
26 * more complex expressions such as [exponents](https://www.wolframalpha.com/input/?i=d%28x%5Ee%5E2%29%2Fdx),
27 * [logs](https://www.wolframalpha.com/input/?i=d%28log%28log%28x%29%29%29%2Fdx) or
28 * [trigonometry](https://www.wolframalpha.com/input/?i=d%28sin%28cos%28x%29%29%29%2Fdx).
29 * Symbolic differentiation can give you expressions which are just as or more complicated
30 * than the original, and doing it by hand can be error prone. Symbolic Differentiation is
31 * also tricky to relate to algorithmic computations that use control structures.
32 *
33 * ## [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
34 *
35 * Automatic Differentiation computes the derivative of a function without rewriting
36 * the function as symbolic differentiation does and without the precision issues of numerical
37 * differentiation by splitting the derivative into lots of tiny steps of basic operations
38 * like addition and multiplication. These are combined using the chain rule. The downside
39 * is more memory is used than in symbolic or numerical differentiation, as derivatives have
40 * to be tracked through the computational graph.
41 *
42 * # Forward Differentiation
43 *
44 * Forward Differentiation computes all the gradients in a computational graph with respect
45 * to an input. For example, if you have a function f(x, y) = 5x<sup>3</sup> - 4x<sup>2</sup> +
46 * 10x - y, then for some actual value of x and y you can compute f(x,y) and δf(x,y)/δx
47 * together in one forward pass using forward differentiation. You can also make another pass
48 * and compute f(x,y) and δf(x,y)/δy for some actual value of x and y. Forward differentiation
49 * in this way requires making 2N passes of the function to compute the derivatives of the output
50 * with respect to N inputs. However, you do get the gradients for every output in a single pass
51 * This is poorly suited to neural nets as they often have a single output(loss)
52 * to differentiate many many inputs with respect to.
53 *
54 * # Reverse Mode Differentiation
55 *
56 * Reverse Mode Differentiation computes all the gradients in a computational graph for
57 * the same output. For example, if you have a function f(x, y) = 5x<sup>3</sup> -
58 * 4x<sup>2</sup> + 10x - y, then for some actual value of x and y you can compute f(x,y)
59 * and store all the intermediate results. You can then run a backward pass on the output
60 * of f(x, y) and obtain δf(x,y)/δx and δf(x,y)/δy for the actual values of x and y in a
61 * single pass. The catch is that reverse mode must store as many intermediate values as
62 * there are steps in the function which can use much more memory than forward mode.
63 * Reverse mode also requires making N backward passes to get the gradients for N different
64 * outputs. This is well suited to neural nets because we often have a single output (loss)
65 * to differentiate many inputs with respect to. However, reverse mode will be slower than
66 * forward mode if the number of inputs is small or there are many outputs.
67 *
68 * # Usage
69 *
70 * [See sub module for usage examples](usage)
71 *
72 * # Further information
73 *
74 * - [Automatic Differentiation Step by Step](https://medium.com/@marksaroufim/automatic-differentiation-step-by-step-24240f97a6e6)
75 * - [Forward Mode Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers)
76 * - [Reverse Mode Automatic Differentiation](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
77 * - [Automatic Differentiation: The most criminally underused tool in the potential machine learning toolbox?](https://justindomke.wordpress.com/2009/02/17/automatic-differentiation-the-most-criminally-underused-tool-in-the-potential-machine-learning-toolbox/)
78 * - [Yes you should understand backprop](https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b)
79 */
80
81mod container_record;
82mod functions;
83pub mod operations;
84pub mod record_operations;
85pub mod trace_operations;
86pub mod usage;
87
88pub use container_record::*;
89
90use crate::numeric::{Numeric, NumericRef};
91
92#[cfg(feature = "serde")]
93use serde::{Deserialize, Serialize};
94
95/**
96 * A trait with no methods which is implemented for all primitive types.
97 *
98 * Importantly this trait is not implemented for Traces (or Records), to stop the compiler
99 * from trying to evaluate nested Traces of Traces or Records of Records as Numeric types.
100 * There is no reason to create a Trace of a Trace or Record of a Record, it won't do
101 * anything a Trace or Record can't except use more memory.
102 *
103 * The boilerplate implementations for primitives is performed with a macro.
104 * If a primitive type is missing from this list, please open an issue to add it in.
105 */
106pub trait Primitive {}
107
108/**
109 * A dual number which traces a real number and keeps track of its derivative.
110 * This is used to perform Forward Automatic Differentiation
111 *
112 * Trace implements only first order differentiation. For example, given a function
113 * 3x<sup>2</sup>, you can use calculus to work out that its derivative with respect
114 * to x is 6x. You can also take the derivative of 6x with respect to x and work out
115 * that the second derivative is 6. By instead writing the function 3x<sup>2</sup> in
116 * code using Trace types as your numbers you can compute the first order derivative
117 * for a given value of x by passing your function `Trace { number: x, derivative: 1.0 }`.
118 *
119 * ```
120 * use easy_ml::differentiation::Trace;
121 * let x = Trace { number: 3.2, derivative: 1.0 };
122 * let dx = Trace::constant(3.0) * x * x;
123 * assert_eq!(dx.derivative, 3.2 * 6.0);
124 * ```
125 *
126 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
127 * differentiation.
128 *
129 * # Acknowledgments
130 *
131 * The wikipedia page on [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
132 * provided a very useful overview and explanation for understanding Forward Mode Automatic
133 * Differentiation as well as the implementation rules.
134 */
135#[derive(Debug)]
136#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137pub struct Trace<T: Primitive> {
138    /**
139     * The real number
140     */
141    pub number: T,
142    /**
143     * The first order derivative of this number.
144     */
145    // If we loosen this type from T to a tensor of T of some const-generic
146    // dimensionality then we can calculate higher order derivatives with a single
147    // Trace type.
148    // However, Trace<Trace<f64>> can do such a calculation already for 2nd order
149    // (and so on) and requires far less complexity in the API so this might not
150    // be that worthwhile. Tensor<T, 1> introduces a lot of boxing that might also
151    // hurt first order performance.
152    pub derivative: T,
153}
154
155/**
156 * The main set of methods for using Trace types for Forward Differentiation.
157 *
158 * The general steps are
159 * 1. create one variable
160 * 2. create as many constants as needed
161 * 3. do operations on the variable and constants
162 * 4. the outputs will have derivatives computed which can be accessed from
163 * the `.derivative` field, with each derivative being the output with respect
164 * to the input variable.
165 * 5. if you need derivatives for a different input then do everything all over again
166 * or do them all in parallel
167 */
168impl<T: Numeric + Primitive> Trace<T> {
169    /**
170     * Constants are lifted to Traces with a derivative of 0
171     *
172     * Why zero for the starting derivative? Because for any constant C
173     * δC/δx = 0, as with symbolic differentiation.
174     */
175    pub fn constant(c: T) -> Trace<T> {
176        Trace {
177            number: c,
178            derivative: T::zero(),
179        }
180    }
181
182    /**
183     * To lift a variable that you want to find the derivative of
184     * a function to, the Trace starts with a derivative of 1
185     *
186     * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
187     * differentiation.
188     */
189    pub fn variable(x: T) -> Trace<T> {
190        Trace {
191            number: x,
192            derivative: T::one(),
193        }
194    }
195
196    /**
197     * Computes the derivative of a function with respect to its input x.
198     *
199     * This is a shorthand for `(function(Trace::variable(x))).derivative`
200     *
201     * In the more general case, if you provide a function with an input x
202     * and it returns N outputs y<sub>1</sub> to y<sub>N</sub> then you
203     * have computed all the derivatives δy<sub>i</sub>/δx for i = 1 to N.
204     */
205    pub fn derivative(function: impl FnOnce(Trace<T>) -> Trace<T>, x: T) -> T {
206        (function(Trace::variable(x))).derivative
207    }
208}
209
210impl<T: Numeric + Primitive> Trace<T>
211where
212    for<'a> &'a T: NumericRef<T>,
213{
214    /**
215     * Creates a new Trace from a reference to an existing Trace by applying
216     * some unary function to it which operates on the type the Trace wraps.
217     *
218     * To compute the new trace, the unary function of some input x to some
219     * output y is needed along with its derivative with respect to its input x.
220     *
221     * For example, tanh is a commonly used activation function, but the Real trait
222     * does not include this operation and Trace has no operations for it specifically.
223     * However, you can use this function to compute the tanh of a Trace like so:
224     *
225     * ```
226     * use easy_ml::differentiation::Trace;
227     * let x = Trace::variable(0.7f32);
228     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
229     * // 1 / (cosh(x) * cosh(x))
230     * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
231     * assert_eq!(y.derivative, 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
232     * ```
233     */
234    #[inline]
235    pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Trace<T> {
236        Trace {
237            number: fx(self.number.clone()),
238            derivative: self.derivative.clone() * dfx_dx(self.number.clone()),
239        }
240    }
241
242    /**
243     * Creates a new Trace from a reference to two existing Traces by applying
244     * some binary function to them which operates on two arguments of the type
245     * the Traces wrap.
246     *
247     * To compute the new trace, the binary function of some inputs x and y to some
248     * output z is needed along with its derivative with respect to its first input x and
249     * its derivative with respect to its second input y.
250     *
251     * For example, atan2 takes two arguments, but the Real trait
252     * does not include this operation and Trace has no operations for it specifically.
253     * However, you can use this function to compute the atan2 of two Traces like so:
254     *
255     * ```
256     * use easy_ml::differentiation::Trace;
257     * let x = Trace::variable(3.0f32);
258     * let y = Trace::variable(3.0f32);
259     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
260     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
261     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
262     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
263     * let z = x.binary(&y,
264     *     |x, y| x.atan2(y),
265     *     |x, y| y/((x*x) + (y*y)),
266     *     |x, y| -x/((x*x) + (y*y))
267     * );
268     * ```
269     */
270    #[inline]
271    pub fn binary(
272        &self,
273        rhs: &Trace<T>,
274        fxy: impl Fn(T, T) -> T,
275        dfxy_dx: impl Fn(T, T) -> T,
276        dfxy_dy: impl Fn(T, T) -> T,
277    ) -> Trace<T> {
278        Trace {
279            number: fxy(self.number.clone(), rhs.number.clone()),
280            #[rustfmt::skip]
281            derivative: (
282                ((self.derivative.clone() * dfxy_dx(self.number.clone(), rhs.number.clone()))
283                + (rhs.derivative.clone() * dfxy_dy(self.number.clone(), rhs.number.clone())))
284            ),
285        }
286    }
287}
288
289use std::cell::RefCell;
290
291/**
292 * WengertLists are indexed with [`usize`].
293 */
294pub type Index = usize;
295
296/**
297 * A list of operations performed in a forward pass of a dynamic computational graph,
298 * used for Reverse Mode Automatic Differentiation.
299 *
300 * This is dynamic, as in, you build the [Wengert list](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation)
301 * at runtime by performing operations like addition and multiplication on
302 * [Records](Record) that were created with that Wengert list.
303 *
304 * When you perform a backward pass to obtain the gradients you travel back up the
305 * computational graph using the stored intermediate values from this list to compute
306 * all the gradients of the inputs and every intermediate step with respect to an output.
307 *
308 * Although sophisticated implementations can make the Wengert list only log(N) in length
309 * by storing only some of the intermediate steps of N computational steps, this implementation
310 * is not as sophisticated, and will store all of them.
311 *
312 * # Panics
313 *
314 * Every operation and nearly every method a Record has involves manipulating the
315 * record's history on its referenced WengertList. This WengertList itself maintains
316 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
317 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
318 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
319 * compiler correctly infers that it is not safe to share references to WengertLists between
320 * threads, nor transfer Records across threads. If you called a method on two Records that both
321 * mutably borrowed from the same WengertList at once, which could be trivially done with
322 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
323 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
324 * Record method call, and you can't call two methods simulatenously without threading.
325 */
326#[derive(Debug)]
327pub struct WengertList<T> {
328    // It is neccessary to wrap the vec in a RefCell to allow for mutating
329    // this list from immutable references held by each
330    operations: RefCell<Vec<Operation<T>>>,
331}
332
333struct BorrowedWengertList<'a, T> {
334    operations: &'a mut Vec<Operation<T>>,
335}
336
337/**
338 * A binary operation to record on a WengertList. For unary operations the
339 * right derivative is set to 0, and for nullary operations both derivatives
340 * are set to 0.
341 *
342 * Each operation acts like a node in an upside down binary tree, with two parents that
343 * each node was computed from. The main difference is that the numerical
344 * index of those parents in the WengertList is stored, rather than any pointers.
345 */
346#[derive(Debug)]
347struct Operation<T> {
348    left_parent: Index,
349    right_parent: Index,
350    left_derivative: T,
351    right_derivative: T,
352}
353
354/**
355 * Computed derivatives of a computational graph for some output [Record] variable.
356 *
357 * This can be indexed by any Record used in the computational graph to get
358 * the derivative with respect to that input.
359 *
360 * Indexing using Records not involved in the computational graph, or involved
361 * in a different one will return nonsense or index out of bounds and panic. In
362 * the future this may be changed to always panic.
363 */
364#[derive(Debug)]
365pub struct Derivatives<T> {
366    derivatives: Vec<T>,
367}
368
369/**
370 * Any derivatives of a Cloneable type implements clone
371 */
372impl<T: Clone> Clone for Derivatives<T> {
373    fn clone(&self) -> Self {
374        Derivatives {
375            derivatives: self.derivatives.clone(),
376        }
377    }
378}
379
380impl<T: Clone + Primitive> Derivatives<T> {
381    /**
382     * Quries the derivative at the provided record as input.
383     *
384     * If you construct a Derivatives object for some output y,
385     * and call .at(&x) on it for some input x, this returns dy/dx.
386     */
387    pub fn at(&self, input: &Record<T>) -> T {
388        self.derivatives[input.index].clone()
389    }
390}
391
392impl<'a, T: Primitive> std::ops::Index<&Record<'a, T>> for Derivatives<T> {
393    type Output = T;
394    /**
395     * Quries the derivative at the provided record as input.
396     *
397     * If you construct a Derivatives object for some output y,
398     * and call .at(&x) on it for some input x, this returns dy/dx.
399     */
400    fn index(&self, input: &Record<'a, T>) -> &Self::Output {
401        &self.derivatives[input.index]
402    }
403}
404
405impl<T> std::convert::From<Derivatives<T>> for Vec<T> {
406    /**
407     * Converts the Derivatives struct into a Vec of derivatives that
408     * can be indexed with `usize`s. The indexes correspond to the
409     * index field on Records.
410     */
411    fn from(derivatives: Derivatives<T>) -> Self {
412        derivatives.derivatives
413    }
414}
415
416/**
417 * Any operation of a Cloneable type implements clone
418 */
419impl<T: Clone + Primitive> Clone for Operation<T> {
420    fn clone(&self) -> Self {
421        Operation {
422            left_parent: self.left_parent,
423            right_parent: self.right_parent,
424            left_derivative: self.left_derivative.clone(),
425            right_derivative: self.right_derivative.clone(),
426        }
427    }
428}
429
430/**
431 * A wrapper around a real number which records it going through the computational
432 * graph. This is used to perform Reverse Mode Automatic Differentiation.
433 *
434 * # Panics
435 *
436 * Every operation and nearly every method a Record has involves manipulating the
437 * record's history on its referenced [WengertList]. This WengertList itself maintains
438 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
439 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
440 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
441 * compiler infers that it is not safe to share references to WengertLists between threads,
442 * nor transfer Records across threads. If you called a method on two Records that both
443 * mutably borrowed from the same WengertList at once, which could be trivially done with
444 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
445 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
446 * Record method call, and you can't call two methods simulatenously without threading.
447 *
448 * # Acknowledgments
449 *
450 * A [tutorial by Rufflewind](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
451 * and the associated [MIT licensed](http://opensource.org/licenses/MIT)
452 * [soure code](https://github.com/Rufflewind/revad/blob/master/src/tape.rs) were invaluable
453 * in providing understanding on how to implement Reverse Mode Automatic Differentiation.
454 */
455#[derive(Debug)]
456pub struct Record<'a, T: Primitive> {
457    // A record consists of a number used in the forward pass, as
458    // well as a WengertList of operations performed on the numbers
459    // and each record needs to know which point in the history they
460    // are for.
461    /**
462     * The real number
463     */
464    pub number: T,
465    history: Option<&'a WengertList<T>>,
466    /**
467     * The index of this number in its [WengertList]. The first entry will be 0,
468     * the next 1 and so on.
469     *
470     * In normal use cases you should not need to read this field,
471     * you can index [Derivatives] directly with Records.
472     */
473    pub index: Index,
474}
475
476/**
477 * The main set of methods for using Record types for Reverse Differentiation.
478 *
479 * The general steps are
480 * 1. create a `WengertList`
481 * 2. create variables from this list
482 * 3. do operations on the variables
483 * 4. from the output you want to compute derivatives for call `.derivatives()`
484 * 5. index the `Derivatives` object with the index variables to get the derivatives
485 * with respect to each input
486 * 6. if you want to make another pass call `clear()` on the `WengertList`
487 * and then call `reset()` on all of the variables to forget the gradients already
488 * computed (the order of `clear` then `reset` is very important!).
489 *
490 * Constants can be used to save memory if you have numbers that
491 * you do not need to compute the gradients with respect to.
492 */
493impl<'a, T: Numeric + Primitive> Record<'a, T> {
494    /**
495     * Creates an untracked Record which has no backing WengertList.
496     *
497     * This is provided for using constants along with Records in operations.
498     *
499     * For example with y = x + 4 the computation graph could be conceived as
500     * a y node with parent nodes of x and 4 combined with the operation +.
501     * However there is no need to record the derivatives of a constant, so
502     * instead the computation graph can be conceived as a y node with a single
503     * parent node of x and the unary operation of +4.
504     *
505     * This is also used for the type level constructors required by Numeric
506     * which are also considered constants.
507     */
508    pub fn constant(c: T) -> Record<'a, T> {
509        Record {
510            number: c,
511            history: None,
512            index: 0,
513        }
514    }
515
516    /**
517     * Creates a record backed by the provided WengertList.
518     *
519     * The record cannot live longer than the WengertList, hence
520     * the following example does not compile
521     *
522     * ```compile_fail
523     * use easy_ml::differentiation::Record;
524     * use easy_ml::differentiation::WengertList;
525     * let record = {
526     *     let list = WengertList::new();
527     *     Record::variable(1.0, &list)
528     * }; // list no longer in scope
529     * ```
530     *
531     * You can alternatively use the [record constructor on the WengertList type](WengertList::variable()).
532     */
533    pub fn variable(x: T, history: &'a WengertList<T>) -> Record<'a, T> {
534        Record {
535            number: x,
536            history: Some(history),
537            index: history.append_nullary(),
538        }
539    }
540
541    /**
542     * Creates a record from a constant/variable directly, most likely obtained by getting a
543     * tensor view of an existing [container](RecordContainer). **The inputs are not checked for
544     * validity**. It is possible to pass in the wrong Wengert list here or even numbers with
545     * indexes that aren’t tracked on the WengertList.
546     *
547     * It is recommended to use this constructor only in conjunction with manipulating an existing
548     * container and not for creating new variables. Any variables created outside of
549     * `Record::variable` / `RecordContainer::variables` would have to be manually added to the
550     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
551     */
552    pub fn from_existing(number: (T, Index), history: Option<&'a WengertList<T>>) -> Record<'a, T> {
553        Record {
554            number: number.0,
555            history,
556            index: number.1,
557        }
558    }
559
560    /**
561     * Resets this Record to place it back on its WengertList, for use
562     * in performing another derivation after clearing the WengertList.
563     */
564    pub fn reset(&mut self) {
565        match self.history {
566            None => (), // noop
567            Some(history) => self.index = history.append_nullary(),
568        };
569    }
570
571    /**
572     * A convenience helper function which takes a Record by value and
573     * calls [reset](Record::reset()) on it.
574     */
575    pub fn do_reset(mut x: Record<T>) -> Record<T> {
576        x.reset();
577        x
578    }
579
580    /**
581     * Gets the WengertList this Record is backed by if a variable, and [None] if a constant.
582     */
583    pub fn history(&self) -> Option<&'a WengertList<T>> {
584        self.history
585    }
586}
587
588impl<'a, T: Numeric + Primitive> Record<'a, T>
589where
590    for<'t> &'t T: NumericRef<T>,
591{
592    /**
593     * Performs a backward pass up this record's WengertList from this
594     * record as the output, computing all the derivatives for the inputs
595     * involving this output.
596     *
597     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
598     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
599     *
600     * # Panics
601     *
602     * Panics if the Record has no backing WengertList, ie it was created as a
603     * constant.
604     */
605    #[track_caller]
606    pub fn derivatives(&self) -> Derivatives<T> {
607        match self.try_derivatives() {
608            None => panic!("Record has no WengertList to find derivatives from"),
609            Some(d) => d,
610        }
611    }
612
613    /**
614     * Performs a backward pass up this record's WengertList from this
615     * record as the output, computing all the derivatives for the inputs
616     * involving this output.
617     *
618     * If this record has no WengertList, ie it's a constant, None is returned instead.
619     *
620     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
621     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
622     */
623    pub fn try_derivatives(&self) -> Option<Derivatives<T>> {
624        let history = self.history?;
625        let operations = history.operations.borrow();
626
627        let mut derivatives = vec![T::zero(); operations.len()];
628
629        // δy/δy = 1
630        derivatives[self.index] = T::one();
631
632        // Go back up the computation graph to the inputs
633        for i in (0..operations.len()).rev() {
634            let operation = operations[i].clone();
635            let derivative = derivatives[i].clone();
636            // The chain rule allows breaking up the derivative of the output y
637            // with respect to the input x into many smaller derivatives that
638            // are summed together.
639            // δy/δx = δy/δw * δw/δx
640            // δy/δx = sum for all i parents of y ( δy/δw_i * δw_i/δx )
641            derivatives[operation.left_parent] = derivatives[operation.left_parent].clone()
642                + derivative.clone() * operation.left_derivative;
643            derivatives[operation.right_parent] = derivatives[operation.right_parent].clone()
644                + derivative * operation.right_derivative;
645        }
646
647        Some(Derivatives { derivatives })
648    }
649}
650
651impl<T: Primitive> WengertList<T> {
652    /**
653     * Creates a new empty WengertList from which Records can be constructed.
654     */
655    pub fn new() -> WengertList<T> {
656        WengertList {
657            operations: RefCell::new(Vec::new()),
658        }
659    }
660}
661
662impl<T: Primitive> Default for WengertList<T> {
663    fn default() -> Self {
664        Self::new()
665    }
666}
667
668impl<T> WengertList<T> {
669    /**
670     * Clears a WengertList to make it empty again. After clearing a WengertList
671     * you must reset all the Records still using that list. Then you can perform
672     * another computation and get new gradients.
673     */
674    pub fn clear(&self) {
675        self.operations.borrow_mut().clear();
676    }
677}
678
679impl<T: Numeric + Primitive> WengertList<T> {
680    /**
681     * Creates a record backed by this WengertList.
682     *
683     * You can alternatively use the [record constructor on the Record type](Record::variable()).
684     */
685    pub fn variable(&self, x: T) -> Record<'_, T> {
686        Record {
687            number: x,
688            history: Some(self),
689            index: self.append_nullary(),
690        }
691    }
692
693    /**
694     * Adds a value to the list which does not have any parent values.
695     */
696    fn append_nullary(&self) -> Index {
697        use std::ops::DerefMut;
698        let mut borrow = self.operations.borrow_mut();
699        BorrowedWengertList::new(borrow.deref_mut()).append_nullary()
700    }
701
702    /**
703     * Adds a number of values to the list which do not have any parent values, returning
704     * the index of the first added value, the others will be contiguously afterwards.
705     *
706     * If values is 0, returns the first index that would be used but wasn't.
707     */
708    fn append_nullary_repeating(&self, values: usize) -> Index {
709        let mut operations = self.operations.borrow_mut();
710        // insert into end of list
711        let starting_index = operations.len();
712        for i in 0..values {
713            let index = starting_index + i;
714            operations.push(Operation {
715                // this index of the child is used for both indexes as these
716                // won't be needed but will always be valid (ie point to a
717                // real entry on the list)
718                left_parent: index,
719                right_parent: index,
720                // for the parents 0 is used to zero out these calculations
721                // as there are no parents
722                left_derivative: T::zero(),
723                right_derivative: T::zero(),
724            });
725        }
726        starting_index
727    }
728
729    /**
730     * Adds a value to the list which has one parent.
731     *
732     * For an output w_N which depends on one parent w_N-1
733     * the derivative cached here is δw_N / δw_N-1
734     *
735     * For example, if z = sin(x), then δz/δx = cos(x)
736     */
737    fn append_unary(&self, parent: Index, derivative: T) -> Index {
738        use std::ops::DerefMut;
739        let mut borrow = self.operations.borrow_mut();
740        BorrowedWengertList::new(borrow.deref_mut()).append_unary(parent, derivative)
741    }
742
743    /**
744     * Adds a value to the list which has two parents.
745     *
746     * For an output w_N which depends on two parents w_N-1
747     * and w_N-2 the derivatives cached here are δw_N / δw_N-1
748     * and δw_N / δw_N-2.
749     *
750     * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
751     * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
752     */
753    fn append_binary(
754        &self,
755        left_parent: Index,
756        left_derivative: T,
757        right_parent: Index,
758        right_derivative: T,
759    ) -> Index {
760        use std::ops::DerefMut;
761        let mut borrow = self.operations.borrow_mut();
762        BorrowedWengertList::new(borrow.deref_mut()).append_binary(
763            left_parent,
764            left_derivative,
765            right_parent,
766            right_derivative,
767        )
768    }
769
770    /**
771     * Borrows the WengertList mutably for batch operations. It is *very* important to
772     * hold onto the borrow only for as long as needed then drop it immediately. To avoid panics
773     * Easy ML needs to ensure 100% of method calls on the public API do not maintain a borrow
774     * after they finish executing. This was previously enforced by not having any batch
775     * append APIs, but they're needed for RecordContainer. Calling borrow again while still
776     * holding the first would trigger a panic, as would holding onto the borrow after the public
777     * API method is finished
778     */
779    fn borrow<F>(&self, op: F)
780    where
781        F: FnOnce(&mut BorrowedWengertList<T>),
782    {
783        use std::ops::DerefMut;
784        let mut borrow = self.operations.borrow_mut();
785        op(&mut BorrowedWengertList::new(borrow.deref_mut()));
786    }
787}
788
789/**
790 * Any Wengert list of a Cloneable type implements clone
791 */
792impl<T: Clone + Primitive> Clone for WengertList<T> {
793    fn clone(&self) -> Self {
794        WengertList {
795            operations: RefCell::new(self.operations.borrow().clone()),
796        }
797    }
798}
799
800/**
801 * Methods for appending Operations after borrowing the Wengert list.
802 */
803impl<'a, T: Numeric + Primitive> BorrowedWengertList<'a, T> {
804    fn new(operations: &mut Vec<Operation<T>>) -> BorrowedWengertList<'_, T> {
805        BorrowedWengertList { operations }
806    }
807
808    /**
809     * Adds a value to the list which does not have any parent values.
810     */
811    fn append_nullary(&mut self) -> Index {
812        // insert into end of list
813        let index = self.operations.len();
814        self.operations.push(Operation {
815            // this index of the child is used for both indexes as these
816            // won't be needed but will always be valid (ie point to a
817            // real entry on the list)
818            left_parent: index,
819            right_parent: index,
820            // for the parents 0 is used to zero out these calculations
821            // as there are no parents
822            left_derivative: T::zero(),
823            right_derivative: T::zero(),
824        });
825        index
826    }
827
828    /**
829     * Adds a value to the list which has one parent.
830     *
831     * For an output w_N which depends on one parent w_N-1
832     * the derivative cached here is δw_N / δw_N-1
833     *
834     * For example, if z = sin(x), then δz/δx = cos(x)
835     */
836    fn append_unary(&mut self, parent: Index, derivative: T) -> Index {
837        // insert into end of list
838        let index = self.operations.len();
839        self.operations.push(Operation {
840            left_parent: parent,
841            // this index of the child is used as this index won't be needed
842            // but will always be valid (ie points to a real entry on the list)
843            right_parent: index,
844            left_derivative: derivative,
845            // for the right parent 0 is used to zero out this calculation
846            // as there is no right parent
847            right_derivative: T::zero(),
848        });
849        index
850    }
851
852    /**
853     * Adds a value to the list which has two parents.
854     *
855     * For an output w_N which depends on two parents w_N-1
856     * and w_N-2 the derivatives cached here are δw_N / δw_N-1
857     * and δw_N / δw_N-2.
858     *
859     * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
860     * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
861     */
862    fn append_binary(
863        &mut self,
864        left_parent: Index,
865        left_derivative: T,
866        right_parent: Index,
867        right_derivative: T,
868    ) -> Index {
869        // insert into end of list
870        let index = self.operations.len();
871        self.operations.push(Operation {
872            left_parent,
873            right_parent,
874            left_derivative,
875            right_derivative,
876        });
877        index
878    }
879}
880
881impl<'a, T: Numeric + Primitive> Record<'a, T>
882where
883    for<'t> &'t T: NumericRef<T>,
884{
885    /**
886     * Creates a new Record from a reference to an existing Record by applying
887     * some unary function to it which operates on the type the Record wraps.
888     *
889     * To compute the new record, the unary function of some input x to some
890     * output y is needed along with its derivative with respect to its input x.
891     *
892     * For example, tanh is a commonly used activation function, but the Real trait
893     * does not include this operation and Record has no operations for it specifically.
894     * However, you can use this function to compute the tanh of a Record like so:
895     *
896     * ```
897     * use easy_ml::differentiation::{Record, WengertList};
898     * let list = WengertList::new();
899     * let x = Record::variable(0.7f32, &list);
900     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
901     * // 1 / (cosh(x) * cosh(x))
902     * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
903     * assert_eq!(y.derivatives()[&x], 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
904     * ```
905     */
906    #[inline]
907    pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Record<'_, T> {
908        match self.history {
909            None => Record {
910                number: fx(self.number.clone()),
911                history: None,
912                index: 0,
913            },
914            Some(history) => Record {
915                number: fx(self.number.clone()),
916                history: Some(history),
917                index: history.append_unary(self.index, dfx_dx(self.number.clone())),
918            },
919        }
920    }
921
922    /**
923     * Creates a new Record from a reference to two existing Records by applying
924     * some binary function to them which operates on two arguments of the type
925     * the Records wrap.
926     *
927     * To compute the new record, the binary function of some inputs x and y to some
928     * output z is needed along with its derivative with respect to its first input x and
929     * its derivative with respect to its second input y.
930     *
931     * For example, atan2 takes two arguments, but the Real trait
932     * does not include this operation and Record has no operations for it specifically.
933     * However, you can use this function to compute the atan2 of two Records like so:
934     *
935     * ```
936     * use easy_ml::differentiation::{Record, WengertList};
937     * let list = WengertList::new();
938     * let x = Record::variable(3.0f32, &list);
939     * let y = Record::variable(3.0f32, &list);
940     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
941     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
942     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
943     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
944     * let z = x.binary(&y,
945     *     |x, y| x.atan2(y),
946     *     |x, y| y/((x*x) + (y*y)),
947     *     |x, y| -x/((x*x) + (y*y))
948     * );
949     * let derivatives = z.derivatives();
950     * let dx = derivatives[&x];
951     * let dy = derivatives[&y];
952     * ```
953     */
954    #[inline]
955    #[track_caller]
956    pub fn binary(
957        &self,
958        rhs: &Record<'a, T>,
959        fxy: impl Fn(T, T) -> T,
960        dfxy_dx: impl Fn(T, T) -> T,
961        dfxy_dy: impl Fn(T, T) -> T,
962    ) -> Record<'_, T> {
963        assert!(
964            record_operations::same_list(self, rhs),
965            "Records must be using the same WengertList"
966        );
967        match (self.history, rhs.history) {
968            (None, None) => Record {
969                number: fxy(self.number.clone(), rhs.number.clone()),
970                history: None,
971                index: 0,
972            },
973            (Some(history), None) => Record {
974                number: fxy(self.number.clone(), rhs.number.clone()),
975                history: Some(history),
976                index: history.append_unary(
977                    // if rhs didn't have a history, don't track that derivative
978                    self.index,
979                    dfxy_dx(self.number.clone(), rhs.number.clone()),
980                ),
981            },
982            (None, Some(history)) => Record {
983                number: fxy(self.number.clone(), rhs.number.clone()),
984                history: Some(history),
985                index: history.append_unary(
986                    // if self didn't have a history, don't track that derivative
987                    rhs.index,
988                    dfxy_dy(self.number.clone(), rhs.number.clone()),
989                ),
990            },
991            (Some(history), Some(_)) => Record {
992                number: fxy(self.number.clone(), rhs.number.clone()),
993                history: Some(history),
994                index: history.append_binary(
995                    self.index,
996                    dfxy_dx(self.number.clone(), rhs.number.clone()),
997                    rhs.index,
998                    dfxy_dy(self.number.clone(), rhs.number.clone()),
999                ),
1000            },
1001        }
1002    }
1003}
1004
1005#[cfg(test)]
1006#[should_panic]
1007#[test]
1008fn test_record_derivatives_when_no_history() {
1009    let record = Record::constant(1.0);
1010    record.derivatives();
1011}
1012
1013#[test]
1014fn test_sync() {
1015    fn assert_sync<T: Sync>() {}
1016    assert_sync::<Trace<f64>>();
1017}
1018
1019#[test]
1020fn test_send() {
1021    fn assert_send<T: Send>() {}
1022    assert_send::<Trace<f64>>();
1023}