brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation

brainharmony-rs

Pure-Rust inference for the Brain-Harmony multimodal brain foundation model, built on Burn 0.20.

Brain-Harmony unifies morphology (T1 MRI) and function (fMRI) into 1D tokens using a Vision Transformer with brain gradient + geometric harmonics positional embeddings and a JEPA self-supervised learning framework.

Benchmark

Setup: Brain-Harmony ViT-Base encoder (12 layers, 768-dim, 12 heads, 7200 patches), single sample, 10 runs, Apple M4 Pro.

Backend Best Median vs Python CPU
Python MPS (Apple GPU) 2.2s 2.4s 1.6x faster
Python CPU (PyTorch 2.11) 3.6s 3.8s baseline
Rust wgpu f16 (Metal GPU) 4.8s 4.8s 1.3x slower
Rust wgpu f32 (Metal GPU) 6.4s 6.8s 1.8x slower
Rust Accelerate (CPU) 45.4s 45.7s 12.6x slower

Rust GPU f16 is within 1.3x of PyTorch CPU thanks to tiled attention (tile=1024) that keeps softmax working sets in GPU cache. The remaining gap vs Python MPS is in Burn's wgpu shader generation vs Apple's hand-tuned Metal Performance Shaders. For production use, the wgpu-f16 backend is recommended.

Project Structure

brainharmony-rs/
├── src/
│   ├── lib.rs                    # Module declarations, re-exports
│   ├── config.rs                 # ModelConfig, DataConfig, YamlConfig
│   ├── error.rs                  # BrainHarmonyError
│   ├── weights.rs                # SafeTensors weight loading
│   ├── data.rs                   # GradientData, GeohData, SignalInput
│   ├── inference.rs              # BrainHarmonyEncoder
│   ├── predictor_api.rs          # BrainHarmonyPredictor (JEPA pipeline)
│   ├── classification.rs         # ClassificationHead, MLPHead
│   ├── masks.rs                  # Spatiotemporal masking
│   ├── csv_export.rs             # CSV export utilities
│   ├── hf_download.rs            # HuggingFace Hub integration
│   ├── model/
│   │   ├── encoder.rs            # FlexVisionTransformer
│   │   ├── decoder.rs            # VisionTransformerPredictor
│   │   ├── patch_embed.rs        # FlexiPatchEmbed (dynamic patch size)
│   │   ├── pos_embed.rs          # Brain gradient + geometric harmonics
│   │   ├── attention.rs          # Multi-head attention with masking
│   │   ├── feedforward.rs        # MLP with fast GELU
│   │   ├── block.rs              # Pre-norm transformer block
│   │   └── norm.rs               # LayerNorm wrapper
│   └── bin/infer.rs              # CLI binary
├── examples/                     # embed, batch, classify, csv_export, profile, bench
├── tests/                        # config, data_loading, model, masks, classification, csv_export
├── scripts/
│   ├── convert_weights.py        # PyTorch -> SafeTensors converter
│   └── bench_chart.py            # Generate benchmark chart
└── benchmark.sh                  # Multi-backend benchmark runner

Quick Start

Install

git clone https://github.com/eugenehp/brainharmony-rs
cd brainharmony-rs
cargo build --release

Build Variants

# CPU (default — NdArray + Rayon)
cargo build --release

# CPU + Apple Accelerate (recommended on macOS)
cargo build --release --features accelerate

# GPU via Metal (macOS) or Vulkan (Linux)
cargo build --release --no-default-features --features wgpu

# GPU half-precision
cargo build --release --no-default-features --features wgpu-f16

Convert Weights

python scripts/convert_weights.py \
    --input checkpoints/harmonizer/model.pth \
    --output data/brainharmony.safetensors

Run Inference

infer --weights data/brainharmony.safetensors \
      --gradient data/gradient_mapping_400.csv \
      --geoh data/schaefer400_roi_eigenmodes.csv \
      --input data/signal.safetensors \
      --output embeddings.safetensors

Library Usage

use brainharmony::{BrainHarmonyEncoder, ModelConfig, DataConfig};

let (enc, ms) = BrainHarmonyEncoder::<B>::from_weights(
    "model.safetensors",
    "gradient_mapping_400.csv",
    "schaefer400_roi_eigenmodes.csv",
    &ModelConfig::default(),
    &DataConfig::default(),
    &device,
)?;

let result = enc.encode_safetensors("signal.safetensors")?;
result.save_safetensors("embeddings.safetensors")?;

API

Type Use case
BrainHarmonyEncoder<B> Produce latent embeddings from brain signals
BrainHarmonyPredictor<B> Full JEPA pipeline (encoder + predictor)
ClassificationHead<B> Linear classification with global average pooling
MLPHead<B> 3-layer MLP classification head (stage 2)

Architecture

Input: [B, 1, 400, 864] raw fMRI signal (400 ROIs x 18 patches x 48 time points)

FlexiPatchEmbed ──→ [B, 7200, 768]     (Conv2d with dynamic patch size)
       │
  + PosEmbed ──────→ brain gradient (30d) + geometric harmonics (200d)
       │                projected, averaged, normalized to [-1,1]
       │
  12x Block ───────→ pre-norm, multi-head attention, GELU MLP
       │
  LayerNorm ───────→ [B, 7200, 768]     output embeddings

Backends

Feature Backend Notes
ndarray (default) CPU (NdArray + Rayon) Baseline
accelerate CPU + Apple Accelerate macOS, ~1.5x faster matmul
openblas-system CPU + OpenBLAS Linux (apt install libopenblas-dev)
wgpu GPU (Metal / Vulkan, f32) --no-default-features --features wgpu
wgpu-f16 GPU (half precision) --no-default-features --features wgpu-f16
hf-download HuggingFace Hub client Auto-download weights

Tests

cargo test

47 tests covering config, data loading, model components, masking, classification, and CSV export.

License

MIT