rai-core 0.5.0

ML framework with Ergonomic APIs in Rust
Documentation

RAI

Rust Docs Status Latest Version

ML framework with Ergonomic APIs in Rust. Lazy computation and composable transformations.

Note: It required Rust nightly with following features [fn_traits, unboxed_closures]

Installation

cargo add rai

Code snippets

Function transformations (jvp, vjp, grad, value_and_grad)

use rai::backend::Cpu;
use rai::{grad, DType, Tensor};

fn f(x: &Tensor) -> Tensor {
    x.sin()
}

fn main() {
    let grad_fn = grad(grad(f));

    let backend = &Cpu;
    let x = Tensor::ones([1], DType::F32, backend);
    let grads = grad_fn([x]);

    println!("{}", grads[0].dot_graph());
    println!("{}", grads[0]);
}

NN Modules, Optimizer and loss functions

fn loss_fn<M: Module + 'static>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}

fn train_step<O: Optimizer, M: Module + 'static>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), grads) = vg_fn((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(&params);
    model.update(&mut params);
}

Examples

LICENSE

This project is licensed under either of

at your option.