easy_ml/
differentiation.rs

1#![allow(clippy::double_parens)]
2/*!
3 * (Automatic) Differentiation helpers
4 *
5 * # Automatic Differentiation
6 *
7 * This module provides structs for performing Forward and Reverse Automatic Differentiation
8 *
9 * ## Automatic Differentiation is not [Numerical Differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation)
10 *
11 * You were probably introduced to differentiation as numeric differentiation,
12 * ie if you have a function 3x<sup>2</sup> then you can estimate its gradient
13 * at some value x by computing 3x<sup>2</sup> and 3(x+h)<sup>2</sup> where h
14 * is a very small value. The tangent line these two points create gives you an approximation
15 * of the gradient when you calculate (f(x+h) - f(x)) / h. Unfortunately floating
16 * point numbers in computers have limited precision, so this method is only approximate
17 * and can result in floating point errors. 1 + 1 might equal 2 but as you go smaller
18 * 10<sup>-i</sup> + 10<sup>-i</sup> starts to loook rather like 10<sup>-i</sup> as i goes
19 * into double digits.
20 *
21 * ## Automatic Differentiation is not Symbolic Differentiation
22 *
23 * If you were taught calculus you have probably done plenty of symbolic differentiation
24 * by hand. A function 3x<sup>2</sup> can be symbolically differentiated into 6x by applying
25 * simple rules to manipulate the algebra. Unfortunately the rules aren't so simple for
26 * more complex expressions such as [exponents](https://www.wolframalpha.com/input/?i=d%28x%5Ee%5E2%29%2Fdx),
27 * [logs](https://www.wolframalpha.com/input/?i=d%28log%28log%28x%29%29%29%2Fdx) or
28 * [trigonometry](https://www.wolframalpha.com/input/?i=d%28sin%28cos%28x%29%29%29%2Fdx).
29 * Symbolic differentiation can give you expressions which are just as or more complicated
30 * than the original, and doing it by hand can be error prone. Symbolic Differentiation is
31 * also tricky to relate to algorithmic computations that use control structures.
32 *
33 * ## [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
34 *
35 * Automatic Differentiation computes the derivative of a function without rewriting
36 * the function as symbolic differentiation does and without the precision issues of numerical
37 * differentiation by splitting the derivative into lots of tiny steps of basic operations
38 * like addition and multiplication. These are combined using the chain rule. The downside
39 * is more memory is used than in symbolic or numerical differentiation, as derivatives have
40 * to be tracked through the computational graph.
41 *
42 * # Forward Differentiation
43 *
44 * Forward Differentiation computes all the gradients in a computational graph with respect
45 * to an input. For example, if you have a function f(x, y) = 5x<sup>3</sup> - 4x<sup>2</sup> +
46 * 10x - y, then for some actual value of x and y you can compute f(x,y) and δf(x,y)/δx
47 * together in one forward pass using forward differentiation. You can also make another pass
48 * and compute f(x,y) and δf(x,y)/δy for some actual value of x and y. It is possible to avoid
49 * redundantly calculating f(x,y) multiple times, but I am waiting on const generics to implement
50 * this. Regardless, forward differentiation requires making at least N+1 passes of the
51 * function to compute the derivatives of the output with respect to N inputs - and the current
52 * implementation will make 2N. However, you do get the gradients for every output in a
53 * single pass. This is poorly suited to neural nets as they often have a single output(loss)
54 * to differentiate many many inputs with respect to.
55 *
56 * # Reverse Mode Differentiation
57 *
58 * Reverse Mode Differentiation computes all the gradients in a computational graph for
59 * the same output. For example, if you have a function f(x, y) = 5x<sup>3</sup> -
60 * 4x<sup>2</sup> + 10x - y, then for some actual value of x and y you can compute f(x,y)
61 * and store all the intermediate results. You can then run a backward pass on the output
62 * of f(x, y) and obtain δf(x,y)/δx and δf(x,y)/δy for the actual values of x and y in a
63 * single pass. The catch is that reverse mode must store as many intermediate values as
64 * there are steps in the function which can use much more memory than forward mode.
65 * Reverse mode also requires making N backward passes to get the gradients for N different
66 * outputs. This is well suited to neural nets because we often have a single output (loss)
67 * to differentiate many inputs with respect to. However, reverse mode will be slower than
68 * forward mode if the number of inputs is small or there are many outputs.
69 *
70 * # Usage
71 *
72 * [See sub module for usage examples](usage)
73 *
74 * # Further information
75 *
76 * - [Automatic Differentiation Step by Step](https://medium.com/@marksaroufim/automatic-differentiation-step-by-step-24240f97a6e6)
77 * - [Forward Mode Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers)
78 * - [Reverse Mode Automatic Differentiation](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
79 * - [Automatic Differentiation: The most criminally underused tool in the potential machine learning toolbox?](https://justindomke.wordpress.com/2009/02/17/automatic-differentiation-the-most-criminally-underused-tool-in-the-potential-machine-learning-toolbox/)
80 * - [Yes you should understand backprop](https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b)
81 */
82
83mod container_record;
84mod functions;
85pub mod operations;
86pub mod record_operations;
87pub mod trace_operations;
88pub mod usage;
89
90pub use container_record::*;
91
92use crate::numeric::{Numeric, NumericRef};
93
94#[cfg(feature = "serde")]
95use serde::{Deserialize, Serialize};
96
97/**
98 * A trait with no methods which is implemented for all primitive types.
99 *
100 * Importantly this trait is not implemented for Traces (or Records), to stop the compiler
101 * from trying to evaluate nested Traces of Traces or Records of Records as Numeric types.
102 * There is no reason to create a Trace of a Trace or Record of a Record, it won't do
103 * anything a Trace or Record can't except use more memory.
104 *
105 * The boilerplate implementations for primitives is performed with a macro.
106 * If a primitive type is missing from this list, please open an issue to add it in.
107 */
108pub trait Primitive {}
109
110/**
111 * A dual number which traces a real number and keeps track of its derivative.
112 * This is used to perform Forward Automatic Differentiation
113 *
114 * Trace implements only first order differentiation. For example, given a function
115 * 3x<sup>2</sup>, you can use calculus to work out that its derivative with respect
116 * to x is 6x. You can also take the derivative of 6x with respect to x and work out
117 * that the second derivative is 6. By instead writing the function 3x<sup>2</sup> in
118 * code using Trace types as your numbers you can compute the first order derivative
119 * for a given value of x by passing your function `Trace { number: x, derivative: 1.0 }`.
120 *
121 * ```
122 * use easy_ml::differentiation::Trace;
123 * let x = Trace { number: 3.2, derivative: 1.0 };
124 * let dx = Trace::constant(3.0) * x * x;
125 * assert_eq!(dx.derivative, 3.2 * 6.0);
126 * ```
127 *
128 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
129 * differentiation.
130 *
131 * # Acknowledgments
132 *
133 * The wikipedia page on [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
134 * provided a very useful overview and explanation for understanding Forward Mode Automatic
135 * Differentiation as well as the implementation rules.
136 */
137#[derive(Debug)]
138#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
139pub struct Trace<T: Primitive> {
140    /**
141     * The real number
142     */
143    pub number: T,
144    /**
145     * The first order derivative of this number.
146     */
147    pub derivative: T,
148}
149
150/**
151 * The main set of methods for using Trace types for Forward Differentiation.
152 *
153 * The general steps are
154 * 1. create one variable
155 * 2. create as many constants as needed
156 * 3. do operations on the variable and constants
157 * 4. the outputs will have derivatives computed which can be accessed from
158 * the `.derivative` field, with each derivative being the output with respect
159 * to the input variable.
160 * 5. if you need derivatives for a different input then do everything all over again
161 * or do them all in parallel
162 */
163impl<T: Numeric + Primitive> Trace<T> {
164    /**
165     * Constants are lifted to Traces with a derivative of 0
166     *
167     * Why zero for the starting derivative? Because for any constant C
168     * δC/δx = 0, as with symbolic differentiation.
169     */
170    pub fn constant(c: T) -> Trace<T> {
171        Trace {
172            number: c,
173            derivative: T::zero(),
174        }
175    }
176
177    /**
178     * To lift a variable that you want to find the derivative of
179     * a function to, the Trace starts with a derivative of 1
180     *
181     * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
182     * differentiation.
183     */
184    pub fn variable(x: T) -> Trace<T> {
185        Trace {
186            number: x,
187            derivative: T::one(),
188        }
189    }
190
191    /**
192     * Computes the derivative of a function with respect to its input x.
193     *
194     * This is a shorthand for `(function(Trace::variable(x))).derivative`
195     *
196     * In the more general case, if you provide a function with an input x
197     * and it returns N outputs y<sub>1</sub> to y<sub>N</sub> then you
198     * have computed all the derivatives δy<sub>i</sub>/δx for i = 1 to N.
199     */
200    pub fn derivative(function: impl FnOnce(Trace<T>) -> Trace<T>, x: T) -> T {
201        (function(Trace::variable(x))).derivative
202    }
203}
204
205impl<T: Numeric + Primitive> Trace<T>
206where
207    for<'a> &'a T: NumericRef<T>,
208{
209    /**
210     * Creates a new Trace from a reference to an existing Trace by applying
211     * some unary function to it which operates on the type the Trace wraps.
212     *
213     * To compute the new trace, the unary function of some input x to some
214     * output y is needed along with its derivative with respect to its input x.
215     *
216     * For example, tanh is a commonly used activation function, but the Real trait
217     * does not include this operation and Trace has no operations for it specifically.
218     * However, you can use this function to compute the tanh of a Trace like so:
219     *
220     * ```
221     * use easy_ml::differentiation::Trace;
222     * let x = Trace::variable(0.7f32);
223     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
224     * // 1 / (cosh(x) * cosh(x))
225     * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
226     * assert_eq!(y.derivative, 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
227     * ```
228     */
229    #[inline]
230    pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Trace<T> {
231        Trace {
232            number: fx(self.number.clone()),
233            derivative: self.derivative.clone() * dfx_dx(self.number.clone()),
234        }
235    }
236
237    /**
238     * Creates a new Trace from a reference to two existing Traces by applying
239     * some binary function to them which operates on two arguments of the type
240     * the Traces wrap.
241     *
242     * To compute the new trace, the binary function of some inputs x and y to some
243     * output z is needed along with its derivative with respect to its first input x and
244     * its derivative with respect to its second input y.
245     *
246     * For example, atan2 takes two arguments, but the Real trait
247     * does not include this operation and Trace has no operations for it specifically.
248     * However, you can use this function to compute the atan2 of two Traces like so:
249     *
250     * ```
251     * use easy_ml::differentiation::Trace;
252     * let x = Trace::variable(3.0f32);
253     * let y = Trace::variable(3.0f32);
254     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
255     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
256     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
257     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
258     * let z = x.binary(&y,
259     *     |x, y| x.atan2(y),
260     *     |x, y| y/((x*x) + (y*y)),
261     *     |x, y| -x/((x*x) + (y*y))
262     * );
263     * ```
264     */
265    #[inline]
266    pub fn binary(
267        &self,
268        rhs: &Trace<T>,
269        fxy: impl Fn(T, T) -> T,
270        dfxy_dx: impl Fn(T, T) -> T,
271        dfxy_dy: impl Fn(T, T) -> T,
272    ) -> Trace<T> {
273        Trace {
274            number: fxy(self.number.clone(), rhs.number.clone()),
275            #[rustfmt::skip]
276            derivative: (
277                ((self.derivative.clone() * dfxy_dx(self.number.clone(), rhs.number.clone()))
278                + (rhs.derivative.clone() * dfxy_dy(self.number.clone(), rhs.number.clone())))
279            ),
280        }
281    }
282}
283
284use std::cell::RefCell;
285
286/**
287 * WengertLists are indexed with [`usize`].
288 */
289pub type Index = usize;
290
291/**
292 * A list of operations performed in a forward pass of a dynamic computational graph,
293 * used for Reverse Mode Automatic Differentiation.
294 *
295 * This is dynamic, as in, you build the [Wengert list](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation)
296 * at runtime by performing operations like addition and multiplication on
297 * [Records](Record) that were created with that Wengert list.
298 *
299 * When you perform a backward pass to obtain the gradients you travel back up the
300 * computational graph using the stored intermediate values from this list to compute
301 * all the gradients of the inputs and every intermediate step with respect to an output.
302 *
303 * Although sophisticated implementations can make the Wengert list only log(N) in length
304 * by storing only some of the intermediate steps of N computational steps, this implementation
305 * is not as sophisticated, and will store all of them.
306 *
307 * # Panics
308 *
309 * Every operation and nearly every method a Record has involves manipulating the
310 * record's history on its referenced WengertList. This WengertList itself maintains
311 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
312 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
313 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
314 * compiler correctly infers that it is not safe to share references to WengertLists between
315 * threads, nor transfer Records across threads. If you called a method on two Records that both
316 * mutably borrowed from the same WengertList at once, which could be trivially done with
317 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
318 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
319 * Record method call, and you can't call two methods simulatenously without threading.
320 */
321#[derive(Debug)]
322pub struct WengertList<T> {
323    // It is neccessary to wrap the vec in a RefCell to allow for mutating
324    // this list from immutable references held by each
325    operations: RefCell<Vec<Operation<T>>>,
326}
327
328struct BorrowedWengertList<'a, T> {
329    operations: &'a mut Vec<Operation<T>>,
330}
331
332/**
333 * A binary operation to record on a WengertList. For unary operations the
334 * right derivative is set to 0, and for nullary operations both derivatives
335 * are set to 0.
336 *
337 * Each operation acts like a node in an upside down binary tree, with two parents that
338 * each node was computed from. The main difference is that the numerical
339 * index of those parents in the WengertList is stored, rather than any pointers.
340 */
341#[derive(Debug)]
342struct Operation<T> {
343    left_parent: Index,
344    right_parent: Index,
345    left_derivative: T,
346    right_derivative: T,
347}
348
349/**
350 * Computed derivatives of a computational graph for some output [Record] variable.
351 *
352 * This can be indexed by any Record used in the computational graph to get
353 * the derivative with respect to that input.
354 *
355 * Indexing using Records not involved in the computational graph, or involved
356 * in a different one will return nonsense or index out of bounds and panic. In
357 * the future this may be changed to always panic.
358 */
359#[derive(Debug)]
360pub struct Derivatives<T> {
361    derivatives: Vec<T>,
362}
363
364/**
365 * Any derivatives of a Cloneable type implements clone
366 */
367impl<T: Clone> Clone for Derivatives<T> {
368    fn clone(&self) -> Self {
369        Derivatives {
370            derivatives: self.derivatives.clone(),
371        }
372    }
373}
374
375impl<T: Clone + Primitive> Derivatives<T> {
376    /**
377     * Quries the derivative at the provided record as input.
378     *
379     * If you construct a Derivatives object for some output y,
380     * and call .at(&x) on it for some input x, this returns dy/dx.
381     */
382    pub fn at(&self, input: &Record<T>) -> T {
383        self.derivatives[input.index].clone()
384    }
385}
386
387impl<'a, T: Primitive> std::ops::Index<&Record<'a, T>> for Derivatives<T> {
388    type Output = T;
389    /**
390     * Quries the derivative at the provided record as input.
391     *
392     * If you construct a Derivatives object for some output y,
393     * and call .at(&x) on it for some input x, this returns dy/dx.
394     */
395    fn index(&self, input: &Record<'a, T>) -> &Self::Output {
396        &self.derivatives[input.index]
397    }
398}
399
400impl<T> std::convert::From<Derivatives<T>> for Vec<T> {
401    /**
402     * Converts the Derivatives struct into a Vec of derivatives that
403     * can be indexed with `usize`s. The indexes correspond to the
404     * index field on Records.
405     */
406    fn from(derivatives: Derivatives<T>) -> Self {
407        derivatives.derivatives
408    }
409}
410
411/**
412 * Any operation of a Cloneable type implements clone
413 */
414impl<T: Clone + Primitive> Clone for Operation<T> {
415    fn clone(&self) -> Self {
416        Operation {
417            left_parent: self.left_parent,
418            right_parent: self.right_parent,
419            left_derivative: self.left_derivative.clone(),
420            right_derivative: self.right_derivative.clone(),
421        }
422    }
423}
424
425/**
426 * A wrapper around a real number which records it going through the computational
427 * graph. This is used to perform Reverse Mode Automatic Differentiation.
428 *
429 * # Panics
430 *
431 * Every operation and nearly every method a Record has involves manipulating the
432 * record's history on its referenced [WengertList]. This WengertList itself maintains
433 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
434 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
435 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
436 * compiler infers that it is not safe to share references to WengertLists between threads,
437 * nor transfer Records across threads. If you called a method on two Records that both
438 * mutably borrowed from the same WengertList at once, which could be trivially done with
439 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
440 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
441 * Record method call, and you can't call two methods simulatenously without threading.
442 *
443 * # Acknowledgments
444 *
445 * A [tutorial by Rufflewind](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
446 * and the associated [MIT licensed](http://opensource.org/licenses/MIT)
447 * [soure code](https://github.com/Rufflewind/revad/blob/master/src/tape.rs) were invaluable
448 * in providing understanding on how to implement Reverse Mode Automatic Differentiation.
449 */
450#[derive(Debug)]
451pub struct Record<'a, T: Primitive> {
452    // A record consists of a number used in the forward pass, as
453    // well as a WengertList of operations performed on the numbers
454    // and each record needs to know which point in the history they
455    // are for.
456    /**
457     * The real number
458     */
459    pub number: T,
460    history: Option<&'a WengertList<T>>,
461    /**
462     * The index of this number in its [WengertList]. The first entry will be 0,
463     * the next 1 and so on.
464     *
465     * In normal use cases you should not need to read this field,
466     * you can index [Derivatives] directly with Records.
467     */
468    pub index: Index,
469}
470
471/**
472 * The main set of methods for using Record types for Reverse Differentiation.
473 *
474 * The general steps are
475 * 1. create a `WengertList`
476 * 2. create variables from this list
477 * 3. do operations on the variables
478 * 4. from the output you want to compute derivatives for call `.derivatives()`
479 * 5. index the `Derivatives` object with the index variables to get the derivatives
480 * with respect to each input
481 * 6. if you want to make another pass call `clear()` on the `WengertList`
482 * and then call `reset()` on all of the variables to forget the gradients already
483 * computed (the order of `clear` then `reset` is very important!).
484 *
485 * Constants can be used to save memory if you have numbers that
486 * you do not need to compute the gradients with respect to.
487 */
488impl<'a, T: Numeric + Primitive> Record<'a, T> {
489    /**
490     * Creates an untracked Record which has no backing WengertList.
491     *
492     * This is provided for using constants along with Records in operations.
493     *
494     * For example with y = x + 4 the computation graph could be conceived as
495     * a y node with parent nodes of x and 4 combined with the operation +.
496     * However there is no need to record the derivatives of a constant, so
497     * instead the computation graph can be conceived as a y node with a single
498     * parent node of x and the unary operation of +4.
499     *
500     * This is also used for the type level constructors required by Numeric
501     * which are also considered constants.
502     */
503    pub fn constant(c: T) -> Record<'a, T> {
504        Record {
505            number: c,
506            history: None,
507            index: 0,
508        }
509    }
510
511    /**
512     * Creates a record backed by the provided WengertList.
513     *
514     * The record cannot live longer than the WengertList, hence
515     * the following example does not compile
516     *
517     * ```compile_fail
518     * use easy_ml::differentiation::Record;
519     * use easy_ml::differentiation::WengertList;
520     * let record = {
521     *     let list = WengertList::new();
522     *     Record::variable(1.0, &list)
523     * }; // list no longer in scope
524     * ```
525     *
526     * You can alternatively use the [record constructor on the WengertList type](WengertList::variable()).
527     */
528    pub fn variable(x: T, history: &'a WengertList<T>) -> Record<'a, T> {
529        Record {
530            number: x,
531            history: Some(history),
532            index: history.append_nullary(),
533        }
534    }
535
536    /**
537     * Creates a record from a constant/variable directly, most likely obtained by getting a
538     * tensor view of an existing [container](RecordContainer). **The inputs are not checked for
539     * validity**. It is possible to pass in the wrong Wengert list here or even numbers with
540     * indexes that aren’t tracked on the WengertList.
541     *
542     * It is recommended to use this constructor only in conjunction with manipulating an existing
543     * container and not for creating new variables. Any variables created outside of
544     * `Record::variable` / `RecordContainer::variables` would have to be manually added to the
545     * correct Wengert list, and any arithmetic operations would also need tracking correctly.
546     */
547    pub fn from_existing(number: (T, Index), history: Option<&'a WengertList<T>>) -> Record<'a, T> {
548        Record {
549            number: number.0,
550            history,
551            index: number.1,
552        }
553    }
554
555    /**
556     * Resets this Record to place it back on its WengertList, for use
557     * in performing another derivation after clearing the WengertList.
558     */
559    pub fn reset(&mut self) {
560        match self.history {
561            None => (), // noop
562            Some(history) => self.index = history.append_nullary(),
563        };
564    }
565
566    /**
567     * A convenience helper function which takes a Record by value and
568     * calls [reset](Record::reset()) on it.
569     */
570    pub fn do_reset(mut x: Record<T>) -> Record<T> {
571        x.reset();
572        x
573    }
574
575    /**
576     * Gets the WengertList this Record is backed by if a variable, and [None] if a constant.
577     */
578    pub fn history(&self) -> Option<&'a WengertList<T>> {
579        self.history
580    }
581}
582
583impl<'a, T: Numeric + Primitive> Record<'a, T>
584where
585    for<'t> &'t T: NumericRef<T>,
586{
587    /**
588     * Performs a backward pass up this record's WengertList from this
589     * record as the output, computing all the derivatives for the inputs
590     * involving this output.
591     *
592     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
593     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
594     *
595     * # Panics
596     *
597     * Panics if the Record has no backing WengertList, ie it was created as a
598     * constant.
599     */
600    #[track_caller]
601    pub fn derivatives(&self) -> Derivatives<T> {
602        match self.try_derivatives() {
603            None => panic!("Record has no WengertList to find derivatives from"),
604            Some(d) => d,
605        }
606    }
607
608    /**
609     * Performs a backward pass up this record's WengertList from this
610     * record as the output, computing all the derivatives for the inputs
611     * involving this output.
612     *
613     * If this record has no WengertList, ie it's a constant, None is returned instead.
614     *
615     * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
616     * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
617     */
618    pub fn try_derivatives(&self) -> Option<Derivatives<T>> {
619        let history = self.history?;
620        let operations = history.operations.borrow();
621
622        let mut derivatives = vec![T::zero(); operations.len()];
623
624        // δy/δy = 1
625        derivatives[self.index] = T::one();
626
627        // Go back up the computation graph to the inputs
628        for i in (0..operations.len()).rev() {
629            let operation = operations[i].clone();
630            let derivative = derivatives[i].clone();
631            // The chain rule allows breaking up the derivative of the output y
632            // with respect to the input x into many smaller derivatives that
633            // are summed together.
634            // δy/δx = δy/δw * δw/δx
635            // δy/δx = sum for all i parents of y ( δy/δw_i * δw_i/δx )
636            derivatives[operation.left_parent] = derivatives[operation.left_parent].clone()
637                + derivative.clone() * operation.left_derivative;
638            derivatives[operation.right_parent] = derivatives[operation.right_parent].clone()
639                + derivative * operation.right_derivative;
640        }
641
642        Some(Derivatives { derivatives })
643    }
644}
645
646impl<T: Primitive> WengertList<T> {
647    /**
648     * Creates a new empty WengertList from which Records can be constructed.
649     */
650    pub fn new() -> WengertList<T> {
651        WengertList {
652            operations: RefCell::new(Vec::new()),
653        }
654    }
655}
656
657impl<T: Primitive> Default for WengertList<T> {
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663impl<T> WengertList<T> {
664    /**
665     * Clears a WengertList to make it empty again. After clearing a WengertList
666     * you must reset all the Records still using that list. Then you can perform
667     * another computation and get new gradients.
668     */
669    pub fn clear(&self) {
670        self.operations.borrow_mut().clear();
671    }
672}
673
674impl<T: Numeric + Primitive> WengertList<T> {
675    /**
676     * Creates a record backed by this WengertList.
677     *
678     * You can alternatively use the [record constructor on the Record type](Record::variable()).
679     */
680    pub fn variable(&self, x: T) -> Record<T> {
681        Record {
682            number: x,
683            history: Some(self),
684            index: self.append_nullary(),
685        }
686    }
687
688    /**
689     * Adds a value to the list which does not have any parent values.
690     */
691    fn append_nullary(&self) -> Index {
692        use std::ops::DerefMut;
693        let mut borrow = self.operations.borrow_mut();
694        BorrowedWengertList::new(borrow.deref_mut()).append_nullary()
695    }
696
697    /**
698     * Adds a number of values to the list which do not have any parent values, returning
699     * the index of the first added value, the others will be contiguously afterwards.
700     *
701     * If values is 0, returns the first index that would be used but wasn't.
702     */
703    fn append_nullary_repeating(&self, values: usize) -> Index {
704        let mut operations = self.operations.borrow_mut();
705        // insert into end of list
706        let starting_index = operations.len();
707        for i in 0..values {
708            let index = starting_index + i;
709            operations.push(Operation {
710                // this index of the child is used for both indexes as these
711                // won't be needed but will always be valid (ie point to a
712                // real entry on the list)
713                left_parent: index,
714                right_parent: index,
715                // for the parents 0 is used to zero out these calculations
716                // as there are no parents
717                left_derivative: T::zero(),
718                right_derivative: T::zero(),
719            });
720        }
721        starting_index
722    }
723
724    /**
725     * Adds a value to the list which has one parent.
726     *
727     * For an output w_N which depends on one parent w_N-1
728     * the derivative cached here is δw_N / δw_N-1
729     *
730     * For example, if z = sin(x), then δz/δx = cos(x)
731     */
732    fn append_unary(&self, parent: Index, derivative: T) -> Index {
733        use std::ops::DerefMut;
734        let mut borrow = self.operations.borrow_mut();
735        BorrowedWengertList::new(borrow.deref_mut()).append_unary(parent, derivative)
736    }
737
738    /**
739     * Adds a value to the list which has two parents.
740     *
741     * For an output w_N which depends on two parents w_N-1
742     * and w_N-2 the derivatives cached here are δw_N / δw_N-1
743     * and δw_N / δw_N-2.
744     *
745     * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
746     * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
747     */
748    fn append_binary(
749        &self,
750        left_parent: Index,
751        left_derivative: T,
752        right_parent: Index,
753        right_derivative: T,
754    ) -> Index {
755        use std::ops::DerefMut;
756        let mut borrow = self.operations.borrow_mut();
757        BorrowedWengertList::new(borrow.deref_mut()).append_binary(
758            left_parent,
759            left_derivative,
760            right_parent,
761            right_derivative,
762        )
763    }
764
765    /**
766     * Borrows the WengertList mutably for batch operations. It is *very* important to
767     * hold onto the borrow only for as long as needed then drop it immediately. To avoid panics
768     * Easy ML needs to ensure 100% of method calls on the public API do not maintain a borrow
769     * after they finish executing. This was previously enforced by not having any batch
770     * append APIs, but they're needed for RecordContainer. Calling borrow again while still
771     * holding the first would trigger a panic, as would holding onto the borrow after the public
772     * API method is finished
773     */
774    fn borrow<F>(&self, op: F)
775    where
776        F: FnOnce(&mut BorrowedWengertList<T>),
777    {
778        use std::ops::DerefMut;
779        let mut borrow = self.operations.borrow_mut();
780        op(&mut BorrowedWengertList::new(borrow.deref_mut()));
781    }
782}
783
784/**
785 * Any Wengert list of a Cloneable type implements clone
786 */
787impl<T: Clone + Primitive> Clone for WengertList<T> {
788    fn clone(&self) -> Self {
789        WengertList {
790            operations: RefCell::new(self.operations.borrow().clone()),
791        }
792    }
793}
794
795/**
796 * Methods for appending Operations after borrowing the Wengert list.
797 */
798impl<'a, T: Numeric + Primitive> BorrowedWengertList<'a, T> {
799    fn new(operations: &mut Vec<Operation<T>>) -> BorrowedWengertList<T> {
800        BorrowedWengertList { operations }
801    }
802
803    /**
804     * Adds a value to the list which does not have any parent values.
805     */
806    fn append_nullary(&mut self) -> Index {
807        // insert into end of list
808        let index = self.operations.len();
809        self.operations.push(Operation {
810            // this index of the child is used for both indexes as these
811            // won't be needed but will always be valid (ie point to a
812            // real entry on the list)
813            left_parent: index,
814            right_parent: index,
815            // for the parents 0 is used to zero out these calculations
816            // as there are no parents
817            left_derivative: T::zero(),
818            right_derivative: T::zero(),
819        });
820        index
821    }
822
823    /**
824     * Adds a value to the list which has one parent.
825     *
826     * For an output w_N which depends on one parent w_N-1
827     * the derivative cached here is δw_N / δw_N-1
828     *
829     * For example, if z = sin(x), then δz/δx = cos(x)
830     */
831    fn append_unary(&mut self, parent: Index, derivative: T) -> Index {
832        // insert into end of list
833        let index = self.operations.len();
834        self.operations.push(Operation {
835            left_parent: parent,
836            // this index of the child is used as this index won't be needed
837            // but will always be valid (ie points to a real entry on the list)
838            right_parent: index,
839            left_derivative: derivative,
840            // for the right parent 0 is used to zero out this calculation
841            // as there is no right parent
842            right_derivative: T::zero(),
843        });
844        index
845    }
846
847    /**
848     * Adds a value to the list which has two parents.
849     *
850     * For an output w_N which depends on two parents w_N-1
851     * and w_N-2 the derivatives cached here are δw_N / δw_N-1
852     * and δw_N / δw_N-2.
853     *
854     * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
855     * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
856     */
857    fn append_binary(
858        &mut self,
859        left_parent: Index,
860        left_derivative: T,
861        right_parent: Index,
862        right_derivative: T,
863    ) -> Index {
864        // insert into end of list
865        let index = self.operations.len();
866        self.operations.push(Operation {
867            left_parent,
868            right_parent,
869            left_derivative,
870            right_derivative,
871        });
872        index
873    }
874}
875
876impl<'a, T: Numeric + Primitive> Record<'a, T>
877where
878    for<'t> &'t T: NumericRef<T>,
879{
880    /**
881     * Creates a new Record from a reference to an existing Record by applying
882     * some unary function to it which operates on the type the Record wraps.
883     *
884     * To compute the new record, the unary function of some input x to some
885     * output y is needed along with its derivative with respect to its input x.
886     *
887     * For example, tanh is a commonly used activation function, but the Real trait
888     * does not include this operation and Record has no operations for it specifically.
889     * However, you can use this function to compute the tanh of a Record like so:
890     *
891     * ```
892     * use easy_ml::differentiation::{Record, WengertList};
893     * let list = WengertList::new();
894     * let x = Record::variable(0.7f32, &list);
895     * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
896     * // 1 / (cosh(x) * cosh(x))
897     * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
898     * assert_eq!(y.derivatives()[&x], 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
899     * ```
900     */
901    #[inline]
902    pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Record<T> {
903        match self.history {
904            None => Record {
905                number: fx(self.number.clone()),
906                history: None,
907                index: 0,
908            },
909            Some(history) => Record {
910                number: fx(self.number.clone()),
911                history: Some(history),
912                index: history.append_unary(self.index, dfx_dx(self.number.clone())),
913            },
914        }
915    }
916
917    /**
918     * Creates a new Record from a reference to two existing Records by applying
919     * some binary function to them which operates on two arguments of the type
920     * the Records wrap.
921     *
922     * To compute the new record, the binary function of some inputs x and y to some
923     * output z is needed along with its derivative with respect to its first input x and
924     * its derivative with respect to its second input y.
925     *
926     * For example, atan2 takes two arguments, but the Real trait
927     * does not include this operation and Record has no operations for it specifically.
928     * However, you can use this function to compute the atan2 of two Records like so:
929     *
930     * ```
931     * use easy_ml::differentiation::{Record, WengertList};
932     * let list = WengertList::new();
933     * let x = Record::variable(3.0f32, &list);
934     * let y = Record::variable(3.0f32, &list);
935     * // the derivative of atan2 with respect to x is y/(x*x + y*y)
936     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
937     * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
938     * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
939     * let z = x.binary(&y,
940     *     |x, y| x.atan2(y),
941     *     |x, y| y/((x*x) + (y*y)),
942     *     |x, y| -x/((x*x) + (y*y))
943     * );
944     * let derivatives = z.derivatives();
945     * let dx = derivatives[&x];
946     * let dy = derivatives[&y];
947     * ```
948     */
949    #[inline]
950    #[track_caller]
951    pub fn binary(
952        &self,
953        rhs: &Record<'a, T>,
954        fxy: impl Fn(T, T) -> T,
955        dfxy_dx: impl Fn(T, T) -> T,
956        dfxy_dy: impl Fn(T, T) -> T,
957    ) -> Record<T> {
958        assert!(
959            record_operations::same_list(self, rhs),
960            "Records must be using the same WengertList"
961        );
962        match (self.history, rhs.history) {
963            (None, None) => Record {
964                number: fxy(self.number.clone(), rhs.number.clone()),
965                history: None,
966                index: 0,
967            },
968            (Some(history), None) => Record {
969                number: fxy(self.number.clone(), rhs.number.clone()),
970                history: Some(history),
971                index: history.append_unary(
972                    // if rhs didn't have a history, don't track that derivative
973                    self.index,
974                    dfxy_dx(self.number.clone(), rhs.number.clone()),
975                ),
976            },
977            (None, Some(history)) => Record {
978                number: fxy(self.number.clone(), rhs.number.clone()),
979                history: Some(history),
980                index: history.append_unary(
981                    // if self didn't have a history, don't track that derivative
982                    rhs.index,
983                    dfxy_dy(self.number.clone(), rhs.number.clone()),
984                ),
985            },
986            (Some(history), Some(_)) => Record {
987                number: fxy(self.number.clone(), rhs.number.clone()),
988                history: Some(history),
989                index: history.append_binary(
990                    self.index,
991                    dfxy_dx(self.number.clone(), rhs.number.clone()),
992                    rhs.index,
993                    dfxy_dy(self.number.clone(), rhs.number.clone()),
994                ),
995            },
996        }
997    }
998}
999
1000#[cfg(test)]
1001#[should_panic]
1002#[test]
1003fn test_record_derivatives_when_no_history() {
1004    let record = Record::constant(1.0);
1005    record.derivatives();
1006}
1007
1008#[test]
1009fn test_sync() {
1010    fn assert_sync<T: Sync>() {}
1011    assert_sync::<Trace<f64>>();
1012}
1013
1014#[test]
1015fn test_send() {
1016    fn assert_send<T: Send>() {}
1017    assert_send::<Trace<f64>>();
1018}