easy_ml/differentiation.rs
1#![allow(clippy::double_parens)]
2/*!
3 * (Automatic) Differentiation helpers
4 *
5 * # Automatic Differentiation
6 *
7 * This module provides structs for performing Forward and Reverse Automatic Differentiation
8 *
9 * ## Automatic Differentiation is not [Numerical Differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation)
10 *
11 * You were probably introduced to differentiation as numeric differentiation,
12 * ie if you have a function 3x<sup>2</sup> then you can estimate its gradient
13 * at some value x by computing 3x<sup>2</sup> and 3(x+h)<sup>2</sup> where h
14 * is a very small value. The tangent line these two points create gives you an approximation
15 * of the gradient when you calculate (f(x+h) - f(x)) / h. Unfortunately floating
16 * point numbers in computers have limited precision, so this method is only approximate
17 * and can result in floating point errors. 1 + 1 might equal 2 but as you go smaller
18 * 10<sup>-i</sup> + 10<sup>-i</sup> starts to loook rather like 10<sup>-i</sup> as i goes
19 * into double digits.
20 *
21 * ## Automatic Differentiation is not Symbolic Differentiation
22 *
23 * If you were taught calculus you have probably done plenty of symbolic differentiation
24 * by hand. A function 3x<sup>2</sup> can be symbolically differentiated into 6x by applying
25 * simple rules to manipulate the algebra. Unfortunately the rules aren't so simple for
26 * more complex expressions such as [exponents](https://www.wolframalpha.com/input/?i=d%28x%5Ee%5E2%29%2Fdx),
27 * [logs](https://www.wolframalpha.com/input/?i=d%28log%28log%28x%29%29%29%2Fdx) or
28 * [trigonometry](https://www.wolframalpha.com/input/?i=d%28sin%28cos%28x%29%29%29%2Fdx).
29 * Symbolic differentiation can give you expressions which are just as or more complicated
30 * than the original, and doing it by hand can be error prone. Symbolic Differentiation is
31 * also tricky to relate to algorithmic computations that use control structures.
32 *
33 * ## [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
34 *
35 * Automatic Differentiation computes the derivative of a function without rewriting
36 * the function as symbolic differentiation does and without the precision issues of numerical
37 * differentiation by splitting the derivative into lots of tiny steps of basic operations
38 * like addition and multiplication. These are combined using the chain rule. The downside
39 * is more memory is used than in symbolic or numerical differentiation, as derivatives have
40 * to be tracked through the computational graph.
41 *
42 * # Forward Differentiation
43 *
44 * Forward Differentiation computes all the gradients in a computational graph with respect
45 * to an input. For example, if you have a function f(x, y) = 5x<sup>3</sup> - 4x<sup>2</sup> +
46 * 10x - y, then for some actual value of x and y you can compute f(x,y) and δf(x,y)/δx
47 * together in one forward pass using forward differentiation. You can also make another pass
48 * and compute f(x,y) and δf(x,y)/δy for some actual value of x and y. Forward differentiation
49 * in this way requires making 2N passes of the function to compute the derivatives of the output
50 * with respect to N inputs. However, you do get the gradients for every output in a single pass
51 * This is poorly suited to neural nets as they often have a single output(loss)
52 * to differentiate many many inputs with respect to.
53 *
54 * # Reverse Mode Differentiation
55 *
56 * Reverse Mode Differentiation computes all the gradients in a computational graph for
57 * the same output. For example, if you have a function f(x, y) = 5x<sup>3</sup> -
58 * 4x<sup>2</sup> + 10x - y, then for some actual value of x and y you can compute f(x,y)
59 * and store all the intermediate results. You can then run a backward pass on the output
60 * of f(x, y) and obtain δf(x,y)/δx and δf(x,y)/δy for the actual values of x and y in a
61 * single pass. The catch is that reverse mode must store as many intermediate values as
62 * there are steps in the function which can use much more memory than forward mode.
63 * Reverse mode also requires making N backward passes to get the gradients for N different
64 * outputs. This is well suited to neural nets because we often have a single output (loss)
65 * to differentiate many inputs with respect to. However, reverse mode will be slower than
66 * forward mode if the number of inputs is small or there are many outputs.
67 *
68 * # Usage
69 *
70 * [See sub module for usage examples](usage)
71 *
72 * # Further information
73 *
74 * - [Automatic Differentiation Step by Step](https://medium.com/@marksaroufim/automatic-differentiation-step-by-step-24240f97a6e6)
75 * - [Forward Mode Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers)
76 * - [Reverse Mode Automatic Differentiation](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
77 * - [Automatic Differentiation: The most criminally underused tool in the potential machine learning toolbox?](https://justindomke.wordpress.com/2009/02/17/automatic-differentiation-the-most-criminally-underused-tool-in-the-potential-machine-learning-toolbox/)
78 * - [Yes you should understand backprop](https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b)
79 */
80
81mod container_record;
82mod functions;
83pub mod operations;
84pub mod record_operations;
85pub mod trace_operations;
86pub mod usage;
87
88pub use container_record::*;
89
90use crate::numeric::{Numeric, NumericRef};
91
92#[cfg(feature = "serde")]
93use serde::{Deserialize, Serialize};
94
95/**
96 * A trait with no methods which is implemented for all primitive types.
97 *
98 * Importantly this trait is not implemented for Traces (or Records), to stop the compiler
99 * from trying to evaluate nested Traces of Traces or Records of Records as Numeric types.
100 * There is no reason to create a Trace of a Trace or Record of a Record, it won't do
101 * anything a Trace or Record can't except use more memory.
102 *
103 * The boilerplate implementations for primitives is performed with a macro.
104 * If a primitive type is missing from this list, please open an issue to add it in.
105 */
106pub trait Primitive {}
107
108/**
109 * A dual number which traces a real number and keeps track of its derivative.
110 * This is used to perform Forward Automatic Differentiation
111 *
112 * Trace implements only first order differentiation. For example, given a function
113 * 3x<sup>2</sup>, you can use calculus to work out that its derivative with respect
114 * to x is 6x. You can also take the derivative of 6x with respect to x and work out
115 * that the second derivative is 6. By instead writing the function 3x<sup>2</sup> in
116 * code using Trace types as your numbers you can compute the first order derivative
117 * for a given value of x by passing your function `Trace { number: x, derivative: 1.0 }`.
118 *
119 * ```
120 * use easy_ml::differentiation::Trace;
121 * let x = Trace { number: 3.2, derivative: 1.0 };
122 * let dx = Trace::constant(3.0) * x * x;
123 * assert_eq!(dx.derivative, 3.2 * 6.0);
124 * ```
125 *
126 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
127 * differentiation.
128 *
129 * # Acknowledgments
130 *
131 * The wikipedia page on [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
132 * provided a very useful overview and explanation for understanding Forward Mode Automatic
133 * Differentiation as well as the implementation rules.
134 */
135#[derive(Debug)]
136#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137pub struct Trace<T: Primitive> {
138 /**
139 * The real number
140 */
141 pub number: T,
142 /**
143 * The first order derivative of this number.
144 */
145 // If we loosen this type from T to a tensor of T of some const-generic
146 // dimensionality then we can calculate higher order derivatives with a single
147 // Trace type.
148 // However, Trace<Trace<f64>> can do such a calculation already for 2nd order
149 // (and so on) and requires far less complexity in the API so this might not
150 // be that worthwhile. Tensor<T, 1> introduces a lot of boxing that might also
151 // hurt first order performance.
152 pub derivative: T,
153}
154
155/**
156 * The main set of methods for using Trace types for Forward Differentiation.
157 *
158 * The general steps are
159 * 1. create one variable
160 * 2. create as many constants as needed
161 * 3. do operations on the variable and constants
162 * 4. the outputs will have derivatives computed which can be accessed from
163 * the `.derivative` field, with each derivative being the output with respect
164 * to the input variable.
165 * 5. if you need derivatives for a different input then do everything all over again
166 * or do them all in parallel
167 */
168impl<T: Numeric + Primitive> Trace<T> {
169 /**
170 * Constants are lifted to Traces with a derivative of 0
171 *
172 * Why zero for the starting derivative? Because for any constant C
173 * δC/δx = 0, as with symbolic differentiation.
174 */
175 pub fn constant(c: T) -> Trace<T> {
176 Trace {
177 number: c,
178 derivative: T::zero(),
179 }
180 }
181
182 /**
183 * To lift a variable that you want to find the derivative of
184 * a function to, the Trace starts with a derivative of 1
185 *
186 * Why the one for the starting derivative? Because δx/δx = 1, as with symbolic
187 * differentiation.
188 */
189 pub fn variable(x: T) -> Trace<T> {
190 Trace {
191 number: x,
192 derivative: T::one(),
193 }
194 }
195
196 /**
197 * Computes the derivative of a function with respect to its input x.
198 *
199 * This is a shorthand for `(function(Trace::variable(x))).derivative`
200 *
201 * In the more general case, if you provide a function with an input x
202 * and it returns N outputs y<sub>1</sub> to y<sub>N</sub> then you
203 * have computed all the derivatives δy<sub>i</sub>/δx for i = 1 to N.
204 */
205 pub fn derivative(function: impl FnOnce(Trace<T>) -> Trace<T>, x: T) -> T {
206 (function(Trace::variable(x))).derivative
207 }
208}
209
210impl<T: Numeric + Primitive> Trace<T>
211where
212 for<'a> &'a T: NumericRef<T>,
213{
214 /**
215 * Creates a new Trace from a reference to an existing Trace by applying
216 * some unary function to it which operates on the type the Trace wraps.
217 *
218 * To compute the new trace, the unary function of some input x to some
219 * output y is needed along with its derivative with respect to its input x.
220 *
221 * For example, tanh is a commonly used activation function, but the Real trait
222 * does not include this operation and Trace has no operations for it specifically.
223 * However, you can use this function to compute the tanh of a Trace like so:
224 *
225 * ```
226 * use easy_ml::differentiation::Trace;
227 * let x = Trace::variable(0.7f32);
228 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
229 * // 1 / (cosh(x) * cosh(x))
230 * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
231 * assert_eq!(y.derivative, 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
232 * ```
233 */
234 #[inline]
235 pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Trace<T> {
236 Trace {
237 number: fx(self.number.clone()),
238 derivative: self.derivative.clone() * dfx_dx(self.number.clone()),
239 }
240 }
241
242 /**
243 * Creates a new Trace from a reference to two existing Traces by applying
244 * some binary function to them which operates on two arguments of the type
245 * the Traces wrap.
246 *
247 * To compute the new trace, the binary function of some inputs x and y to some
248 * output z is needed along with its derivative with respect to its first input x and
249 * its derivative with respect to its second input y.
250 *
251 * For example, atan2 takes two arguments, but the Real trait
252 * does not include this operation and Trace has no operations for it specifically.
253 * However, you can use this function to compute the atan2 of two Traces like so:
254 *
255 * ```
256 * use easy_ml::differentiation::Trace;
257 * let x = Trace::variable(3.0f32);
258 * let y = Trace::variable(3.0f32);
259 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
260 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
261 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
262 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
263 * let z = x.binary(&y,
264 * |x, y| x.atan2(y),
265 * |x, y| y/((x*x) + (y*y)),
266 * |x, y| -x/((x*x) + (y*y))
267 * );
268 * ```
269 */
270 #[inline]
271 pub fn binary(
272 &self,
273 rhs: &Trace<T>,
274 fxy: impl Fn(T, T) -> T,
275 dfxy_dx: impl Fn(T, T) -> T,
276 dfxy_dy: impl Fn(T, T) -> T,
277 ) -> Trace<T> {
278 Trace {
279 number: fxy(self.number.clone(), rhs.number.clone()),
280 #[rustfmt::skip]
281 derivative: (
282 ((self.derivative.clone() * dfxy_dx(self.number.clone(), rhs.number.clone()))
283 + (rhs.derivative.clone() * dfxy_dy(self.number.clone(), rhs.number.clone())))
284 ),
285 }
286 }
287}
288
289use std::cell::RefCell;
290
291/**
292 * WengertLists are indexed with [`usize`].
293 */
294pub type Index = usize;
295
296/**
297 * A list of operations performed in a forward pass of a dynamic computational graph,
298 * used for Reverse Mode Automatic Differentiation.
299 *
300 * This is dynamic, as in, you build the [Wengert list](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation)
301 * at runtime by performing operations like addition and multiplication on
302 * [Records](Record) that were created with that Wengert list.
303 *
304 * When you perform a backward pass to obtain the gradients you travel back up the
305 * computational graph using the stored intermediate values from this list to compute
306 * all the gradients of the inputs and every intermediate step with respect to an output.
307 *
308 * Although sophisticated implementations can make the Wengert list only log(N) in length
309 * by storing only some of the intermediate steps of N computational steps, this implementation
310 * is not as sophisticated, and will store all of them.
311 *
312 * # Panics
313 *
314 * Every operation and nearly every method a Record has involves manipulating the
315 * record's history on its referenced WengertList. This WengertList itself maintains
316 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
317 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
318 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
319 * compiler correctly infers that it is not safe to share references to WengertLists between
320 * threads, nor transfer Records across threads. If you called a method on two Records that both
321 * mutably borrowed from the same WengertList at once, which could be trivially done with
322 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
323 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
324 * Record method call, and you can't call two methods simulatenously without threading.
325 */
326#[derive(Debug)]
327pub struct WengertList<T> {
328 // It is neccessary to wrap the vec in a RefCell to allow for mutating
329 // this list from immutable references held by each
330 operations: RefCell<Vec<Operation<T>>>,
331}
332
333struct BorrowedWengertList<'a, T> {
334 operations: &'a mut Vec<Operation<T>>,
335}
336
337/**
338 * A binary operation to record on a WengertList. For unary operations the
339 * right derivative is set to 0, and for nullary operations both derivatives
340 * are set to 0.
341 *
342 * Each operation acts like a node in an upside down binary tree, with two parents that
343 * each node was computed from. The main difference is that the numerical
344 * index of those parents in the WengertList is stored, rather than any pointers.
345 */
346#[derive(Debug)]
347struct Operation<T> {
348 left_parent: Index,
349 right_parent: Index,
350 left_derivative: T,
351 right_derivative: T,
352}
353
354/**
355 * Computed derivatives of a computational graph for some output [Record] variable.
356 *
357 * This can be indexed by any Record used in the computational graph to get
358 * the derivative with respect to that input.
359 *
360 * Indexing using Records not involved in the computational graph, or involved
361 * in a different one will return nonsense or index out of bounds and panic. In
362 * the future this may be changed to always panic.
363 */
364#[derive(Debug)]
365pub struct Derivatives<T> {
366 derivatives: Vec<T>,
367}
368
369/**
370 * Any derivatives of a Cloneable type implements clone
371 */
372impl<T: Clone> Clone for Derivatives<T> {
373 fn clone(&self) -> Self {
374 Derivatives {
375 derivatives: self.derivatives.clone(),
376 }
377 }
378}
379
380impl<T: Clone + Primitive> Derivatives<T> {
381 /**
382 * Quries the derivative at the provided record as input.
383 *
384 * If you construct a Derivatives object for some output y,
385 * and call .at(&x) on it for some input x, this returns dy/dx.
386 */
387 pub fn at(&self, input: &Record<T>) -> T {
388 self.derivatives[input.index].clone()
389 }
390}
391
392impl<'a, T: Primitive> std::ops::Index<&Record<'a, T>> for Derivatives<T> {
393 type Output = T;
394 /**
395 * Quries the derivative at the provided record as input.
396 *
397 * If you construct a Derivatives object for some output y,
398 * and call .at(&x) on it for some input x, this returns dy/dx.
399 */
400 fn index(&self, input: &Record<'a, T>) -> &Self::Output {
401 &self.derivatives[input.index]
402 }
403}
404
405impl<T> std::convert::From<Derivatives<T>> for Vec<T> {
406 /**
407 * Converts the Derivatives struct into a Vec of derivatives that
408 * can be indexed with `usize`s. The indexes correspond to the
409 * index field on Records.
410 */
411 fn from(derivatives: Derivatives<T>) -> Self {
412 derivatives.derivatives
413 }
414}
415
416/**
417 * Any operation of a Cloneable type implements clone
418 */
419impl<T: Clone + Primitive> Clone for Operation<T> {
420 fn clone(&self) -> Self {
421 Operation {
422 left_parent: self.left_parent,
423 right_parent: self.right_parent,
424 left_derivative: self.left_derivative.clone(),
425 right_derivative: self.right_derivative.clone(),
426 }
427 }
428}
429
430/**
431 * A wrapper around a real number which records it going through the computational
432 * graph. This is used to perform Reverse Mode Automatic Differentiation.
433 *
434 * # Panics
435 *
436 * Every operation and nearly every method a Record has involves manipulating the
437 * record's history on its referenced [WengertList]. This WengertList itself maintains
438 * a [RefCell] which tracks borrows at runtime rather than compile time. This is neccessary to
439 * maintain the illusion that Records are just ordinary numbers, and the side effects of doing
440 * arithmetic with Records are limited to their referenced WengertList. Hence, the Rust
441 * compiler infers that it is not safe to share references to WengertLists between threads,
442 * nor transfer Records across threads. If you called a method on two Records that both
443 * mutably borrowed from the same WengertList at once, which could be trivially done with
444 * multiple threads, then the code would panic. Easy ML shouldn't allow you to do this
445 * in safe Rust because each mutable borrow of the WengertList is dropped at the end of each
446 * Record method call, and you can't call two methods simulatenously without threading.
447 *
448 * # Acknowledgments
449 *
450 * A [tutorial by Rufflewind](https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation)
451 * and the associated [MIT licensed](http://opensource.org/licenses/MIT)
452 * [soure code](https://github.com/Rufflewind/revad/blob/master/src/tape.rs) were invaluable
453 * in providing understanding on how to implement Reverse Mode Automatic Differentiation.
454 */
455#[derive(Debug)]
456pub struct Record<'a, T: Primitive> {
457 // A record consists of a number used in the forward pass, as
458 // well as a WengertList of operations performed on the numbers
459 // and each record needs to know which point in the history they
460 // are for.
461 /**
462 * The real number
463 */
464 pub number: T,
465 history: Option<&'a WengertList<T>>,
466 /**
467 * The index of this number in its [WengertList]. The first entry will be 0,
468 * the next 1 and so on.
469 *
470 * In normal use cases you should not need to read this field,
471 * you can index [Derivatives] directly with Records.
472 */
473 pub index: Index,
474}
475
476/**
477 * The main set of methods for using Record types for Reverse Differentiation.
478 *
479 * The general steps are
480 * 1. create a `WengertList`
481 * 2. create variables from this list
482 * 3. do operations on the variables
483 * 4. from the output you want to compute derivatives for call `.derivatives()`
484 * 5. index the `Derivatives` object with the index variables to get the derivatives
485 * with respect to each input
486 * 6. if you want to make another pass call `clear()` on the `WengertList`
487 * and then call `reset()` on all of the variables to forget the gradients already
488 * computed (the order of `clear` then `reset` is very important!).
489 *
490 * Constants can be used to save memory if you have numbers that
491 * you do not need to compute the gradients with respect to.
492 */
493impl<'a, T: Numeric + Primitive> Record<'a, T> {
494 /**
495 * Creates an untracked Record which has no backing WengertList.
496 *
497 * This is provided for using constants along with Records in operations.
498 *
499 * For example with y = x + 4 the computation graph could be conceived as
500 * a y node with parent nodes of x and 4 combined with the operation +.
501 * However there is no need to record the derivatives of a constant, so
502 * instead the computation graph can be conceived as a y node with a single
503 * parent node of x and the unary operation of +4.
504 *
505 * This is also used for the type level constructors required by Numeric
506 * which are also considered constants.
507 */
508 pub fn constant(c: T) -> Record<'a, T> {
509 Record {
510 number: c,
511 history: None,
512 index: 0,
513 }
514 }
515
516 /**
517 * Creates a record backed by the provided WengertList.
518 *
519 * The record cannot live longer than the WengertList, hence
520 * the following example does not compile
521 *
522 * ```compile_fail
523 * use easy_ml::differentiation::Record;
524 * use easy_ml::differentiation::WengertList;
525 * let record = {
526 * let list = WengertList::new();
527 * Record::variable(1.0, &list)
528 * }; // list no longer in scope
529 * ```
530 *
531 * You can alternatively use the [record constructor on the WengertList type](WengertList::variable()).
532 */
533 pub fn variable(x: T, history: &'a WengertList<T>) -> Record<'a, T> {
534 Record {
535 number: x,
536 history: Some(history),
537 index: history.append_nullary(),
538 }
539 }
540
541 /**
542 * Creates a record from a constant/variable directly, most likely obtained by getting a
543 * tensor view of an existing [container](RecordContainer). **The inputs are not checked for
544 * validity**. It is possible to pass in the wrong Wengert list here or even numbers with
545 * indexes that aren’t tracked on the WengertList.
546 *
547 * It is recommended to use this constructor only in conjunction with manipulating an existing
548 * container and not for creating new variables. Any variables created outside of
549 * `Record::variable` / `RecordContainer::variables` would have to be manually added to the
550 * correct Wengert list, and any arithmetic operations would also need tracking correctly.
551 */
552 pub fn from_existing(number: (T, Index), history: Option<&'a WengertList<T>>) -> Record<'a, T> {
553 Record {
554 number: number.0,
555 history,
556 index: number.1,
557 }
558 }
559
560 /**
561 * Resets this Record to place it back on its WengertList, for use
562 * in performing another derivation after clearing the WengertList.
563 */
564 pub fn reset(&mut self) {
565 match self.history {
566 None => (), // noop
567 Some(history) => self.index = history.append_nullary(),
568 };
569 }
570
571 /**
572 * A convenience helper function which takes a Record by value and
573 * calls [reset](Record::reset()) on it.
574 */
575 pub fn do_reset(mut x: Record<T>) -> Record<T> {
576 x.reset();
577 x
578 }
579
580 /**
581 * Gets the WengertList this Record is backed by if a variable, and [None] if a constant.
582 */
583 pub fn history(&self) -> Option<&'a WengertList<T>> {
584 self.history
585 }
586}
587
588impl<'a, T: Numeric + Primitive> Record<'a, T>
589where
590 for<'t> &'t T: NumericRef<T>,
591{
592 /**
593 * Performs a backward pass up this record's WengertList from this
594 * record as the output, computing all the derivatives for the inputs
595 * involving this output.
596 *
597 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
598 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
599 *
600 * # Panics
601 *
602 * Panics if the Record has no backing WengertList, ie it was created as a
603 * constant.
604 */
605 #[track_caller]
606 pub fn derivatives(&self) -> Derivatives<T> {
607 match self.try_derivatives() {
608 None => panic!("Record has no WengertList to find derivatives from"),
609 Some(d) => d,
610 }
611 }
612
613 /**
614 * Performs a backward pass up this record's WengertList from this
615 * record as the output, computing all the derivatives for the inputs
616 * involving this output.
617 *
618 * If this record has no WengertList, ie it's a constant, None is returned instead.
619 *
620 * If you have N inputs x<sub>1</sub> to x<sub>N</sub>, and this output is y,
621 * then this computes all the derivatives δy/δx<sub>i</sub> for i = 1 to N.
622 */
623 pub fn try_derivatives(&self) -> Option<Derivatives<T>> {
624 let history = self.history?;
625 let operations = history.operations.borrow();
626
627 let mut derivatives = vec![T::zero(); operations.len()];
628
629 // δy/δy = 1
630 derivatives[self.index] = T::one();
631
632 // Go back up the computation graph to the inputs
633 for i in (0..operations.len()).rev() {
634 let operation = operations[i].clone();
635 let derivative = derivatives[i].clone();
636 // The chain rule allows breaking up the derivative of the output y
637 // with respect to the input x into many smaller derivatives that
638 // are summed together.
639 // δy/δx = δy/δw * δw/δx
640 // δy/δx = sum for all i parents of y ( δy/δw_i * δw_i/δx )
641 derivatives[operation.left_parent] = derivatives[operation.left_parent].clone()
642 + derivative.clone() * operation.left_derivative;
643 derivatives[operation.right_parent] = derivatives[operation.right_parent].clone()
644 + derivative * operation.right_derivative;
645 }
646
647 Some(Derivatives { derivatives })
648 }
649}
650
651impl<T: Primitive> WengertList<T> {
652 /**
653 * Creates a new empty WengertList from which Records can be constructed.
654 */
655 pub fn new() -> WengertList<T> {
656 WengertList {
657 operations: RefCell::new(Vec::new()),
658 }
659 }
660}
661
662impl<T: Primitive> Default for WengertList<T> {
663 fn default() -> Self {
664 Self::new()
665 }
666}
667
668impl<T> WengertList<T> {
669 /**
670 * Clears a WengertList to make it empty again. After clearing a WengertList
671 * you must reset all the Records still using that list. Then you can perform
672 * another computation and get new gradients.
673 */
674 pub fn clear(&self) {
675 self.operations.borrow_mut().clear();
676 }
677}
678
679impl<T: Numeric + Primitive> WengertList<T> {
680 /**
681 * Creates a record backed by this WengertList.
682 *
683 * You can alternatively use the [record constructor on the Record type](Record::variable()).
684 */
685 pub fn variable(&self, x: T) -> Record<'_, T> {
686 Record {
687 number: x,
688 history: Some(self),
689 index: self.append_nullary(),
690 }
691 }
692
693 /**
694 * Adds a value to the list which does not have any parent values.
695 */
696 fn append_nullary(&self) -> Index {
697 use std::ops::DerefMut;
698 let mut borrow = self.operations.borrow_mut();
699 BorrowedWengertList::new(borrow.deref_mut()).append_nullary()
700 }
701
702 /**
703 * Adds a number of values to the list which do not have any parent values, returning
704 * the index of the first added value, the others will be contiguously afterwards.
705 *
706 * If values is 0, returns the first index that would be used but wasn't.
707 */
708 fn append_nullary_repeating(&self, values: usize) -> Index {
709 let mut operations = self.operations.borrow_mut();
710 // insert into end of list
711 let starting_index = operations.len();
712 for i in 0..values {
713 let index = starting_index + i;
714 operations.push(Operation {
715 // this index of the child is used for both indexes as these
716 // won't be needed but will always be valid (ie point to a
717 // real entry on the list)
718 left_parent: index,
719 right_parent: index,
720 // for the parents 0 is used to zero out these calculations
721 // as there are no parents
722 left_derivative: T::zero(),
723 right_derivative: T::zero(),
724 });
725 }
726 starting_index
727 }
728
729 /**
730 * Adds a value to the list which has one parent.
731 *
732 * For an output w_N which depends on one parent w_N-1
733 * the derivative cached here is δw_N / δw_N-1
734 *
735 * For example, if z = sin(x), then δz/δx = cos(x)
736 */
737 fn append_unary(&self, parent: Index, derivative: T) -> Index {
738 use std::ops::DerefMut;
739 let mut borrow = self.operations.borrow_mut();
740 BorrowedWengertList::new(borrow.deref_mut()).append_unary(parent, derivative)
741 }
742
743 /**
744 * Adds a value to the list which has two parents.
745 *
746 * For an output w_N which depends on two parents w_N-1
747 * and w_N-2 the derivatives cached here are δw_N / δw_N-1
748 * and δw_N / δw_N-2.
749 *
750 * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
751 * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
752 */
753 fn append_binary(
754 &self,
755 left_parent: Index,
756 left_derivative: T,
757 right_parent: Index,
758 right_derivative: T,
759 ) -> Index {
760 use std::ops::DerefMut;
761 let mut borrow = self.operations.borrow_mut();
762 BorrowedWengertList::new(borrow.deref_mut()).append_binary(
763 left_parent,
764 left_derivative,
765 right_parent,
766 right_derivative,
767 )
768 }
769
770 /**
771 * Borrows the WengertList mutably for batch operations. It is *very* important to
772 * hold onto the borrow only for as long as needed then drop it immediately. To avoid panics
773 * Easy ML needs to ensure 100% of method calls on the public API do not maintain a borrow
774 * after they finish executing. This was previously enforced by not having any batch
775 * append APIs, but they're needed for RecordContainer. Calling borrow again while still
776 * holding the first would trigger a panic, as would holding onto the borrow after the public
777 * API method is finished
778 */
779 fn borrow<F>(&self, op: F)
780 where
781 F: FnOnce(&mut BorrowedWengertList<T>),
782 {
783 use std::ops::DerefMut;
784 let mut borrow = self.operations.borrow_mut();
785 op(&mut BorrowedWengertList::new(borrow.deref_mut()));
786 }
787}
788
789/**
790 * Any Wengert list of a Cloneable type implements clone
791 */
792impl<T: Clone + Primitive> Clone for WengertList<T> {
793 fn clone(&self) -> Self {
794 WengertList {
795 operations: RefCell::new(self.operations.borrow().clone()),
796 }
797 }
798}
799
800/**
801 * Methods for appending Operations after borrowing the Wengert list.
802 */
803impl<'a, T: Numeric + Primitive> BorrowedWengertList<'a, T> {
804 fn new(operations: &mut Vec<Operation<T>>) -> BorrowedWengertList<'_, T> {
805 BorrowedWengertList { operations }
806 }
807
808 /**
809 * Adds a value to the list which does not have any parent values.
810 */
811 fn append_nullary(&mut self) -> Index {
812 // insert into end of list
813 let index = self.operations.len();
814 self.operations.push(Operation {
815 // this index of the child is used for both indexes as these
816 // won't be needed but will always be valid (ie point to a
817 // real entry on the list)
818 left_parent: index,
819 right_parent: index,
820 // for the parents 0 is used to zero out these calculations
821 // as there are no parents
822 left_derivative: T::zero(),
823 right_derivative: T::zero(),
824 });
825 index
826 }
827
828 /**
829 * Adds a value to the list which has one parent.
830 *
831 * For an output w_N which depends on one parent w_N-1
832 * the derivative cached here is δw_N / δw_N-1
833 *
834 * For example, if z = sin(x), then δz/δx = cos(x)
835 */
836 fn append_unary(&mut self, parent: Index, derivative: T) -> Index {
837 // insert into end of list
838 let index = self.operations.len();
839 self.operations.push(Operation {
840 left_parent: parent,
841 // this index of the child is used as this index won't be needed
842 // but will always be valid (ie points to a real entry on the list)
843 right_parent: index,
844 left_derivative: derivative,
845 // for the right parent 0 is used to zero out this calculation
846 // as there is no right parent
847 right_derivative: T::zero(),
848 });
849 index
850 }
851
852 /**
853 * Adds a value to the list which has two parents.
854 *
855 * For an output w_N which depends on two parents w_N-1
856 * and w_N-2 the derivatives cached here are δw_N / δw_N-1
857 * and δw_N / δw_N-2.
858 *
859 * For example, if z = y + x, then δz/δy = 1 and δz/δx = 1
860 * For example, if z = y * x, then δz/δy = x and δz/δ/x = y
861 */
862 fn append_binary(
863 &mut self,
864 left_parent: Index,
865 left_derivative: T,
866 right_parent: Index,
867 right_derivative: T,
868 ) -> Index {
869 // insert into end of list
870 let index = self.operations.len();
871 self.operations.push(Operation {
872 left_parent,
873 right_parent,
874 left_derivative,
875 right_derivative,
876 });
877 index
878 }
879}
880
881impl<'a, T: Numeric + Primitive> Record<'a, T>
882where
883 for<'t> &'t T: NumericRef<T>,
884{
885 /**
886 * Creates a new Record from a reference to an existing Record by applying
887 * some unary function to it which operates on the type the Record wraps.
888 *
889 * To compute the new record, the unary function of some input x to some
890 * output y is needed along with its derivative with respect to its input x.
891 *
892 * For example, tanh is a commonly used activation function, but the Real trait
893 * does not include this operation and Record has no operations for it specifically.
894 * However, you can use this function to compute the tanh of a Record like so:
895 *
896 * ```
897 * use easy_ml::differentiation::{Record, WengertList};
898 * let list = WengertList::new();
899 * let x = Record::variable(0.7f32, &list);
900 * // the derivative of tanh(x) is sech(x) * sech(x) which is equivalent to
901 * // 1 / (cosh(x) * cosh(x))
902 * let y = x.unary(|x| x.tanh(), |x| 1.0 / (x.cosh() * x.cosh()));
903 * assert_eq!(y.derivatives()[&x], 1.0f32 / (0.7f32.cosh() * 0.7f32.cosh()));
904 * ```
905 */
906 #[inline]
907 pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Record<'_, T> {
908 match self.history {
909 None => Record {
910 number: fx(self.number.clone()),
911 history: None,
912 index: 0,
913 },
914 Some(history) => Record {
915 number: fx(self.number.clone()),
916 history: Some(history),
917 index: history.append_unary(self.index, dfx_dx(self.number.clone())),
918 },
919 }
920 }
921
922 /**
923 * Creates a new Record from a reference to two existing Records by applying
924 * some binary function to them which operates on two arguments of the type
925 * the Records wrap.
926 *
927 * To compute the new record, the binary function of some inputs x and y to some
928 * output z is needed along with its derivative with respect to its first input x and
929 * its derivative with respect to its second input y.
930 *
931 * For example, atan2 takes two arguments, but the Real trait
932 * does not include this operation and Record has no operations for it specifically.
933 * However, you can use this function to compute the atan2 of two Records like so:
934 *
935 * ```
936 * use easy_ml::differentiation::{Record, WengertList};
937 * let list = WengertList::new();
938 * let x = Record::variable(3.0f32, &list);
939 * let y = Record::variable(3.0f32, &list);
940 * // the derivative of atan2 with respect to x is y/(x*x + y*y)
941 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdx
942 * // the derivative of atan2 with respect to y is -x/(x*x + y*y)
943 * // https://www.wolframalpha.com/input/?i=d%28atan2%28x%2Cy%29%29%2Fdy
944 * let z = x.binary(&y,
945 * |x, y| x.atan2(y),
946 * |x, y| y/((x*x) + (y*y)),
947 * |x, y| -x/((x*x) + (y*y))
948 * );
949 * let derivatives = z.derivatives();
950 * let dx = derivatives[&x];
951 * let dy = derivatives[&y];
952 * ```
953 */
954 #[inline]
955 #[track_caller]
956 pub fn binary(
957 &self,
958 rhs: &Record<'a, T>,
959 fxy: impl Fn(T, T) -> T,
960 dfxy_dx: impl Fn(T, T) -> T,
961 dfxy_dy: impl Fn(T, T) -> T,
962 ) -> Record<'_, T> {
963 assert!(
964 record_operations::same_list(self, rhs),
965 "Records must be using the same WengertList"
966 );
967 match (self.history, rhs.history) {
968 (None, None) => Record {
969 number: fxy(self.number.clone(), rhs.number.clone()),
970 history: None,
971 index: 0,
972 },
973 (Some(history), None) => Record {
974 number: fxy(self.number.clone(), rhs.number.clone()),
975 history: Some(history),
976 index: history.append_unary(
977 // if rhs didn't have a history, don't track that derivative
978 self.index,
979 dfxy_dx(self.number.clone(), rhs.number.clone()),
980 ),
981 },
982 (None, Some(history)) => Record {
983 number: fxy(self.number.clone(), rhs.number.clone()),
984 history: Some(history),
985 index: history.append_unary(
986 // if self didn't have a history, don't track that derivative
987 rhs.index,
988 dfxy_dy(self.number.clone(), rhs.number.clone()),
989 ),
990 },
991 (Some(history), Some(_)) => Record {
992 number: fxy(self.number.clone(), rhs.number.clone()),
993 history: Some(history),
994 index: history.append_binary(
995 self.index,
996 dfxy_dx(self.number.clone(), rhs.number.clone()),
997 rhs.index,
998 dfxy_dy(self.number.clone(), rhs.number.clone()),
999 ),
1000 },
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1006#[should_panic]
1007#[test]
1008fn test_record_derivatives_when_no_history() {
1009 let record = Record::constant(1.0);
1010 record.derivatives();
1011}
1012
1013#[test]
1014fn test_sync() {
1015 fn assert_sync<T: Sync>() {}
1016 assert_sync::<Trace<f64>>();
1017}
1018
1019#[test]
1020fn test_send() {
1021 fn assert_send<T: Send>() {}
1022 assert_send::<Trace<f64>>();
1023}