mamba-rs 0.1.1

Mamba SSM (Selective State Space Model) in Rust with CUDA GPU support. Training + inference, forward + backward (BPTT), burn-in, custom CUDA kernels.
Documentation

mamba-rs

Mamba SSM (Selective State Space Model) implementation in Rust with optional CUDA GPU acceleration.

Supports both inference and training, including full backward pass with BPTT through recurrent SSM state. Custom CUDA kernels for GPU-accelerated forward and backward passes.

Reference: Gu & Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023).

Features

  • Inference — zero-allocation single-step recurrent forward pass (~200μs on CPU)
  • Training — full backward pass with BPTT through SSM hidden state
  • Burn-in — warm up recurrent state from history before training window
  • CUDA — custom kernels for SSM recurrence, conv1d, fused activations
  • Standalone — no framework dependency (no PyTorch, no Burn, no Candle)
  • f32 — native single precision, TF32 Tensor Cores on Ampere/Hopper

Quick Start

[dependencies]
mamba-rs = "0.1"
use mamba_rs::{MambaConfig, MambaState, MambaStepScratch, MambaWeights, mamba_step};

let cfg = MambaConfig::default(); // d_model=128, 3 layers
let weights = load_weights(); // your weight loading
let mut state = MambaState::zeros(cfg.n_layers, cfg.d_inner(), cfg.d_state, cfg.d_conv);
let mut scratch = MambaStepScratch::new(&cfg);
let mut output = vec![0.0f32; cfg.d_model];

// single-step inference (recurrent, O(1) per step)
mamba_step(&input, &mut output, &weights, &mut state.layers, &mut scratch, &cfg, input_dim);

// reset state on sequence boundary
state.reset();

GPU (CUDA)

[dependencies]
mamba-rs = { version = "0.1", features = ["cuda"] }

Requires NVIDIA GPU + CUDA toolkit. Kernels compiled at runtime via NVRTC.

Architecture

    input [B, T, input_dim]
        |
    input_proj (linear + bias)
        |
        v
    +--------- x N layers ---------+
    |                               |
    |   residual                    |
    |      |                        |
    |   RmsNorm                     |
    |      |                        |
    |   in_proj ----+---- gate      |
    |      |             |          |
    |   conv1d           |          |
    |      |             |          |
    |   SiLU          SiLU          |
    |      |             |          |
    |   x_proj           |          |
    |    / | \           |          |
    |  dt  B  C          |          |
    |   |                |          |
    |  dt_proj           |          |
    |   |                |          |
    |  softplus          |          |
    |   |                |          |
    |  SSM recurrence    |          |
    |  h = A*h + B*x     |          |
    |  y = C*h + D*x     |          |
    |      |             |          |
    |      +--- gate * --+          |
    |            |                  |
    |        out_proj               |
    |            |                  |
    |      + residual               |
    |                               |
    +-------------------------------+

    norm_f (RmsNorm)
        |
    output [B, T, d_model]

Performance

CPU inference (single-step recurrent, GH200 ARM Neoverse V2):

Config d_model layers params latency
small 64 2 70K 61 μs
default 128 3 366K 375 μs
medium 256 4 1.8M 2.2 ms
large 512 6 10.4M 13.6 ms

Zero heap allocations per step. All buffers pre-allocated.

GPU SGEMM throughput (GH200 H100, TF32 Tensor Cores, d_model=128):

Batch cuBLAS SGEMM Latency
B=1 1x128 x 128x128 25 μs
B=128 128x128 x 128x128 10 μs
B=256 256x128 x 128x128 10 μs

Highlights

  • Full backward pass with hand-derived analytical gradients (no autograd dependency)
  • BPTT through SSM recurrent state across time steps
  • Burn-in support for warming hidden state from historical context
  • Zero-allocation inference path with pre-allocated scratch buffers
  • Custom CUDA kernels for SSM recurrence, conv1d, and fused activations
  • Flat contiguous weight buffers for efficient optimizer fusion
  • Compatible with CUDA Graph capture for low-latency GPU execution

Differences from Python

Python (state-spaces/mamba) mamba-rs
Backward PyTorch autograd Manual BPTT (hand-derived)
Kernels Triton + CUDA C++ CUDA (NVRTC, runtime compile)
Framework PyTorch None (standalone)
Training Via autograd Explicit forward + backward
Precision fp16/bf16/fp32 f32 (TF32 on GPU)
Burn-in Not exposed First-class API

Citation

@inproceedings{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

License

Dual-licensed under MIT or Apache-2.0.