Expand description
§jax-rs: JAX in Rust
A machine learning framework for the web, running on WebGPU & Wasm.
§Key Features
- NumPy-compatible API: Familiar array creation and manipulation
- Automatic differentiation:
grad,vjp,jvpfor computing gradients - Vectorization:
vmapfor batching operations - JIT compilation: Fused kernel execution for performance
- Multiple backends: CPU (debugging), WebAssembly, WebGPU
- Rust memory safety: No manual reference counting, automatic cleanup via
Drop
§Quick Start
use jax_rs::{Array, DType, Shape};
// Create arrays
let x = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);Re-exports§
pub use trace::grad;pub use trace::grad;pub use trace::jit;pub use trace::jit;pub use trace::value_and_grad;pub use trace::vmap;pub use trace::vmap;
Modules§
- backend
- Backend implementations for different compute devices.
- nn
- Neural network operations and activation functions.
- ops
- Array operations and transformations.
- optim
- Optimization algorithms for training neural networks.
- random
- Random number generation with reproducible PRNG keys.
- scipy
- Scipy special functions.
- trace
- Tracing infrastructure for JIT compilation and transformations.
Structs§
Enums§
Functions§
- default_
device - Get or set the default device.
- set_
default_ device - Set the default device for array operations.