tch-rs
Rust bindings for PyTorch. The goal of the tch
crate is to provide some thin wrappers
around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as
possible to the original C++ api. More idiomatic rust bindings could then be
developed on top of this. The documentation can be found on docs.rs.
The code generation part for the C api on top of libtorch comes from ocaml-torch.
Getting Started
This crate requires the C++ version of PyTorch (libtorch) to be available on
your system. You can either install it manually and let the build script now about
it via the LIBTORCH
environment variable. If not set, the build script will
try downloading and extracting a pre-built binary version of libtorch.
Libtorch Manual Install
- Get
libtorch
from the PyTorch website download section and extract the content of the zip file. - Add the following to your
.bashrc
or equivalent.
- You should now be able to run some examples, e.g.
cargo run --example basics
.
Examples
The following code defines a simple model with one hidden layer.
This model can be trained on the MNIST dataset by running the following command.
More details on the training loop can be found in the detailed tutorial.
Further examples include:
- A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
- Some ResNet examples on CIFAR-10.