eegpt-rs
Pure-Rust inference for the EEGPT (Pretrained Transformer for Universal and Reliable Representation of EEG Signals) foundation model, built on Burn 0.20.
EEGPT uses dual self-supervised learning (spatio-temporal representation alignment + masked reconstruction) pretrained on large-scale EEG data. It employs channel embeddings for spatial flexibility and summary tokens for compact global representation.
Architecture
EEG [B, C, T]
│
├─ Conv2d Patch Embedding (patch_size=64, stride=32)
│ → [B, n_patches, C, embed_dim]
│
├─ Channel Embedding (learned per-channel vectors)
│ → [B, n_patches, C, embed_dim]
│
├─ Per-patch Transformer (concat summary tokens, self-attention)
│ [B*n_patches, C+embed_num, embed_dim] → blocks → extract summary
│ → [B, n_patches, embed_num, embed_dim]
│
└─ LinearConstraintProbe (flatten → Linear → flatten → Linear)
→ [B, n_outputs]
Benchmarks
Benchmarked on Apple M3 Max (arm64, macOS 26.3.1) with embed_dim=512, depth=2, heads=8.
Python baseline: PyTorch 2.8.0. Rust backends: NdArray, Accelerate BLAS, Metal GPU.
Inference Latency

Speedup vs Python

The Metal GPU backend achieves up to 3.3x speedup over PyTorch, with consistent 2–3x improvement across all configurations.
Channel Scaling

Time Scaling

Summary Table
| Configuration | Python (PyTorch) | Rust (Accelerate) | Rust (Metal GPU) |
|---|---|---|---|
| 4ch × 1000t | 5.46 ms | 12.91 ms | 2.61 ms (2.1x) |
| 8ch × 1000t | 6.93 ms | 15.15 ms | 2.59 ms (2.7x) |
| 16ch × 1000t | 10.13 ms | 20.84 ms | 4.38 ms (2.3x) |
| 22ch × 1000t | 13.33 ms | 24.43 ms | 4.05 ms (3.3x) |
| 32ch × 1000t | 16.27 ms | 32.62 ms | 6.18 ms (2.6x) |
| 64ch × 1000t | 28.68 ms | — | 9.43 ms (3.0x) |
| 22ch × 2000t | 22.18 ms | 44.03 ms | 8.77 ms (2.5x) |
| 22ch × 4000t | 41.82 ms | 85.27 ms | 17.01 ms (2.5x) |
Note: EEGPT processes each temporal patch through the full transformer independently (B×n_patches forward passes), making it more compute-intensive per sample than models like REVE. The Metal GPU backend excels by parallelizing these independent patch computations.
Numerical Parity
Python ↔ Rust output difference: < 7.3×10⁻⁷ (f32 precision limit).
Build
Pretrained Weights
Available on HuggingFace.
Citation
Author
License
Apache-2.0