rai-nn 0.7.0

ML framework with Ergonomic APIs in Rust
Documentation

RAI

Rust Docs Status Latest Version

ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations.

Installation

cargo add rai

Code snippets

Function transformations (jvp, vjp, grad, value_and_grad)

use rai::{backend::Cpu, grad, Func, Tensor, F32};

fn f(x: &Tensor) -> Tensor {
    x.sin()
}

fn main() {
    let grad_fn = grad(grad(f));

    let backend = &Cpu;
    let x = Tensor::ones([1], F32, backend);
    let grad = grad_fn.apply((&x,));

    println!("{}", grad.dot_graph());
    println!("{}", grad);
}

NN Modules, Optimizer and loss functions

fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}

fn train_step<
    M: TrainableModule<
        Input = Tensor,
        Output = Tensor,
        Tensors = HashMap<usize, Tensor>,
        Gradient = HashMap<usize, Tensor>,
    >,
    O: Optimizer,
>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn.apply((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(&params);
    model.update_params(&mut params);
}

Examples

  • linear_regression
    • cargo run --bin linear_regression --release
  • mnist
    • cargo run --bin mnist --release
  • phi2
    • cargo run --bin phi2 --release

LICENSE

This project is licensed under either of

at your option.