easy_ml/differentiation.rs
1#![allow(clippy::double_parens)]
2/*!
3 * (Automatic) Differentiation helpers
4 *
5 * # Automatic Differentiation
6 *
7 * This module provides structs for performing Forward and Reverse Automatic Differentiation
8 *
9 * ## Automatic Differentiation is not [Numerical Differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation)
10 *
11 * You were probably introduced to differentiation as numeric differentiation,
12 * ie if you have a function 3x<sup>2</sup> then you can estimate its gradient
13 * at some value x by computing 3x<sup>2</sup> and 3(x+h)<sup>2</sup> where h
14 * is a very small value. The tangent line these two points create gives you an approximation
15 * of the gradient when you calculate (f(x+h) - f(x)) / h. Unfortunately floating
16 * point numbers in computers have limited precision, so this method is only approximate
17 * and can result in floating point errors. 1 + 1 might equal 2 but as you go smaller
18 * 10<sup>-i</sup> + 10<sup>-i</sup> starts to loook rather like 10<sup>-i</sup> as i goes
19 * into double digits.
20 *
21 * ## Automatic Differentiation is not Symbolic Differentiation
22 *
23 * If you were taught calculus you have probably done plenty of symbolic differentiation
24 * by hand. A function 3x<sup>2</sup> can be symbolically differentiated into 6x by applying
25 * simple rules to manipulate the algebra. Unfortunately the rules aren't so simple for
26 * more complex expressions such as [exponents](https://www.wolframalpha.com/input/?i=d%28x%5Ee%5E2%29%2Fdx),
27 * [logs](https://www.wolframalpha.com/input/?i=d%28log%28log%28x%29%29%29%2Fdx) or
28 * [trigonometry](https://www.wolframalpha.com/input/?i=d%28sin%28cos%28x%29%29%29%2Fdx).
29 * Symbolic differentiation can give you expressions which are just as or more complicated
30 * than the original, and doing it by hand can be error prone. Symbolic Differentiation is
31 * also tricky to relate to algorithmic computations that use control structures.
32 *
33 * ## [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
34 *
35 * Automatic Differentiation computes the derivative of a function without rewriting
36 * the function as symbolic differentiation does and without the precision issues of numerical
37 * differentiation by splitting the derivative into lots of tiny steps of basic operations
38 * like addition and multiplication. These are combined using the chain rule. The downside
39 * is more memory is used than in symbolic or numerical differentiation, as derivatives have
40 * to be tracked through the computational graph.
41 *
42 * # Forward Differentiation
43 *
44 * Forward Differentiation computes all the gradients in a computational graph with respect
45 * to an input. For example, if you have a function f(x, y) = 5x<sup>3</sup> - 4x<sup>2</sup> +
46 * 10x - y, then for some actual value of x and y you can compute f(x,y) and δf(x,y)/δx
47 * together in one forward pass using forward differentiation. You can also make another pass
48 * and compute f(x,y) and δf(x,y)/δy for some actual value of x and y. It is possible to avoid
49 * redundantly calculating f(x,y) multiple times, but I am waiting on const generics to implement
50 * this. Regardless, forward differentiation requires making at least N+1 passes of the
51 * function to compute the derivatives of the output with respect to N inputs - and the current
52 * implementation will make 2N. However, you do get the gradients for every output in a
53 * single pass. This is poorly suited to neural nets as they often have a single output(loss)
54 * to differentiate many many inputs with respect to.
55 *
56 * # Reverse Mode Differentiation
57 *
58 * Reverse Mode Differentiation computes all the gradients in a computational graph for
59 * the same output. For example, if you have a function f(x, y) = 5x<sup>3</sup> -
60 * 4x<sup>2</sup> + 10x - y, then for some actual value of x and y you can compute f(x,y)
61 * and store all the intermediate results. You can then run a backward pass on the output
62 * of f(x, y) and obtain δf(x,y)/δx and δf(x,y)/δy for the actual values of x and y in a
63 * single pass. The catch is that reverse mode must store as many intermediate values as
64 * there are steps in the function which can use much more memory than forward mode.
65 * Reverse mode also requires making N backward passes to get the gradients for N different
66 * outputs. This is well suited to neural nets because we often have a single output (loss)
67 * to differentiate many inputs with respect to. However, reverse mode will be slower than
68 * forward mode if the number of inputs is small or there are many outputs.
69 *
70 * # Usage
71 *
72 * [See sub module for usage examples](usage)
73 *
74 * # Further information
75 *
76 * - [Automatic Differentiation Step by Step](https://medium.com/@marksaroufim/automatic-differentiation-step-by-step-24240f97a6e6)
77 * - [Forward Mode Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers)
78 * - [Reverse Mode Automatic Differentiation](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
79 * - [Automatic Differentiation: The most criminally underused tool in the potential machine learning toolbox?](https://justindomke.wordpress.com/2009/02/17/automatic-differentiation-the-most-criminally-underused-tool-in-the-potential-machine-learning-toolbox/)
80 * - [Yes you should understand backprop](https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b)
81 */
82
83mod container_record;
84mod functions;
85pub mod operations;
86pub mod record_operations;
87pub mod trace_operations;
88pub mod usage;
89
90pub use container_record::*;
91
92use crate::numeric::{Numeric, NumericRef};
93
94#[cfg(feature = "serde")]
95use serde::{Deserialize, Serialize};
96
97/**
98 * A trait with no methods which is implemented for all primitive types.
99 *
100 * Importantly this trait is not implemented for Traces (or Records), to stop the compiler
101 * from trying to evaluate nested Traces of Traces or Records of Records as Numeric types.
102 * There is no reason to create a Trace of a Trace or Record of a Record, it won't do
103 * anything a Trace or Record can't except use more memory.
104 *
105 * The boilerplate implementations for primitives is performed with a macro.
106 * If a primitive type is missing from this list, please open an issue to add it in.
107 */
108pub trait Primitive {}
109
110/**
111 * A dual number which traces a real number and keeps track of its derivative.
112 * This is used to perform Forward Automatic Differentiation
113 *
114 * Trace implements only first order differentiation. For example, given a function
115 * 3x<sup>2</sup>, you can use calculus to work out that its derivative with respect
116 * to x is 6x. You can also take the derivative of 6x with respect to x and work out
117 * that the second derivative is 6. By instead writing the function 3x<sup>2</sup> in
118 * code using Trace types as your numbers you can compute the first order derivative
119 * for a given value of x by passing your function `Trace { number: x, derivative: 1.0 }`.
120 *
121 * ```
122 * use easy_ml::differentiation::Trace;
123 * let x = Trace { number: 3.2, derivative: 1.0 };
124 * let dx = Trace::constant(3.0) * x * x;
125 * assert_eq!(dx.derivative, 3.2 * 6.0);
126 * ```
127 *
128 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
129 * differentiation.
130 *
131 * # Acknowledgments
132 *
133 * The wikipedia page on [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
134 * provided a very useful overview and explanation for understanding Forward Mode Automatic
135 * Differentiation as well as the implementation rules.
136 */
137#[derive(Debug)]
138#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
139pub struct Trace<T: Primitive> {
140 /**
141 * The real number
142 */
143 pub number: T,
144 /**
145 * The first order derivative of this number.
146 */
147 pub derivative: T,
148}
149
150/**
151 * The main set of methods for using Trace types for Forward Differentiation.
152 *
153 * The general steps are
154 * 1. create one variable
155 * 2. create as many constants as needed
156 * 3. do operations on the variable and constants
157 * 4. the outputs will have derivatives computed which can be accessed from
158 * the `.derivative` field, with each derivative being the output with respect
159 * to the input variable.
160 * 5. if you need derivatives for a different input then do everything all over again
161 * or do them all in parallel
162 */
163impl<T: Numeric + Primitive> Trace<T> {
164 /**
165 * Constants are lifted to Traces with a derivative of 0
166 *
167 * Why zero for the starting derivative? Because for any constant C
168 * δC/δx = 0, as with symbolic differentiation.
169 */
170 pub fn constant(c: T) -> Trace<T> {
171 Trace {
172 number: c,
173 derivative: T::zero(),
174 }
175 }
176
177 /**
178 * To lift a variable that you want to find the derivative of
179 * a function to, the Trace starts with a derivative of 1
180 *
181 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
182 * differentiation.
183 */
184 pub fn variable(x: T) -> Trace<T> {
185 Trace {
186 number: x,
187 derivative: T::one(),
188 }
189 }
190
191 /**
192 * Computes the derivative of a function with respect to its input x.
193 *
194 * This is a shorthand for `(function(Trace::variable(x))).derivative`
195 *
196 * In the more general case, if you provide a function with an input x
197 * and it returns N outputs y<sub>1</sub> to y<sub>N</sub> then you
198 * have computed all the derivatives δy<sub>i</sub>/δx for i = 1 to N.
199 */
200 pub fn derivative(function: impl FnOnce(Trace<T>) -> Trace<T>, x: T) -> T {
201 (function(Trace::variable(x))).derivative
202 }
203}
204
205impl<T: Numeric + Primitive> Trace<T>
206where
207 for<'a> &'a T: NumericRef<T>,
208{
209 /**
210 * Creates a new Trace from a reference to an existing Trace by applying
211 * some unary function to it which operates on the type the Trace wraps.
212 *
213 * To compute the new trace, the unary function of some input x to some
214 * output y is needed along with its derivative with respect to its input x.
215 *
216 * For example, tanh is a commonly used activation function, but the Real trait
217 * does not include this operation and Trace has no operations for it specifically.
218 * However, you can use this function to compute the tanh of a Trace like so:
219 *
220 * ```
221 * use easy_ml::differentiation::Trace;
222 * let x = Trace::variable(0.7f32);
223 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
224 * // 1 / (cosh(x) * cosh(x))
225 * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
226 * assert_eq!(y.derivative, 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
227 * ```
228 */
229 #[inline]
230 pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Trace<T> {
231 Trace {
232 number: fx(self.number.clone()),
233 derivative: self.derivative.clone() * dfx_dx(self.number.clone()),
234 }
235 }
236
237 /**
238 * Creates a new Trace from a reference to two existing Traces by applying
239 * some binary function to them which operates on two arguments of the type
240 * the Traces wrap.
241 *
242 * To compute the new trace, the binary function of some inputs x and y to some
243 * output z is needed along with its derivative with respect to its first input x and
244 * its derivative with respect to its second input y.
245 *
246 * For example, atan2 takes two arguments, but the Real trait
247 * does not include this operation and Trace has no operations for it specifically.
248 * However, you can use this function to compute the atan2 of two Traces like so:
249 *
250 * ```
251 * use easy_ml::differentiation::Trace;
252 * let x = Trace::variable(3.0f32);
253 * let y = Trace::variable(3.0f32);
254 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
255 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
256 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
257 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
258 * let z = x.binary(&y,
259 * |x, y| x.atan2(y),
260 * |x, y| y/((x*x) + (y*y)),
261 * |x, y| -x/((x*x) + (y*y))
262 * );
263 * ```
264 */
265 #[inline]
266 pub fn binary(
267 &self,
268 rhs: &Trace<T>,
269 fxy: impl Fn(T, T) -> T,
270 dfxy_dx: impl Fn(T, T) -> T,
271 dfxy_dy: impl Fn(T, T) -> T,
272 ) -> Trace<T> {
273 Trace {
274 number: fxy(self.number.clone(), rhs.number.clone()),
275 #[rustfmt::skip]
276 derivative: (
277 ((self.derivative.clone() * dfxy_dx(self.number.clone(), rhs.number.clone()))
278 + (rhs.derivative.clone() * dfxy_dy(self.number.clone(), rhs.number.clone())))
279 ),
280 }
281 }
282}
283
284use std::cell::RefCell;
285
286/**
287 * WengertLists are indexed with [`usize`].
288 */
289pub type Index = usize;
290
291/**
292 * A list of operations performed in a forward pass of a dynamic computational graph,
293 * used for Reverse Mode Automatic Differentiation.
294 *
295 * This is dynamic, as in, you build the [Wengert list](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation)
296 * at runtime by performing operations like addition and multiplication on
297 * [Records](Record) that were created with that Wengert list.
298 *
299 * When you perform a backward pass to obtain the gradients you travel back up the
300 * computational graph using the stored intermediate values from this list to compute
301 * all the gradients of the inputs and every intermediate step with respect to an output.
302 *
303 * Although sophisticated implementations can make the Wengert list only log(N) in length
304 * by storing only some of the intermediate steps of N computational steps, this implementation
305 * is not as sophisticated, and will store all of them.
306 *
307 * # Panics
308 *
309 * Every operation and nearly every method a Record has involves manipulating the
310 * record's history on its referenced WengertList. This WengertList itself maintains
311 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
312 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
313 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
314 * compiler correctly infers that it is not safe to share references to WengertLists between
315 * threads, nor transfer Records across threads. If you called a method on two Records that both
316 * mutably borrowed from the same WengertList at once, which could be trivially done with
317 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
318 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
319 * Record method call, and you can't call two methods simulatenously without threading.
320 */
321#[derive(Debug)]
322pub struct WengertList<T> {
323 // It is neccessary to wrap the vec in a RefCell to allow for mutating
324 // this list from immutable references held by each
325 operations: RefCell<Vec<Operation<T>>>,
326}
327
328struct BorrowedWengertList<'a, T> {
329 operations: &'a mut Vec<Operation<T>>,
330}
331
332/**
333 * A binary operation to record on a WengertList. For unary operations the
334 * right derivative is set to 0, and for nullary operations both derivatives
335 * are set to 0.
336 *
337 * Each operation acts like a node in an upside down binary tree, with two parents that
338 * each node was computed from. The main difference is that the numerical
339 * index of those parents in the WengertList is stored, rather than any pointers.
340 */
341#[derive(Debug)]
342struct Operation<T> {
343 left_parent: Index,
344 right_parent: Index,
345 left_derivative: T,
346 right_derivative: T,
347}
348
349/**
350 * Computed derivatives of a computational graph for some output [Record] variable.
351 *
352 * This can be indexed by any Record used in the computational graph to get
353 * the derivative with respect to that input.
354 *
355 * Indexing using Records not involved in the computational graph, or involved
356 * in a different one will return nonsense or index out of bounds and panic. In
357 * the future this may be changed to always panic.
358 */
359#[derive(Debug)]
360pub struct Derivatives<T> {
361 derivatives: Vec<T>,
362}
363
364/**
365 * Any derivatives of a Cloneable type implements clone
366 */
367impl<T: Clone> Clone for Derivatives<T> {
368 fn clone(&self) -> Self {
369 Derivatives {
370 derivatives: self.derivatives.clone(),
371 }
372 }
373}
374
375impl<T: Clone + Primitive> Derivatives<T> {
376 /**
377 * Quries the derivative at the provided record as input.
378 *
379 * If you construct a Derivatives object for some output y,
380 * and call .at(&x) on it for some input x, this returns dy/dx.
381 */
382 pub fn at(&self, input: &Record<T>) -> T {
383 self.derivatives[input.index].clone()
384 }
385}
386
387impl<'a, T: Primitive> std::ops::Index<&Record<'a, T>> for Derivatives<T> {
388 type Output = T;
389 /**
390 * Quries the derivative at the provided record as input.
391 *
392 * If you construct a Derivatives object for some output y,
393 * and call .at(&x) on it for some input x, this returns dy/dx.
394 */
395 fn index(&self, input: &Record<'a, T>) -> &Self::Output {
396 &self.derivatives[input.index]
397 }
398}
399
400impl<T> std::convert::From<Derivatives<T>> for Vec<T> {
401 /**
402 * Converts the Derivatives struct into a Vec of derivatives that
403 * can be indexed with `usize`s. The indexes correspond to the
404 * index field on Records.
405 */
406 fn from(derivatives: Derivatives<T>) -> Self {
407 derivatives.derivatives
408 }
409}
410
411/**
412 * Any operation of a Cloneable type implements clone
413 */
414impl<T: Clone + Primitive> Clone for Operation<T> {
415 fn clone(&self) -> Self {
416 Operation {
417 left_parent: self.left_parent,
418 right_parent: self.right_parent,
419 left_derivative: self.left_derivative.clone(),
420 right_derivative: self.right_derivative.clone(),
421 }
422 }
423}
424
425/**
426 * A wrapper around a real number which records it going through the computational
427 * graph. This is used to perform Reverse Mode Automatic Differentiation.
428 *
429 * # Panics
430 *
431 * Every operation and nearly every method a Record has involves manipulating the
432 * record's history on its referenced [WengertList]. This WengertList itself maintains
433 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
434 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
435 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
436 * compiler infers that it is not safe to share references to WengertLists between threads,
437 * nor transfer Records across threads. If you called a method on two Records that both
438 * mutably borrowed from the same WengertList at once, which could be trivially done with
439 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
440 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
441 * Record method call, and you can't call two methods simulatenously without threading.
442 *
443 * # Acknowledgments
444 *
445 * A [tutorial by Rufflewind](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
446 * and the associated [MIT licensed](http://opensource.org/licenses/MIT)
447 * [soure code](https://github.com/Rufflewind/revad/blob/master/src/tape.rs) were invaluable
448 * in providing understanding on how to implement Reverse Mode Automatic Differentiation.
449 */
450#[derive(Debug)]
451pub struct Record<'a, T: Primitive> {
452 // A record consists of a number used in the forward pass, as
453 // well as a WengertList of operations performed on the numbers
454 // and each record needs to know which point in the history they
455 // are for.
456 /**
457 * The real number
458 */
459 pub number: T,
460 history: Option<&'a WengertList<T>>,
461 /**
462 * The index of this number in its [WengertList]. The first entry will be 0,
463 * the next 1 and so on.
464 *
465 * In normal use cases you should not need to read this field,
466 * you can index [Derivatives] directly with Records.
467 */
468 pub index: Index,
469}
470
471/**
472 * The main set of methods for using Record types for Reverse Differentiation.
473 *
474 * The general steps are
475 * 1. create a `WengertList`
476 * 2. create variables from this list
477 * 3. do operations on the variables
478 * 4. from the output you want to compute derivatives for call `.derivatives()`
479 * 5. index the `Derivatives` object with the index variables to get the derivatives
480 * with respect to each input
481 * 6. if you want to make another pass call `clear()` on the `WengertList`
482 * and then call `reset()` on all of the variables to forget the gradients already
483 * computed (the order of `clear` then `reset` is very important!).
484 *
485 * Constants can be used to save memory if you have numbers that
486 * you do not need to compute the gradients with respect to.
487 */
488impl<'a, T: Numeric + Primitive> Record<'a, T> {
489 /**
490 * Creates an untracked Record which has no backing WengertList.
491 *
492 * This is provided for using constants along with Records in operations.
493 *
494 * For example with y = x + 4 the computation graph could be conceived as
495 * a y node with parent nodes of x and 4 combined with the operation +.
496 * However there is no need to record the derivatives of a constant, so
497 * instead the computation graph can be conceived as a y node with a single
498 * parent node of x and the unary operation of +4.
499 *
500 * This is also used for the type level constructors required by Numeric
501 * which are also considered constants.
502 */
503 pub fn constant(c: T) -> Record<'a, T> {
504 Record {
505 number: c,
506 history: None,
507 index: 0,
508 }
509 }
510
511 /**
512 * Creates a record backed by the provided WengertList.
513 *
514 * The record cannot live longer than the WengertList, hence
515 * the following example does not compile
516 *
517 * ```compile_fail
518 * use easy_ml::differentiation::Record;
519 * use easy_ml::differentiation::WengertList;
520 * let record = {
521 * let list = WengertList::new();
522 * Record::variable(1.0, &list)
523 * }; // list no longer in scope
524 * ```
525 *
526 * You can alternatively use the [record constructor on the WengertList type](WengertList::variable()).
527 */
528 pub fn variable(x: T, history: &'a WengertList<T>) -> Record<'a, T> {
529 Record {
530 number: x,
531 history: Some(history),
532 index: history.append_nullary(),
533 }
534 }
535
536 /**
537 * Creates a record from a constant/variable directly, most likely obtained by getting a
538 * tensor view of an existing [container](RecordContainer). **The inputs are not checked for
539 * validity**. It is possible to pass in the wrong Wengert list here or even numbers with
540 * indexes that aren’t tracked on the WengertList.
541 *
542 * It is recommended to use this constructor only in conjunction with manipulating an existing
543 * container and not for creating new variables. Any variables created outside of
544 * `Record::variable` / `RecordContainer::variables` would have to be manually added to the
545 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
546 */
547 pub fn from_existing(number: (T, Index), history: Option<&'a WengertList<T>>) -> Record<'a, T> {
548 Record {
549 number: number.0,
550 history,
551 index: number.1,
552 }
553 }
554
555 /**
556 * Resets this Record to place it back on its WengertList, for use
557 * in performing another derivation after clearing the WengertList.
558 */
559 pub fn reset(&mut self) {
560 match self.history {
561 None => (), // noop
562 Some(history) => self.index = history.append_nullary(),
563 };
564 }
565
566 /**
567 * A convenience helper function which takes a Record by value and
568 * calls [reset](Record::reset()) on it.
569 */
570 pub fn do_reset(mut x: Record<T>) -> Record<T> {
571 x.reset();
572 x
573 }
574
575 /**
576 * Gets the WengertList this Record is backed by if a variable, and [None] if a constant.
577 */
578 pub fn history(&self) -> Option<&'a WengertList<T>> {
579 self.history
580 }
581}
582
583impl<'a, T: Numeric + Primitive> Record<'a, T>
584where
585 for<'t> &'t T: NumericRef<T>,
586{
587 /**
588 * Performs a backward pass up this record's WengertList from this
589 * record as the output, computing all the derivatives for the inputs
590 * involving this output.
591 *
592 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
593 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
594 *
595 * # Panics
596 *
597 * Panics if the Record has no backing WengertList, ie it was created as a
598 * constant.
599 */
600 #[track_caller]
601 pub fn derivatives(&self) -> Derivatives<T> {
602 match self.try_derivatives() {
603 None => panic!("Record has no WengertList to find derivatives from"),
604 Some(d) => d,
605 }
606 }
607
608 /**
609 * Performs a backward pass up this record's WengertList from this
610 * record as the output, computing all the derivatives for the inputs
611 * involving this output.
612 *
613 * If this record has no WengertList, ie it's a constant, None is returned instead.
614 *
615 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
616 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
617 */
618 pub fn try_derivatives(&self) -> Option<Derivatives<T>> {
619 let history = self.history?;
620 let operations = history.operations.borrow();
621
622 let mut derivatives = vec![T::zero(); operations.len()];
623
624 // δy/δy = 1
625 derivatives[self.index] = T::one();
626
627 // Go back up the computation graph to the inputs
628 for i in (0..operations.len()).rev() {
629 let operation = operations[i].clone();
630 let derivative = derivatives[i].clone();
631 // The chain rule allows breaking up the derivative of the output y
632 // with respect to the input x into many smaller derivatives that
633 // are summed together.
634 // δy/δx = δy/δw * δw/δx
635 // δy/δx = sum for all i parents of y ( δy/δw_i * δw_i/δx )
636 derivatives[operation.left_parent] = derivatives[operation.left_parent].clone()
637 + derivative.clone() * operation.left_derivative;
638 derivatives[operation.right_parent] = derivatives[operation.right_parent].clone()
639 + derivative * operation.right_derivative;
640 }
641
642 Some(Derivatives { derivatives })
643 }
644}
645
646impl<T: Primitive> WengertList<T> {
647 /**
648 * Creates a new empty WengertList from which Records can be constructed.
649 */
650 pub fn new() -> WengertList<T> {
651 WengertList {
652 operations: RefCell::new(Vec::new()),
653 }
654 }
655}
656
657impl<T: Primitive> Default for WengertList<T> {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663impl<T> WengertList<T> {
664 /**
665 * Clears a WengertList to make it empty again. After clearing a WengertList
666 * you must reset all the Records still using that list. Then you can perform
667 * another computation and get new gradients.
668 */
669 pub fn clear(&self) {
670 self.operations.borrow_mut().clear();
671 }
672}
673
674impl<T: Numeric + Primitive> WengertList<T> {
675 /**
676 * Creates a record backed by this WengertList.
677 *
678 * You can alternatively use the [record constructor on the Record type](Record::variable()).
679 */
680 pub fn variable(&self, x: T) -> Record<T> {
681 Record {
682 number: x,
683 history: Some(self),
684 index: self.append_nullary(),
685 }
686 }
687
688 /**
689 * Adds a value to the list which does not have any parent values.
690 */
691 fn append_nullary(&self) -> Index {
692 use std::ops::DerefMut;
693 let mut borrow = self.operations.borrow_mut();
694 BorrowedWengertList::new(borrow.deref_mut()).append_nullary()
695 }
696
697 /**
698 * Adds a number of values to the list which do not have any parent values, returning
699 * the index of the first added value, the others will be contiguously afterwards.
700 *
701 * If values is 0, returns the first index that would be used but wasn't.
702 */
703 fn append_nullary_repeating(&self, values: usize) -> Index {
704 let mut operations = self.operations.borrow_mut();
705 // insert into end of list
706 let starting_index = operations.len();
707 for i in 0..values {
708 let index = starting_index + i;
709 operations.push(Operation {
710 // this index of the child is used for both indexes as these
711 // won't be needed but will always be valid (ie point to a
712 // real entry on the list)
713 left_parent: index,
714 right_parent: index,
715 // for the parents 0 is used to zero out these calculations
716 // as there are no parents
717 left_derivative: T::zero(),
718 right_derivative: T::zero(),
719 });
720 }
721 starting_index
722 }
723
724 /**
725 * Adds a value to the list which has one parent.
726 *
727 * For an output w_N which depends on one parent w_N-1
728 * the derivative cached here is δw_N / δw_N-1
729 *
730 * For example, if z = sin(x), then δz/δx = cos(x)
731 */
732 fn append_unary(&self, parent: Index, derivative: T) -> Index {
733 use std::ops::DerefMut;
734 let mut borrow = self.operations.borrow_mut();
735 BorrowedWengertList::new(borrow.deref_mut()).append_unary(parent, derivative)
736 }
737
738 /**
739 * Adds a value to the list which has two parents.
740 *
741 * For an output w_N which depends on two parents w_N-1
742 * and w_N-2 the derivatives cached here are δw_N / δw_N-1
743 * and δw_N / δw_N-2.
744 *
745 * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
746 * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
747 */
748 fn append_binary(
749 &self,
750 left_parent: Index,
751 left_derivative: T,
752 right_parent: Index,
753 right_derivative: T,
754 ) -> Index {
755 use std::ops::DerefMut;
756 let mut borrow = self.operations.borrow_mut();
757 BorrowedWengertList::new(borrow.deref_mut()).append_binary(
758 left_parent,
759 left_derivative,
760 right_parent,
761 right_derivative,
762 )
763 }
764
765 /**
766 * Borrows the WengertList mutably for batch operations. It is *very* important to
767 * hold onto the borrow only for as long as needed then drop it immediately. To avoid panics
768 * Easy ML needs to ensure 100% of method calls on the public API do not maintain a borrow
769 * after they finish executing. This was previously enforced by not having any batch
770 * append APIs, but they're needed for RecordContainer. Calling borrow again while still
771 * holding the first would trigger a panic, as would holding onto the borrow after the public
772 * API method is finished
773 */
774 fn borrow<F>(&self, op: F)
775 where
776 F: FnOnce(&mut BorrowedWengertList<T>),
777 {
778 use std::ops::DerefMut;
779 let mut borrow = self.operations.borrow_mut();
780 op(&mut BorrowedWengertList::new(borrow.deref_mut()));
781 }
782}
783
784/**
785 * Any Wengert list of a Cloneable type implements clone
786 */
787impl<T: Clone + Primitive> Clone for WengertList<T> {
788 fn clone(&self) -> Self {
789 WengertList {
790 operations: RefCell::new(self.operations.borrow().clone()),
791 }
792 }
793}
794
795/**
796 * Methods for appending Operations after borrowing the Wengert list.
797 */
798impl<'a, T: Numeric + Primitive> BorrowedWengertList<'a, T> {
799 fn new(operations: &mut Vec<Operation<T>>) -> BorrowedWengertList<T> {
800 BorrowedWengertList { operations }
801 }
802
803 /**
804 * Adds a value to the list which does not have any parent values.
805 */
806 fn append_nullary(&mut self) -> Index {
807 // insert into end of list
808 let index = self.operations.len();
809 self.operations.push(Operation {
810 // this index of the child is used for both indexes as these
811 // won't be needed but will always be valid (ie point to a
812 // real entry on the list)
813 left_parent: index,
814 right_parent: index,
815 // for the parents 0 is used to zero out these calculations
816 // as there are no parents
817 left_derivative: T::zero(),
818 right_derivative: T::zero(),
819 });
820 index
821 }
822
823 /**
824 * Adds a value to the list which has one parent.
825 *
826 * For an output w_N which depends on one parent w_N-1
827 * the derivative cached here is δw_N / δw_N-1
828 *
829 * For example, if z = sin(x), then δz/δx = cos(x)
830 */
831 fn append_unary(&mut self, parent: Index, derivative: T) -> Index {
832 // insert into end of list
833 let index = self.operations.len();
834 self.operations.push(Operation {
835 left_parent: parent,
836 // this index of the child is used as this index won't be needed
837 // but will always be valid (ie points to a real entry on the list)
838 right_parent: index,
839 left_derivative: derivative,
840 // for the right parent 0 is used to zero out this calculation
841 // as there is no right parent
842 right_derivative: T::zero(),
843 });
844 index
845 }
846
847 /**
848 * Adds a value to the list which has two parents.
849 *
850 * For an output w_N which depends on two parents w_N-1
851 * and w_N-2 the derivatives cached here are δw_N / δw_N-1
852 * and δw_N / δw_N-2.
853 *
854 * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
855 * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
856 */
857 fn append_binary(
858 &mut self,
859 left_parent: Index,
860 left_derivative: T,
861 right_parent: Index,
862 right_derivative: T,
863 ) -> Index {
864 // insert into end of list
865 let index = self.operations.len();
866 self.operations.push(Operation {
867 left_parent,
868 right_parent,
869 left_derivative,
870 right_derivative,
871 });
872 index
873 }
874}
875
876impl<'a, T: Numeric + Primitive> Record<'a, T>
877where
878 for<'t> &'t T: NumericRef<T>,
879{
880 /**
881 * Creates a new Record from a reference to an existing Record by applying
882 * some unary function to it which operates on the type the Record wraps.
883 *
884 * To compute the new record, the unary function of some input x to some
885 * output y is needed along with its derivative with respect to its input x.
886 *
887 * For example, tanh is a commonly used activation function, but the Real trait
888 * does not include this operation and Record has no operations for it specifically.
889 * However, you can use this function to compute the tanh of a Record like so:
890 *
891 * ```
892 * use easy_ml::differentiation::{Record, WengertList};
893 * let list = WengertList::new();
894 * let x = Record::variable(0.7f32, &list);
895 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
896 * // 1 / (cosh(x) * cosh(x))
897 * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
898 * assert_eq!(y.derivatives()[&x], 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
899 * ```
900 */
901 #[inline]
902 pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Record<T> {
903 match self.history {
904 None => Record {
905 number: fx(self.number.clone()),
906 history: None,
907 index: 0,
908 },
909 Some(history) => Record {
910 number: fx(self.number.clone()),
911 history: Some(history),
912 index: history.append_unary(self.index, dfx_dx(self.number.clone())),
913 },
914 }
915 }
916
917 /**
918 * Creates a new Record from a reference to two existing Records by applying
919 * some binary function to them which operates on two arguments of the type
920 * the Records wrap.
921 *
922 * To compute the new record, the binary function of some inputs x and y to some
923 * output z is needed along with its derivative with respect to its first input x and
924 * its derivative with respect to its second input y.
925 *
926 * For example, atan2 takes two arguments, but the Real trait
927 * does not include this operation and Record has no operations for it specifically.
928 * However, you can use this function to compute the atan2 of two Records like so:
929 *
930 * ```
931 * use easy_ml::differentiation::{Record, WengertList};
932 * let list = WengertList::new();
933 * let x = Record::variable(3.0f32, &list);
934 * let y = Record::variable(3.0f32, &list);
935 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
936 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
937 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
938 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
939 * let z = x.binary(&y,
940 * |x, y| x.atan2(y),
941 * |x, y| y/((x*x) + (y*y)),
942 * |x, y| -x/((x*x) + (y*y))
943 * );
944 * let derivatives = z.derivatives();
945 * let dx = derivatives[&x];
946 * let dy = derivatives[&y];
947 * ```
948 */
949 #[inline]
950 #[track_caller]
951 pub fn binary(
952 &self,
953 rhs: &Record<'a, T>,
954 fxy: impl Fn(T, T) -> T,
955 dfxy_dx: impl Fn(T, T) -> T,
956 dfxy_dy: impl Fn(T, T) -> T,
957 ) -> Record<T> {
958 assert!(
959 record_operations::same_list(self, rhs),
960 "Records must be using the same WengertList"
961 );
962 match (self.history, rhs.history) {
963 (None, None) => Record {
964 number: fxy(self.number.clone(), rhs.number.clone()),
965 history: None,
966 index: 0,
967 },
968 (Some(history), None) => Record {
969 number: fxy(self.number.clone(), rhs.number.clone()),
970 history: Some(history),
971 index: history.append_unary(
972 // if rhs didn't have a history, don't track that derivative
973 self.index,
974 dfxy_dx(self.number.clone(), rhs.number.clone()),
975 ),
976 },
977 (None, Some(history)) => Record {
978 number: fxy(self.number.clone(), rhs.number.clone()),
979 history: Some(history),
980 index: history.append_unary(
981 // if self didn't have a history, don't track that derivative
982 rhs.index,
983 dfxy_dy(self.number.clone(), rhs.number.clone()),
984 ),
985 },
986 (Some(history), Some(_)) => Record {
987 number: fxy(self.number.clone(), rhs.number.clone()),
988 history: Some(history),
989 index: history.append_binary(
990 self.index,
991 dfxy_dx(self.number.clone(), rhs.number.clone()),
992 rhs.index,
993 dfxy_dy(self.number.clone(), rhs.number.clone()),
994 ),
995 },
996 }
997 }
998}
999
1000#[cfg(test)]
1001#[should_panic]
1002#[test]
1003fn test_record_derivatives_when_no_history() {
1004 let record = Record::constant(1.0);
1005 record.derivatives();
1006}
1007
1008#[test]
1009fn test_sync() {
1010 fn assert_sync<T: Sync>() {}
1011 assert_sync::<Trace<f64>>();
1012}
1013
1014#[test]
1015fn test_send() {
1016 fn assert_send<T: Send>() {}
1017 assert_send::<Trace<f64>>();
1018}