jax_rs/
lib.rs

1//! # jax-rs: JAX in Rust
2//!
3//! A machine learning framework for the web, running on WebGPU & Wasm.
4//!
5//! ## Key Features
6//!
7//! - **NumPy-compatible API**: Familiar array creation and manipulation
8//! - **Automatic differentiation**: `grad`, `vjp`, `jvp` for computing gradients
9//! - **Vectorization**: `vmap` for batching operations
10//! - **JIT compilation**: Fused kernel execution for performance
11//! - **Multiple backends**: CPU (debugging), WebAssembly, WebGPU
12//! - **Rust memory safety**: No manual reference counting, automatic cleanup via `Drop`
13//!
14//! ## Quick Start
15//!
16//! ```rust,no_run
17//! use jax_rs::{Array, DType, Shape};
18//!
19//! // Create arrays
20//! let x = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
21//! ```
22
23#![warn(missing_docs)]
24#![warn(clippy::all)]
25
26mod array;
27pub mod backend;
28mod buffer;
29mod device;
30mod dtype;
31pub mod nn;
32pub mod ops;
33pub mod optim;
34pub mod random;
35pub mod scipy;
36mod shape;
37pub mod trace;
38
39// Public exports
40pub use array::Array;
41pub use device::{default_device, set_default_device, Device};
42pub use dtype::DType;
43pub use shape::Shape;
44pub use trace::{grad, jit, value_and_grad, vmap};