mamba-rs
Mamba SSM (Selective State Space Model) implementation in Rust with optional CUDA GPU acceleration.
Supports both inference and training, including full backward pass with BPTT through recurrent SSM state. Custom CUDA kernels for GPU-accelerated forward and backward passes.
Reference: Gu & Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023).
Features
- Inference — zero-allocation single-step recurrent forward pass (~200μs on CPU)
- Training — full backward pass with BPTT through SSM hidden state
- Burn-in — warm up recurrent state from history before training window
- CUDA — custom kernels for SSM recurrence, conv1d, fused activations
- Standalone — no framework dependency (no PyTorch, no Burn, no Candle)
- f32 — native single precision, TF32 Tensor Cores on Ampere/Hopper
Quick Start
[]
= "0.1"
use ;
let cfg = default; // d_model=128, 3 layers
let weights = load_weights; // your weight loading
let mut state = zeros;
let mut scratch = new;
let mut output = vec!;
// single-step inference (recurrent, O(1) per step)
mamba_step;
// reset state on sequence boundary
state.reset;
GPU (CUDA)
[]
= { = "0.1", = ["cuda"] }
Requires NVIDIA GPU + CUDA toolkit. Kernels compiled at runtime via NVRTC.
Architecture
input [B, T, input_dim]
|
input_proj (linear + bias)
|
v
+--------- x N layers ---------+
| |
| residual |
| | |
| RmsNorm |
| | |
| in_proj ----+---- gate |
| | | |
| conv1d | |
| | | |
| SiLU SiLU |
| | | |
| x_proj | |
| / | \ | |
| dt B C | |
| | | |
| dt_proj | |
| | | |
| softplus | |
| | | |
| SSM recurrence | |
| h = A*h + B*x | |
| y = C*h + D*x | |
| | | |
| +--- gate * --+ |
| | |
| out_proj |
| | |
| + residual |
| |
+-------------------------------+
norm_f (RmsNorm)
|
output [B, T, d_model]
Performance
CPU inference (single-step recurrent, GH200 ARM Neoverse V2):
| Config | d_model | layers | params | latency |
|---|---|---|---|---|
| small | 64 | 2 | 70K | 61 μs |
| default | 128 | 3 | 366K | 375 μs |
| medium | 256 | 4 | 1.8M | 2.2 ms |
| large | 512 | 6 | 10.4M | 13.6 ms |
Zero heap allocations per step. All buffers pre-allocated.
GPU SGEMM throughput (GH200 H100, TF32 Tensor Cores, d_model=128):
| Batch | cuBLAS SGEMM | Latency |
|---|---|---|
| B=1 | 1x128 x 128x128 | 25 μs |
| B=128 | 128x128 x 128x128 | 10 μs |
| B=256 | 256x128 x 128x128 | 10 μs |
Highlights
- Full backward pass with hand-derived analytical gradients (no autograd dependency)
- BPTT through SSM recurrent state across time steps
- Burn-in support for warming hidden state from historical context
- Zero-allocation inference path with pre-allocated scratch buffers
- Custom CUDA kernels for SSM recurrence, conv1d, and fused activations
- Flat contiguous weight buffers for efficient optimizer fusion
- Compatible with CUDA Graph capture for low-latency GPU execution
Differences from Python
| Python (state-spaces/mamba) | mamba-rs | |
|---|---|---|
| Backward | PyTorch autograd | Manual BPTT (hand-derived) |
| Kernels | Triton + CUDA C++ | CUDA (NVRTC, runtime compile) |
| Framework | PyTorch | None (standalone) |
| Training | Via autograd | Explicit forward + backward |
| Precision | fp16/bf16/fp32 | f32 (TF32 on GPU) |
| Burn-in | Not exposed | First-class API |
Citation
License
Dual-licensed under MIT or Apache-2.0.