mamba-rs 0.1.3

Mamba SSM (Selective State Space Model) in Rust with CUDA GPU support. Training + inference, forward + backward (BPTT), burn-in, custom CUDA kernels.
Documentation

mamba-rs

Mamba SSM (Selective State Space Model) implementation in Rust with optional CUDA GPU acceleration.

Supports both inference and training, including full backward pass with BPTT through recurrent SSM state. Custom CUDA kernels for GPU-accelerated forward and backward passes.

Reference: Gu & Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023).

Features

  • Inference — zero-allocation single-step recurrent forward pass
  • GPU Inference — CUDA kernels with optional CUDA Graph capture
  • Training — full backward pass with BPTT through SSM hidden state
  • Burn-in — warm up recurrent state from history before training window
  • CUDA — custom kernels for SSM recurrence, conv1d, fused activations
  • Modular — 3-level API: MambaLayer (pure mixer) / MambaBlock (norm+residual) / MambaBackbone (full)
  • Serialization — save/load weights via safetensors (HuggingFace standard)
  • Standalone — no framework dependency (no PyTorch, no Burn, no Candle)
  • f32 — native single precision, TF32 Tensor Cores on Ampere/Hopper

Quick Start

[dependencies]
mamba-rs = "0.1"
use mamba_rs::{MambaConfig, MambaState, MambaStepScratch, MambaWeights, mamba_step};

let cfg = MambaConfig::default(); // d_model=128, 3 layers
let weights = load_weights(); // your weight loading
let mut state = MambaState::zeros(cfg.n_layers, cfg.d_inner(), cfg.d_state, cfg.d_conv);
let mut scratch = MambaStepScratch::new(&cfg);
let mut output = vec![0.0f32; cfg.d_model];

// single-step inference (recurrent, O(1) per step)
mamba_step(&input, &mut output, &weights, &mut state.layers, &mut scratch, &cfg, input_dim);

// reset state on sequence boundary
state.reset();

GPU (CUDA)

[dependencies]
mamba-rs = { version = "0.1", features = ["cuda"] }

Requires NVIDIA GPU + CUDA toolkit. Kernels compiled at runtime via NVRTC.

Architecture

    input [B, T, input_dim]
        |
    input_proj (linear + bias)
        |
        v
    +--------- x N layers ---------+
    |                               |
    |   residual                    |
    |      |                        |
    |   RmsNorm                     |
    |      |                        |
    |   in_proj ----+---- gate      |
    |      |             |          |
    |   conv1d           |          |
    |      |             |          |
    |   SiLU          SiLU          |
    |      |             |          |
    |   x_proj           |          |
    |    / | \           |          |
    |  dt  B  C          |          |
    |   |                |          |
    |  dt_proj           |          |
    |   |                |          |
    |  softplus          |          |
    |   |                |          |
    |  SSM recurrence    |          |
    |  h = A*h + B*x     |          |
    |  y = C*h + D*x     |          |
    |      |             |          |
    |      +--- gate * --+          |
    |            |                  |
    |        out_proj               |
    |            |                  |
    |      + residual               |
    |                               |
    +-------------------------------+

    norm_f (RmsNorm)
        |
    output [B, T, d_model]

Performance

Default config: d_model=128, 3 layers, d_inner=256, d_state=16, 366K params.

GPU Inference (T=1 step)

Batch GH200 (H100) GH200 + CUDA Graph Intel (RTX 6000) Ada + CUDA Graph
B=1 155 μs 115 μs 124 μs 79 μs
B=4 193 μs 140 μs 145 μs 99 μs
B=16 200 μs 147 μs 148 μs 102 μs
B=64 201 μs 152 μs 156 μs 115 μs
B=128 212 μs 164 μs 176 μs 138 μs

CUDA Graph eliminates kernel launch overhead (~40 μs saved).

GPU Training (B=1, T=32)

GH200 (H100) Intel (RTX 6000)
Forward 788 μs 706 μs
Forward + Backward 2,176 μs 1,739 μs

CPU Inference (T=1 step, B=1)

Config d_model layers params GH200 (Grace, 72 cores) Intel (Xeon Gold 5412U)
small 64 2 70K 21 μs 26 μs
default 128 3 366K 76 μs 86 μs
medium 256 4 1.8M 261 μs 349 μs
large 512 6 10.4M 1,226 μs 2,651 μs

CPU Training (B=1, T=32)

Config d_model layers params GH200 fwd/bwd/total Intel fwd/bwd/total
small 64 2 70K 1,096 / 1,901 / 2,998 μs 1,223 / 1,915 / 3,153 μs
default 128 3 366K 3,746 / 10,156 / 13,915 μs 4,252 / 12,024 / 16,276 μs
medium 256 4 1.8M 12,333 / 68,020 / 80,429 μs 14,249 / 54,683 / 68,931 μs
large 512 6 10.4M 49,325 / 428,821 / 478,146 μs 65,911 / 877,676 / 943,588 μs

Parallel Training (default config, T=32, all cores)

GH200 Grace (72 cores):

Batch Forward Backward Total Samples/sec Speedup
B=16 3,990 μs 13,588 μs 17,583 μs 910 12.6x
B=64 5,139 μs 20,898 μs 26,037 μs 2,458 34.3x
B=128 8,976 μs 31,151 μs 40,217 μs 3,183 44.3x

Intel (Xeon Gold 5412U, 24 cores / 48 threads):

Batch Forward Backward Total Samples/sec Speedup
B=16 10,536 μs 23,466 μs 34,002 μs 471 7.6x
B=64 21,970 μs 65,521 μs 87,491 μs 728 11.9x
B=128 38,216 μs 103,329 μs 141,545 μs 903 14.7x

Thread-local gradient accumulation with epoch-based reduce. GH200 61% parallel efficiency (44x/72), Intel 61% (15x/24).

Speedups

GPU vs CPU
Inference B=1 (CUDA Graph) 3.3x (GH200), 4.4x (Intel)
Training Fwd+Bwd T=32 5.5x (GH200), 8.6x (Intel)

Zero heap allocations per inference step. All buffers pre-allocated.

Precision

GPU uses TF32 Tensor Cores (10-bit mantissa, ~1e-3 per-op precision). Validated against CPU f32:

Check Tolerance Actual max diff
GPU vs CPU inference (20 steps) 1e-2 0.003
GPU vs CPU training forward (T=8) 1e-2 0.003
GPU vs CPU training backward (33 weight groups) 0.15 0.124
CUDA Graph vs non-graph 1e-5 < 1e-6
CPU finite-diff gradient check 5e-2 < 1e-2

All 26 correctness tests pass on both GH200 (H100, Driver 595, CUDA 13.2) and Intel (RTX 6000, Driver 595, CUDA 13.2).

Modular API

Three levels matching the original architecture (Gu & Dao, 2023):

// Level 1: Pure mixer — no norm, no residual (like Mamba class in mamba_simple.py)
mamba_layer_step(input, output, layer_weights, state, scratch, cfg);

// Level 2: Block — pre-norm + mixer + residual (like Block class in block.py)
mamba_block_step(hidden, layer_weights, state, scratch, cfg);

// Level 3: Full backbone — input_proj + N blocks + norm_f
mamba_step(input, output, weights, states, scratch, cfg, input_dim);

Use Level 1 to integrate Mamba into custom architectures with your own normalization and residual patterns.

Weight Serialization

Save and load weights using the safetensors format (HuggingFace standard):

use mamba_rs::serialize;

// Save
serialize::save(Path::new("model.safetensors"), backbone.weights(), cfg, input_dim)?;

// Load
let (weights, cfg, input_dim) = serialize::load(Path::new("model.safetensors"))?;
let backbone = MambaBackbone::from_weights(cfg, weights)?;

Compatible with Python safetensors library for cross-framework weight exchange.

GPU Inference

use mamba_rs::gpu::inference::GpuMambaBackbone;

let mut gpu = GpuMambaBackbone::new(0, cpu_weights, cfg, input_dim, batch)?;
gpu.capture_graph()?; // optional ~2-5x speedup via CUDA Graph

gpu.step(&input, &mut output)?;
gpu.reset()?; // episode boundary

Highlights

  • Analytical gradients derived by hand — no autograd framework needed
  • BPTT through SSM recurrent state across timesteps
  • Burn-in for warming hidden state from historical context
  • Zero-allocation inference with pre-allocated scratch buffers
  • Custom CUDA kernels compiled at runtime via NVRTC
  • Flat contiguous weight buffers for optimizer fusion
  • CUDA Graph capture for minimal kernel launch overhead
  • safetensors serialization for Python/HuggingFace interop

Citation

@inproceedings{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

License

Dual-licensed under MIT or Apache-2.0.