reverse_differentiable 0.1.2

Automatic differentiation of functions.
Documentation
  • Coverage
  • 0%
    0 out of 2 items documented0 out of 1 items with examples
  • Size
  • Source code size: 18.84 kB This is the summed size of all the files inside the crates.io package for this release.
  • Documentation size: 277.46 kB This is the summed size of all files generated by rustdoc for all configured targets
  • Ø build duration
  • this release: 4s Average build duration of successful builds.
  • all releases: 4s Average build duration of successful builds in releases after 2024-10-23.
  • Links
  • al-jshen/reverse
    3 2 2
  • crates.io
  • Dependencies
  • Versions
  • Owners
  • al-jshen

reverse

Crates.io Documentation License

Reverse mode automatic differentiation in Rust.

To use this in your crate, add the following to Cargo.toml:

[dependencies]
reverse = "0.1"

Examples

use reverse::*;

fn main() {
  let graph = Graph::new();
  let a = graph.add_var(2.5);
  let b = graph.add_var(14.);
  let c = (a.sin().powi(2) + b.ln() * 3.) - 5.;
  let gradients = c.grad();

  assert_eq!(gradients.wrt(&a), (2. * 2.5).sin());
  assert_eq!(gradients.wrt(&b), 3. / 14.);
}

Differentiable Functions

There is an optional diff feature that activates a macro to transform functions to the right type so that they are differentiable. That is, functions that act on f64s can be used on differentiable variables without change, and without needing to specify the (not simple) correct type.

To use this, add the following to Cargo.toml:

reverse = { version = "0.1", features = ["diff"] }

Functions must have the type Fn(&[f64], &[&[f64]]) -> f64, where the first argument contains the differentiable parameters and the second argument contains arbitrary arrays of data.

Example

Here is an example of what the feature allows you to do:

use reverse::*;

fn main() {
    let graph = Graph::new();
    let a = graph.add_var(5.);
    let b = graph.add_var(2.);

    // you can track gradients through the function as usual!
    let res = addmul(&[a, b], &[&[4.]]);
    let grad = res.grad();

    assert_eq!(grad.wrt(&a), 1.);
    assert_eq!(grad.wrt(&b), 4.);
}

// function must have these argument types but can be arbitrarily complex
#[differentiable]
fn addmul(params: &[f64], data: &[&[f64]]) -> f64 {
    params[0] + data[0][0] * params[1]
}