mamba-rs 0.1.2

Mamba SSM (Selective State Space Model) in Rust with CUDA GPU support. Training + inference, forward + backward (BPTT), burn-in, custom CUDA kernels.
Documentation

mamba-rs

Mamba SSM (Selective State Space Model) implementation in Rust with optional CUDA GPU acceleration.

Supports both inference and training, including full backward pass with BPTT through recurrent SSM state. Custom CUDA kernels for GPU-accelerated forward and backward passes.

Reference: Gu & Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023).

Features

  • Inference — zero-allocation single-step recurrent forward pass
  • GPU Inference — CUDA kernels with optional CUDA Graph capture
  • Training — full backward pass with BPTT through SSM hidden state
  • Burn-in — warm up recurrent state from history before training window
  • CUDA — custom kernels for SSM recurrence, conv1d, fused activations
  • Modular — 3-level API: MambaLayer (pure mixer) / MambaBlock (norm+residual) / MambaBackbone (full)
  • Serialization — save/load weights via safetensors (HuggingFace standard)
  • Standalone — no framework dependency (no PyTorch, no Burn, no Candle)
  • f32 — native single precision, TF32 Tensor Cores on Ampere/Hopper

Quick Start

[dependencies]
mamba-rs = "0.1"
use mamba_rs::{MambaConfig, MambaState, MambaStepScratch, MambaWeights, mamba_step};

let cfg = MambaConfig::default(); // d_model=128, 3 layers
let weights = load_weights(); // your weight loading
let mut state = MambaState::zeros(cfg.n_layers, cfg.d_inner(), cfg.d_state, cfg.d_conv);
let mut scratch = MambaStepScratch::new(&cfg);
let mut output = vec![0.0f32; cfg.d_model];

// single-step inference (recurrent, O(1) per step)
mamba_step(&input, &mut output, &weights, &mut state.layers, &mut scratch, &cfg, input_dim);

// reset state on sequence boundary
state.reset();

GPU (CUDA)

[dependencies]
mamba-rs = { version = "0.1", features = ["cuda"] }

Requires NVIDIA GPU + CUDA toolkit. Kernels compiled at runtime via NVRTC.

Architecture

    input [B, T, input_dim]
        |
    input_proj (linear + bias)
        |
        v
    +--------- x N layers ---------+
    |                               |
    |   residual                    |
    |      |                        |
    |   RmsNorm                     |
    |      |                        |
    |   in_proj ----+---- gate      |
    |      |             |          |
    |   conv1d           |          |
    |      |             |          |
    |   SiLU          SiLU          |
    |      |             |          |
    |   x_proj           |          |
    |    / | \           |          |
    |  dt  B  C          |          |
    |   |                |          |
    |  dt_proj           |          |
    |   |                |          |
    |  softplus          |          |
    |   |                |          |
    |  SSM recurrence    |          |
    |  h = A*h + B*x     |          |
    |  y = C*h + D*x     |          |
    |      |             |          |
    |      +--- gate * --+          |
    |            |                  |
    |        out_proj               |
    |            |                  |
    |      + residual               |
    |                               |
    +-------------------------------+

    norm_f (RmsNorm)
        |
    output [B, T, d_model]

Performance

Default config: d_model=128, 3 layers, d_inner=256, d_state=16, 366K params.

GPU Inference (T=1 step)

Batch GH200 (H100) GH200 + CUDA Graph Ada (RTX 6000) Ada + CUDA Graph
B=1 155 μs 115 μs 124 μs 79 μs
B=4 193 μs 140 μs 145 μs 99 μs
B=16 200 μs 147 μs 148 μs 102 μs
B=64 201 μs 152 μs 156 μs 115 μs
B=128 212 μs 164 μs 176 μs 138 μs

CUDA Graph eliminates kernel launch overhead (~40 μs saved).

GPU Training (B=1, T=32)

GH200 (H100) Ada (RTX 6000)
Forward 966 μs 706 μs
Forward + Backward 2,353 μs 1,739 μs

CPU Inference (T=1 step, B=1)

Config d_model layers params GH200 (Grace ARM) Ada (Xeon)
small 64 2 70K 61 μs
default 128 3 366K 377 μs 348 μs
medium 256 4 1.8M 2.2 ms
large 512 6 10.4M 13.6 ms

CPU Training (B=1, T=32)

GH200 (Grace ARM) Ada (Xeon)
Forward 3,066 μs 3,260 μs
Forward + Backward 13,064 μs 14,973 μs

Speedups

GPU vs CPU
Inference B=1 (CUDA Graph) 3.3x (GH200), 4.4x (Ada)
Training Fwd+Bwd T=32 5.5x (GH200), 8.6x (Ada)

Zero heap allocations per inference step. All buffers pre-allocated.

Precision

GPU uses TF32 Tensor Cores (10-bit mantissa, ~1e-3 per-op precision). Validated against CPU f32:

Check Tolerance Actual max diff
GPU vs CPU inference (20 steps) 1e-2 0.003
GPU vs CPU training forward (T=8) 1e-2 0.003
GPU vs CPU training backward (33 weight groups) 0.15 0.124
CUDA Graph vs non-graph 1e-5 < 1e-6
CPU finite-diff gradient check 5e-2 < 1e-2

All 26 correctness tests pass on both GH200 (H100, Driver 595, CUDA 13.2) and Ada (RTX 6000, Driver 595, CUDA 13.2).

Modular API

Three levels matching the original architecture (Gu & Dao, 2023):

// Level 1: Pure mixer — no norm, no residual (like Mamba class in mamba_simple.py)
mamba_layer_step(input, output, layer_weights, state, scratch, cfg);

// Level 2: Block — pre-norm + mixer + residual (like Block class in block.py)
mamba_block_step(hidden, layer_weights, state, scratch, cfg);

// Level 3: Full backbone — input_proj + N blocks + norm_f
mamba_step(input, output, weights, states, scratch, cfg, input_dim);

Use Level 1 to integrate Mamba into custom architectures with your own normalization and residual patterns.

Weight Serialization

Save and load weights using the safetensors format (HuggingFace standard):

use mamba_rs::serialize;

// Save
serialize::save(Path::new("model.safetensors"), backbone.weights(), cfg, input_dim)?;

// Load
let (weights, cfg, input_dim) = serialize::load(Path::new("model.safetensors"))?;
let backbone = MambaBackbone::from_weights(cfg, weights)?;

Compatible with Python safetensors library for cross-framework weight exchange.

GPU Inference

use mamba_rs::gpu::inference::GpuMambaBackbone;

let mut gpu = GpuMambaBackbone::new(0, cpu_weights, cfg, input_dim, batch)?;
gpu.capture_graph()?; // optional ~2-5x speedup via CUDA Graph

gpu.step(&input, &mut output)?;
gpu.reset()?; // episode boundary

Highlights

  • Analytical gradients derived by hand — no autograd framework needed
  • BPTT through SSM recurrent state across timesteps
  • Burn-in for warming hidden state from historical context
  • Zero-allocation inference with pre-allocated scratch buffers
  • Custom CUDA kernels compiled at runtime via NVRTC
  • Flat contiguous weight buffers for optimizer fusion
  • CUDA Graph capture for minimal kernel launch overhead
  • safetensors serialization for Python/HuggingFace interop

Citation

@inproceedings{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

License

Dual-licensed under MIT or Apache-2.0.