brainharmony-rs
Pure-Rust inference for the Brain-Harmony multimodal brain foundation model, built on Burn 0.20.
Brain-Harmony unifies morphology (T1 MRI) and function (fMRI) into 1D tokens using a Vision Transformer with brain gradient + geometric harmonics positional embeddings and a JEPA self-supervised learning framework.
Benchmark
Setup: Brain-Harmony ViT-Base encoder (12 layers, 768-dim, 12 heads, 7200 patches), single sample, 10 runs, Apple M4 Pro.
| Backend | Best | Median | vs Python CPU |
|---|---|---|---|
| Python MPS (Apple GPU) | 2.2s | 2.4s | 1.6x faster |
| Python CPU (PyTorch 2.11) | 3.6s | 3.8s | baseline |
| Rust wgpu f16 (Metal GPU) | 4.8s | 4.8s | 1.3x slower |
| Rust wgpu f32 (Metal GPU) | 6.4s | 6.8s | 1.8x slower |
| Rust Accelerate (CPU) | 45.4s | 45.7s | 12.6x slower |
Rust GPU f16 is within 1.3x of PyTorch CPU thanks to tiled attention (tile=1024) that keeps softmax working sets in GPU cache. The remaining gap vs Python MPS is in Burn's wgpu shader generation vs Apple's hand-tuned Metal Performance Shaders. For production use, the
wgpu-f16backend is recommended.
Project Structure
brainharmony-rs/
├── src/
│ ├── lib.rs # Module declarations, re-exports
│ ├── config.rs # ModelConfig, DataConfig, YamlConfig
│ ├── error.rs # BrainHarmonyError
│ ├── weights.rs # SafeTensors weight loading
│ ├── data.rs # GradientData, GeohData, SignalInput
│ ├── inference.rs # BrainHarmonyEncoder
│ ├── predictor_api.rs # BrainHarmonyPredictor (JEPA pipeline)
│ ├── classification.rs # ClassificationHead, MLPHead
│ ├── masks.rs # Spatiotemporal masking
│ ├── csv_export.rs # CSV export utilities
│ ├── hf_download.rs # HuggingFace Hub integration
│ ├── model/
│ │ ├── encoder.rs # FlexVisionTransformer
│ │ ├── decoder.rs # VisionTransformerPredictor
│ │ ├── patch_embed.rs # FlexiPatchEmbed (dynamic patch size)
│ │ ├── pos_embed.rs # Brain gradient + geometric harmonics
│ │ ├── attention.rs # Multi-head attention with masking
│ │ ├── feedforward.rs # MLP with fast GELU
│ │ ├── block.rs # Pre-norm transformer block
│ │ └── norm.rs # LayerNorm wrapper
│ └── bin/infer.rs # CLI binary
├── examples/ # embed, batch, classify, csv_export, profile, bench
├── tests/ # config, data_loading, model, masks, classification, csv_export
├── scripts/
│ ├── convert_weights.py # PyTorch -> SafeTensors converter
│ └── bench_chart.py # Generate benchmark chart
└── benchmark.sh # Multi-backend benchmark runner
Quick Start
Install
Build Variants
# CPU (default — NdArray + Rayon)
# CPU + Apple Accelerate (recommended on macOS)
# GPU via Metal (macOS) or Vulkan (Linux)
# GPU half-precision
Convert Weights
Run Inference
Library Usage
use ;
let = from_weights?;
let result = enc.encode_safetensors?;
result.save_safetensors?;
API
| Type | Use case |
|---|---|
BrainHarmonyEncoder<B> |
Produce latent embeddings from brain signals |
BrainHarmonyPredictor<B> |
Full JEPA pipeline (encoder + predictor) |
ClassificationHead<B> |
Linear classification with global average pooling |
MLPHead<B> |
3-layer MLP classification head (stage 2) |
Architecture
Input: [B, 1, 400, 864] raw fMRI signal (400 ROIs x 18 patches x 48 time points)
FlexiPatchEmbed ──→ [B, 7200, 768] (Conv2d with dynamic patch size)
│
+ PosEmbed ──────→ brain gradient (30d) + geometric harmonics (200d)
│ projected, averaged, normalized to [-1,1]
│
12x Block ───────→ pre-norm, multi-head attention, GELU MLP
│
LayerNorm ───────→ [B, 7200, 768] output embeddings
Backends
| Feature | Backend | Notes |
|---|---|---|
ndarray (default) |
CPU (NdArray + Rayon) | Baseline |
accelerate |
CPU + Apple Accelerate | macOS, ~1.5x faster matmul |
openblas-system |
CPU + OpenBLAS | Linux (apt install libopenblas-dev) |
wgpu |
GPU (Metal / Vulkan, f32) | --no-default-features --features wgpu |
wgpu-f16 |
GPU (half precision) | --no-default-features --features wgpu-f16 |
hf-download |
HuggingFace Hub client | Auto-download weights |
Tests
47 tests covering config, data loading, model components, masking, classification, and CSV export.
License
MIT