brainjepa-rs
Brain-JEPA fMRI Foundation Model — fully Rust inference pipeline.
brainjepa ports the Brain-JEPA encoder
(NeurIPS 2024, Spotlight) to Rust using RLX (rlx-engine, default).
| Binary | Purpose |
|---|---|
infer |
Encoder embeddings |
classify |
Downstream classification |
predict |
JEPA masked prediction |
Pretrained weights load from safetensors; inference runs without Python or PyTorch.
fMRI parcellated time series (450 ROIs x T time points)
|
v Data loading (CSV / safetensors)
| standardise -> temporal downsample (490 -> 160 frames)
|
v Brain-JEPA encoder (RLX)
| PatchEmbed Conv2d(1, 768, (1,16), (1,16))
| GradientPosEmbed sincos + brain gradient projection
| 12x Block LayerNorm -> MultiHeadAttn -> LayerNorm -> MLP(GELU)
| LayerNorm
|
v
embeddings.safetensors
embeddings [4500, 768] float32 (450 ROIs x 10 time patches x 768 dims)
Prerequisites
# Rust stable >= 1.78
|
No PyTorch or Python needed at inference time (Python only for one-time weight conversion).
Download weights (HuggingFace)
Pre-converted weights: eugenehp/BrainJEPA
This caches brainjepa.safetensors (~709 MB) and gradient_mapping_450.csv under
~/.cache/huggingface/hub/. The infer CLI resolves them automatically when
--weights / --gradient are omitted (or set BRAINJEPA_WEIGHTS / BRAINJEPA_GRADIENT).
Quick start (RLX — default)
# Download weights (once)
# CPU encode (weights from HF cache or data/)
# Apple Metal (macOS)
# RLX + Apple Accelerate BLAS
Explicit paths still work:
Classify and JEPA predict
# Downstream classification (optional --head-weights for trained head)
# JEPA encoder + predictor (deterministic masks, seed 42 by default)
Backends
| Feature | Backend | Build |
|---|---|---|
rlx-cpu (default) |
CPU + Rayon | cargo build --release |
rlx-blas-accelerate |
CPU + Apple Accelerate | --features rlx-engine,rlx-blas-accelerate |
rlx-blas-openblas |
CPU + OpenBLAS (Linux) | --features rlx-engine,rlx-blas-openblas |
rlx-metal |
Apple Metal | --features rlx-engine,rlx-metal |
rlx-gpu |
wgpu (cross-platform) | --features rlx-engine,rlx-gpu |
rlx-mlx |
Apple MLX | --features rlx-engine,rlx-mlx |
rlx-cuda |
NVIDIA CUDA | --features rlx-engine,rlx-cuda |
rlx-apple-silicon |
cpu + metal + accelerate | --features rlx-engine,rlx-apple-silicon |
CLI
Brain-JEPA fMRI encoder inference (RLX)
Usage: infer [OPTIONS] --input <INPUT>
Options:
--weights <WEIGHTS> Safetensors weights [env: BRAINJEPA_WEIGHTS]
--gradient <GRADIENT> Gradient CSV [env: BRAINJEPA_GRADIENT]
--input <INPUT> fMRI input (.safetensors or .csv)
--output <OUTPUT> Output embeddings [default: embeddings.safetensors]
--model <MODEL> vit_small | vit_base | vit_large [default: vit_base]
--device <DEVICE> RLX: cpu | metal | mlx | gpu | cuda | rocm | tpu (aliases: wgpu, mtl) [default: cpu]
--repo <REPO> HuggingFace repo [default: eugenehp/BrainJEPA]
--threads <THREADS> CPU threads [env: RAYON_NUM_THREADS]
-v, --verbose
Library usage (RLX default)
use *;
let = from_weights?;
let result = encoder.encode_safetensors?;
result.save_safetensors?;
Entry points
| Type | Use case |
|---|---|
BrainJepaEncoder |
latent embeddings |
BrainJepaPredictor |
JEPA evaluation |
ClassificationHead |
downstream classification |
Device errors
If you pick a backend that was not compiled in, infer prints a rebuild command using
brainjepa feature names (e.g. rlx-metal, not just metal):
cargo build --release --no-default-features --features rlx-engine,rlx-metal
cargo run --release --no-default-features --features rlx-engine,rlx-metal --bin infer -- --device metal ...
# Compare every compiled RLX backend on one fMRI sample:
cargo run --example backend_compare --release --features rlx-engine,rlx-metal,rlx-gpu
rlx 0.2.6+ from crates.io; enable rlx-mlx for Apple MLX (--features rlx-engine,rlx-mlx).
On macOS, native --device metal is usually preferable to --device gpu (wgpu).
Tests
# RLX graph smoke test (no weights)
# RLX param load vs real checkpoint
# Cross-backend parity — see docs/PARITY.md
Weight conversion (from PyTorch)
Architecture
ViT-Base encoder for fMRI: 450 ROIs × 160 time points → 4500 × 768 embeddings. See original Brain-JEPA paper for details.
| Variant | Embed dim | Depth | Heads |
|---|---|---|---|
vit_small |
384 | 12 | 6 |
vit_base |
768 | 12 | 12 |
vit_large |
1024 | 24 | 16 |
Code structure
src/
rlx/ # RLX graph, weights loader, encoder / predictor
hf_download.rs # HuggingFace cache + download
bin/
infer.rs # CLI
classify.rs
predict.rs
download_weights.rs
tests/
rlx_graph_compile.rs
rlx_weights_load.rs
parity_rlx_cross_backend.rs
docs/
PARITY.md
scripts/
parity.sh
Acknowledgement
Zijian Dong et al. Brain-JEPA (NeurIPS 2024 Spotlight). arXiv:2409.19407
Original: hzlab/Brain-JEPA. Patterns from zuna-rs.
License
MIT