mamba-rs
Mamba SSM (Selective State Space Model) implementation in Rust with optional CUDA GPU acceleration.
Supports both inference and training, including full backward pass with BPTT through recurrent SSM state. Custom CUDA kernels for GPU-accelerated forward and backward passes.
Reference: Gu & Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023).
Features
- Inference — zero-allocation single-step recurrent forward pass
- GPU Inference — CUDA kernels with optional CUDA Graph capture
- Training — full backward pass with BPTT through SSM hidden state
- Burn-in — warm up recurrent state from history before training window
- CUDA — custom kernels for SSM recurrence, conv1d, fused activations
- Modular — 3-level API: MambaLayer (pure mixer) / MambaBlock (norm+residual) / MambaBackbone (full)
- Serialization — save/load weights via safetensors (HuggingFace standard)
- Standalone — no framework dependency (no PyTorch, no Burn, no Candle)
- f32 — native single precision, TF32 Tensor Cores on Ampere/Hopper
Quick Start
[]
= "0.1"
use ;
let cfg = default; // d_model=128, 3 layers
let weights = load_weights; // your weight loading
let mut state = zeros;
let mut scratch = new;
let mut output = vec!;
// single-step inference (recurrent, O(1) per step)
mamba_step;
// reset state on sequence boundary
state.reset;
GPU (CUDA)
[]
= { = "0.1", = ["cuda"] }
Requires NVIDIA GPU + CUDA toolkit. Kernels compiled at runtime via NVRTC.
Architecture
input [B, T, input_dim]
|
input_proj (linear + bias)
|
v
+--------- x N layers ---------+
| |
| residual |
| | |
| RmsNorm |
| | |
| in_proj ----+---- gate |
| | | |
| conv1d | |
| | | |
| SiLU SiLU |
| | | |
| x_proj | |
| / | \ | |
| dt B C | |
| | | |
| dt_proj | |
| | | |
| softplus | |
| | | |
| SSM recurrence | |
| h = A*h + B*x | |
| y = C*h + D*x | |
| | | |
| +--- gate * --+ |
| | |
| out_proj |
| | |
| + residual |
| |
+-------------------------------+
norm_f (RmsNorm)
|
output [B, T, d_model]
Performance
Default config: d_model=128, 3 layers, d_inner=256, d_state=16, 366K params.
GPU Inference (T=1 step)
| Batch | GH200 (H100) | GH200 + CUDA Graph | Ada (RTX 6000) | Ada + CUDA Graph |
|---|---|---|---|---|
| B=1 | 155 μs | 115 μs | 124 μs | 79 μs |
| B=4 | 193 μs | 140 μs | 145 μs | 99 μs |
| B=16 | 200 μs | 147 μs | 148 μs | 102 μs |
| B=64 | 201 μs | 152 μs | 156 μs | 115 μs |
| B=128 | 212 μs | 164 μs | 176 μs | 138 μs |
CUDA Graph eliminates kernel launch overhead (~40 μs saved).
GPU Training (B=1, T=32)
| GH200 (H100) | Ada (RTX 6000) | |
|---|---|---|
| Forward | 966 μs | 706 μs |
| Forward + Backward | 2,353 μs | 1,739 μs |
CPU Inference (T=1 step, B=1)
| Config | d_model | layers | params | GH200 (Grace ARM) | Ada (Xeon) |
|---|---|---|---|---|---|
| small | 64 | 2 | 70K | 61 μs | — |
| default | 128 | 3 | 366K | 377 μs | 348 μs |
| medium | 256 | 4 | 1.8M | 2.2 ms | — |
| large | 512 | 6 | 10.4M | 13.6 ms | — |
CPU Training (B=1, T=32)
| GH200 (Grace ARM) | Ada (Xeon) | |
|---|---|---|
| Forward | 3,066 μs | 3,260 μs |
| Forward + Backward | 13,064 μs | 14,973 μs |
Speedups
| GPU vs CPU | |
|---|---|
| Inference B=1 (CUDA Graph) | 3.3x (GH200), 4.4x (Ada) |
| Training Fwd+Bwd T=32 | 5.5x (GH200), 8.6x (Ada) |
Zero heap allocations per inference step. All buffers pre-allocated.
Precision
GPU uses TF32 Tensor Cores (10-bit mantissa, ~1e-3 per-op precision). Validated against CPU f32:
| Check | Tolerance | Actual max diff |
|---|---|---|
| GPU vs CPU inference (20 steps) | 1e-2 | 0.003 |
| GPU vs CPU training forward (T=8) | 1e-2 | 0.003 |
| GPU vs CPU training backward (33 weight groups) | 0.15 | 0.124 |
| CUDA Graph vs non-graph | 1e-5 | < 1e-6 |
| CPU finite-diff gradient check | 5e-2 | < 1e-2 |
All 26 correctness tests pass on both GH200 (H100, Driver 595, CUDA 13.2) and Ada (RTX 6000, Driver 595, CUDA 13.2).
Modular API
Three levels matching the original architecture (Gu & Dao, 2023):
// Level 1: Pure mixer — no norm, no residual (like Mamba class in mamba_simple.py)
mamba_layer_step;
// Level 2: Block — pre-norm + mixer + residual (like Block class in block.py)
mamba_block_step;
// Level 3: Full backbone — input_proj + N blocks + norm_f
mamba_step;
Use Level 1 to integrate Mamba into custom architectures with your own normalization and residual patterns.
Weight Serialization
Save and load weights using the safetensors format (HuggingFace standard):
use serialize;
// Save
save?;
// Load
let = load?;
let backbone = from_weights?;
Compatible with Python safetensors library for cross-framework weight exchange.
GPU Inference
use GpuMambaBackbone;
let mut gpu = new?;
gpu.capture_graph?; // optional ~2-5x speedup via CUDA Graph
gpu.step?;
gpu.reset?; // episode boundary
Highlights
- Analytical gradients derived by hand — no autograd framework needed
- BPTT through SSM recurrent state across timesteps
- Burn-in for warming hidden state from historical context
- Zero-allocation inference with pre-allocated scratch buffers
- Custom CUDA kernels compiled at runtime via NVRTC
- Flat contiguous weight buffers for optimizer fusion
- CUDA Graph capture for minimal kernel launch overhead
- safetensors serialization for Python/HuggingFace interop
Citation
License
Dual-licensed under MIT or Apache-2.0.