Expand description
Ergonomics & safety focused deep learning in Rust. Main features include:
- Const generic tensor library with tensors up to 4d!
- A large library of tensor operations (matrix multiplication, arithmetic, activation functions, etc).
- Safe & easy to use neural network building blocks.
- Standard deep learning optimizers such as Sgd and Adam.
- Reverse mode auto differentiation implementation.
- Serialization to/from
.npy
and.npz
for transferring models to/from python.
A quick tutorial
- crate::tensor::Tensors can be created with normal rust arrays. See crate::tensor.
let x = Tensor2D::new([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let y: Tensor2D<2, 3> = Tensor2D::ones();
- Neural networks are built with types. Tuples are sequential models. See crate::nn.
type Mlp = (
Linear<5, 3>,
ReLU,
Linear<3, 2>,
);
- Instantiate models with Default, and randomize with crate::nn::ResetParams
let mut mlp: Linear<5, 2> = Default::default();
mlp.reset_params(&mut rng);
- Pass data through networks with crate::nn::Module
let mut mlp: Linear<5, 2> = Default::default();
let x = Tensor1D::zeros(); // rust figures out that x must be a `Tensor1D<5>` bc its given to mlp.forward()!
let y = mlp.forward(x); // rust will auto figure out that `y` is `Tensor1D<2>`!
- Trace gradients using crate::tensor::trace()
// tensors default to not having a tape
let x: Tensor1D<10, NoneTape> = Tensor1D::zeros();
// `.trace()` clones `x` and inserts a gradient tape.
let x_t: Tensor1D<10, OwnedTape> = x.trace();
// The tape from the input is moved through the network during .forward().
let y: Tensor1D<5, NoneTape> = model.forward(x);
let y_t: Tensor1D<5, OwnedTape> = model.forward(x_t);
- Compute gradients with crate::tensor_ops::backward(). See crate::tensor_ops.
// compute cross entropy loss
let loss: Tensor0D<OwnedTape> = cross_entropy_with_logits_loss(y, &y_true);
// call `backward()` to compute gradients. The tensor *must* have `OwnedTape`!
let gradients: Gradients = loss.backward();
- Use an optimizer from crate::optim to optimize your network!
// Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum.
let mut opt = Sgd::new(SgdConfig {
lr: 1e-2,
momentum: Some(Momentum::Classic(0.9))
});
// pass the gradients & the model into the optimizer's update method
opt.update(&mut model, gradients);
Modules
Collection of traits to describe Nd arrays.
A collection of data utility classes such as one_hot_encode() and SubsetIterator.
Implementations of GradientTape and generic Nd array containers via Gradients.
Standard loss functions such as mse_loss(), cross_entropy_with_logits_loss(), and more.
Contains all public exports.
A simple implementation of a UID used as a unique key for tensors.
Structs
Used to assert things about const generics
Constants
The library used for BLAS. Configure with crate features.
Traits
Functions
Sets a CPU sse
flag to flush denormal floating point numbers to zero. The opposite of this is keep_denormals().
Sets a CPU flag to keep denormal floating point numbers. The opposite of this is flush_denormals_to_zero().