Expand description
Reverse mode automatic differentation. Currently only gradients can be computed. Suggestions on how to extend the functionality to Hessian matrices are definitely welcome.
Additionally, only functions $f: \mathbb{R}^n \rightarrow \mathbb{R}$ (scalar output) are supported. However, you can manually apply the differentiation to multiple functions that could represent a vector output.
-
Reverse (Adjoint) Mode
- Implementation via Operator and Function Overloading.
- Useful when number of outputs is smaller than number of inputs.
- i.e for functions $f:\mathbb{R}^n \rightarrow \mathbb{R}^m$, where $m \ll n$
-
Forward (Tangent) Mode
- Implementation via Dual Numbers.
- Useful when number of outputs is larger than number of inputs.
- i.e. for functions $f:\mathbb{R}^n \rightarrow \mathbb{R}^m$, where $m \gg n$
// Create a new Graph to store the computations.
let g = Graph::new();
// Assign variables.
let x = g.var(69.);
let y = g.var(420.);
// Define a function.
let f = {
let a = x.powi(2);
let b = y.powi(2);
a + b + (x * y).exp()
};
// Accumulate the gradient.
let gradient = f.accumulate();
println!("Function = {}", f);
println!("Gradient = {:?}", gradient.wrt([x, y]));
You can also generate Graphviz (dot) code to visualize the computation graphs:
ⓘ
println!("{}", graphviz(&graph, &variables));
The computation graph from computing Black-Scholes Greeks is shown at the following link:
It is clearly a work in progress, but gives a general idea of how the computation graph is structured.
If you want to improve the visualization, please feel free to submit a PR!
Re-exports§
pub use accumulate::*;
pub use gradient::*;
pub use graph::*;
pub use graphviz::*;
pub use vertex::*;
pub use overload::*;
pub use variable::*;
Modules§
- accumulate
Accumulate
trait. Reverse accumulation trait. This trait is used to reverse accumulate the gradient for different types.- gradient
- Implements the gradient computation.
This module contains the
Gradient
trait. Each implementation ofwrt
returns the chosen partial derivatives. - graph
- The Graph (aka. tape or Wengert List).
This module contains the implementation of the computation
Graph
. The graph is also known as a Wengert List. - graphviz
- Visualisation of the
Graph
. This module is for visualising a Graph. - overload
- Operator/function overloading.
This module contains the overloaded operators and primitive functions.
In Griewank and Walther - Evaluating Derivatives, they refer to this
as the “elemental library”.
Operations such as
+
and*
are redefined, along with primitive functions such assin
,exp
, andlog
. Each overload has an associated test to ensure functionality. - variable
Variable
s forautodiff
. This module contains the implementation of theVariable
structure.- vertex
- Implements
Vertex
(nodes) for theGraph
.