Expand description
§reverse
reverse
is a light-weight, zero-dependency crate for performing reverse-mode automatic
differentiation in Rust. This is useful when you have functions with many inputs producing a
small number of outputs, as the gradients for all inputs with respect to a particular output
can be computed in a single pass.
§Usage
A tape (also called a Wengert list) is created with Tape::new()
. Variables can then
be added to the tape, either individually (.add_var
) or as a slice (.add_vars
).
This yields differentiable variables with type Var<'a>
.
Differentiable variables can be manipulated like f64
s, are tracked with the tape,
and gradients with respect to other variables can be calculated. Operations can
be performed between variables and normal f64
s as well, and the f64
s are treated as
constants with no gradients.
You can define functions that have Var<'a>
as an input (potentially along with other fixed
data of type f64
) and as an output, and the function will be differentiable. For example:
use reverse::*;
fn main() {
let tape = Tape::new();
let params = tape.add_vars(&[5., 2., 0.]);
let data = [1., 2.];
let result = diff_fn(¶ms, &data);
let gradients = result.grad();
println!("{:?}", gradients.wrt(¶ms));
}
fn diff_fn<'a>(params: &[Var<'a>], data: &[f64]) -> Var<'a> {
params[0].powf(params[1]) + data[0].sin() - params[2].asinh() / data[1]
}
Structs§
- Tape
- Tape (Wengert list) that tracks differentiable variables, intermediate values, and the operations applied to each.
- Var
- Differentiable variable. This is the main type that users will interact with.