Ryft: A Rust Framework for Tracing, Automatic Differentiation, and Just-In-Time Compilation
[!WARNING]
ryftis currently a work in progress and is evolving very actively. APIs and crate boundaries may change.
ryft is a Rust library for building machine learning systems that is inspired by
JAX. It aims to bring type-safe support for tracing, automatic
differentiation, and just-in-time compilation for leveraging hardware accelerators to Rust. The top-level ryft
crate is an umbrella crate that re-exports functionality from a few different crates through a single entry point:
ryft-core: Intended home for core tracing, automatic differentiation, and Just-In-Time (JIT) compilation. This crate is still in an early stage and should not be dependent upon. It is expected to start shaping up in the coming months. Today, the most complete and usable part ofryft-coreis theParameterizedAPI.ryft-macros: Procedural macros used byryftandryft-core(e.g., parameter-related derivation macros).ryft-mlir: High-level, ownership-aware Rust bindings for MLIR and MLIR dialects used by XLA tooling.ryft-pjrt: High-level, ownership-aware Rust bindings for PJRT plugins, clients, buffers, and program execution.ryft-xla-sys: Low-level-sysbindings for XLA/MLIR/PJRT APIs, plus native artifact/toolchain wiring.
Feature Flags
The ryft crate enables the xla feature by default which brings in the ryft-mlir, ryft-pjrt, and ryft-xla-sys
dependencies. Accelerator-specific features (e.g., cuda-12, cuda-13, rocm-7, tpu, neuron, and metal) are
forwarded through the crate stack (ryft -> ryft-core -> ryft-pjrt -> ryft-xla-sys). For feature semantics,
platform/runtime requirements, and artifact-loading behavior, refer to:
crates/ryft-xla-sys/README.md: Reference for XLA dependencies and for instructions on how to configure for obtaining pre-built binaries for supported platforms.crates/ryft-pjrt/README.md: Reference for our PJRT bindings.crates/ryft-mlir/README.md: Reference for our MLIR bindings.
Example: Low-Level StableHLO Matrix Multiplication
The following example uses the low-level MLIR and PJRT APIs provided by ryft::mlir and ryft::pjrt to build a toy
StableHLO matrix multiplication module programmatically, compile it, and execute it on the CPU plugin. Note that this
is quite low-level and verbose. ryft::core will make compiling and executing programs like this a lot more
ergonomic, similar to what JAX accomplishes in Python. Updates on that crate should be coming in the next few weeks
or months.
[!NOTE] If you want to run on CUDA 13 instead, enable
ryft'scuda-13feature and replaceload_cpu_plugin()withload_cuda_13_plugin()in the example code below.
use *;
use ;
use *;
[!NOTE] This is quite low-level and verbose.
ryft::corewill make compiling and executing programs like this a lot more ergonomic, similar to what JAX accomplishes in Python. Updates on that crate should be coming in the next few weeks or months.
Why "Ryft"?
The name for this framework started from the idea of Rust + Lift: "lifting" computations through tracing so they can
be transformed for automatic differentiation and just-in-time compilation. That naturally suggested the name rift.
Since that name was already taken, I chose ryft as a close alternative with the same original inspiration.
The short, catchy spelling also matches a core goal of the project: fast & efficient execution.