# brainharmony-rs
Pure-Rust inference for the [Brain-Harmony](https://github.com/eugenehp/Brain-Harmony) multimodal brain foundation model, built on [Burn 0.20](https://burn.dev).
Brain-Harmony unifies morphology (T1 MRI) and function (fMRI) into 1D tokens using a Vision Transformer with brain gradient + geometric harmonics positional embeddings and a JEPA self-supervised learning framework.
## Benchmark
<p align="center">
<img src="figures/benchmark.png" alt="Benchmark: Python vs Rust" width="900">
</p>
**Setup:** Brain-Harmony ViT-Base encoder (12 layers, 768-dim, 12 heads, 7200 patches), single sample, 10 runs, Apple M4 Pro.
| Python MPS (Apple GPU) | 2.2s | 2.4s | 1.6x faster |
| Python CPU (PyTorch 2.11) | 3.6s | 3.8s | baseline |
| Rust wgpu f16 (Metal GPU) | 4.8s | 4.8s | 1.3x slower |
| Rust wgpu f32 (Metal GPU) | 6.4s | 6.8s | 1.8x slower |
| Rust Accelerate (CPU) | 45.4s | 45.7s | 12.6x slower |
> **Rust GPU f16 is within 1.3x of PyTorch CPU** thanks to tiled attention (tile=1024) that keeps softmax working sets in GPU cache. The remaining gap vs Python MPS is in Burn's wgpu shader generation vs Apple's hand-tuned Metal Performance Shaders. For production use, the `wgpu-f16` backend is recommended.
## Project Structure
```
brainharmony-rs/
├── src/
│ ├── lib.rs # Module declarations, re-exports
│ ├── config.rs # ModelConfig, DataConfig, YamlConfig
│ ├── error.rs # BrainHarmonyError
│ ├── weights.rs # SafeTensors weight loading
│ ├── data.rs # GradientData, GeohData, SignalInput
│ ├── inference.rs # BrainHarmonyEncoder
│ ├── predictor_api.rs # BrainHarmonyPredictor (JEPA pipeline)
│ ├── classification.rs # ClassificationHead, MLPHead
│ ├── masks.rs # Spatiotemporal masking
│ ├── csv_export.rs # CSV export utilities
│ ├── hf_download.rs # HuggingFace Hub integration
│ ├── model/
│ │ ├── encoder.rs # FlexVisionTransformer
│ │ ├── decoder.rs # VisionTransformerPredictor
│ │ ├── patch_embed.rs # FlexiPatchEmbed (dynamic patch size)
│ │ ├── pos_embed.rs # Brain gradient + geometric harmonics
│ │ ├── attention.rs # Multi-head attention with masking
│ │ ├── feedforward.rs # MLP with fast GELU
│ │ ├── block.rs # Pre-norm transformer block
│ │ └── norm.rs # LayerNorm wrapper
│ └── bin/infer.rs # CLI binary
├── examples/ # embed, batch, classify, csv_export, profile, bench
├── tests/ # config, data_loading, model, masks, classification, csv_export
├── scripts/
│ ├── convert_weights.py # PyTorch -> SafeTensors converter
│ └── bench_chart.py # Generate benchmark chart
└── benchmark.sh # Multi-backend benchmark runner
```
## Quick Start
### Install
```bash
git clone https://github.com/eugenehp/brainharmony-rs
cd brainharmony-rs
cargo build --release
```
### Build Variants
```bash
# CPU (default — NdArray + Rayon)
cargo build --release
# CPU + Apple Accelerate (recommended on macOS)
cargo build --release --features accelerate
# GPU via Metal (macOS) or Vulkan (Linux)
cargo build --release --no-default-features --features wgpu
# GPU half-precision
cargo build --release --no-default-features --features wgpu-f16
```
### Convert Weights
```bash
python scripts/convert_weights.py \
--input checkpoints/harmonizer/model.pth \
--output data/brainharmony.safetensors
```
### Run Inference
```bash
infer --weights data/brainharmony.safetensors \
--gradient data/gradient_mapping_400.csv \
--geoh data/schaefer400_roi_eigenmodes.csv \
--input data/signal.safetensors \
--output embeddings.safetensors
```
### Library Usage
```rust
use brainharmony::{BrainHarmonyEncoder, ModelConfig, DataConfig};
let (enc, ms) = BrainHarmonyEncoder::<B>::from_weights(
"model.safetensors",
"gradient_mapping_400.csv",
"schaefer400_roi_eigenmodes.csv",
&ModelConfig::default(),
&DataConfig::default(),
&device,
)?;
let result = enc.encode_safetensors("signal.safetensors")?;
result.save_safetensors("embeddings.safetensors")?;
```
## API
| `BrainHarmonyEncoder<B>` | Produce latent embeddings from brain signals |
| `BrainHarmonyPredictor<B>` | Full JEPA pipeline (encoder + predictor) |
| `ClassificationHead<B>` | Linear classification with global average pooling |
| `MLPHead<B>` | 3-layer MLP classification head (stage 2) |
## Architecture
**Input:** `[B, 1, 400, 864]` raw fMRI signal (400 ROIs x 18 patches x 48 time points)
```
FlexiPatchEmbed ──→ [B, 7200, 768] (Conv2d with dynamic patch size)
│
+ PosEmbed ──────→ brain gradient (30d) + geometric harmonics (200d)
│ projected, averaged, normalized to [-1,1]
│
12x Block ───────→ pre-norm, multi-head attention, GELU MLP
│
LayerNorm ───────→ [B, 7200, 768] output embeddings
```
## Backends
| Feature | Backend | Notes |
|---|---|---|
| `ndarray` (default) | CPU (NdArray + Rayon) | Baseline |
| `accelerate` | CPU + Apple Accelerate | macOS, ~1.5x faster matmul |
| `openblas-system` | CPU + OpenBLAS | Linux (`apt install libopenblas-dev`) |
| `wgpu` | GPU (Metal / Vulkan, f32) | `--no-default-features --features wgpu` |
| `wgpu-f16` | GPU (half precision) | `--no-default-features --features wgpu-f16` |
| `hf-download` | HuggingFace Hub client | Auto-download weights |
## Tests
```bash
cargo test
```
47 tests covering config, data loading, model components, masking, classification, and CSV export.
## License
MIT