labram-rs
Pure-Rust inference for the LaBraM (Large Brain Model) EEG foundation model, built on Burn 0.20.
LaBraM is a BEiTv2-inspired transformer pretrained on large-scale EEG data with neural tokenization. It features Q/K normalization, gamma residual scaling, and channel-aware position embeddings from the standard 10-20 system.
Architecture
EEG [B, C, T]
│
├─ Reshape to patches [B, C, n_patches, patch_size]
│
├─ TemporalConv (3× Conv2d+GELU+GroupNorm)
│ → [B, C*n_patches, embed_dim=200]
│
├─ Prepend [CLS] token + Position Embedding + Temporal Embedding
│
├─ Transformer Blocks (BEiTv2-style)
│ ├─ LayerNorm → Attention (Q/K norm, no qkv bias) → γ₁ scaling → residual
│ └─ LayerNorm → MLP (GELU) → γ₂ scaling → residual
│
├─ LayerNorm → extract [CLS] token
│
└─ Linear → [B, n_outputs]
Benchmarks
Benchmarked on Apple M3 Max (arm64), embed_dim=200, depth=2, heads=10.
Inference Latency

Speedup

| Configuration | Python (PyTorch) | Rust (Accelerate) | Rust (Metal GPU) |
|---|---|---|---|
| 8ch × 800t | 2.20 ms | 2.08 ms (1.06x) | 5.60 ms |
| 8ch × 1600t | 2.56 ms | 2.99 ms | 5.53 ms |
| 8ch × 3200t | 3.04 ms | 5.17 ms | 5.65 ms |
Note: LaBraM is a compact model (200 embed_dim) where PyTorch's optimized BLAS kernels are hard to beat. Rust (Accelerate) achieves parity on small inputs. For production with the full 12-layer model, the gap narrows further.
Numerical Parity
Python ↔ Rust output difference: < 2.3×10⁻¹⁰ — essentially bit-perfect.
Build
Pretrained Weights
Available on HuggingFace.
Citation
Author
License
Apache-2.0