Crate jax_rs

Crate jax_rs 

Source
Expand description

§jax-rs: JAX in Rust

A machine learning framework for the web, running on WebGPU & Wasm.

§Key Features

  • NumPy-compatible API: Familiar array creation and manipulation
  • Automatic differentiation: grad, vjp, jvp for computing gradients
  • Vectorization: vmap for batching operations
  • JIT compilation: Fused kernel execution for performance
  • Multiple backends: CPU (debugging), WebAssembly, WebGPU
  • Rust memory safety: No manual reference counting, automatic cleanup via Drop

§Quick Start

use jax_rs::{Array, DType, Shape};

// Create arrays
let x = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);

Re-exports§

pub use trace::grad;
pub use trace::grad;
pub use trace::jit;
pub use trace::jit;
pub use trace::value_and_grad;
pub use trace::vmap;
pub use trace::vmap;

Modules§

backend
Backend implementations for different compute devices.
nn
Neural network operations and activation functions.
ops
Array operations and transformations.
optim
Optimization algorithms for training neural networks.
random
Random number generation with reproducible PRNG keys.
scipy
Scipy special functions.
trace
Tracing infrastructure for JIT compilation and transformations.

Structs§

Array
A multidimensional numeric array.
Shape
Shape of an n-dimensional array.

Enums§

DType
Numerical data type for array contents.
Device
Compute device for array operations.

Functions§

default_device
Get or set the default device.
set_default_device
Set the default device for array operations.