Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
tribev2-rs
TRIBE v2 — Multimodal fMRI Brain Encoding Model — Inference in Rust
Pure-Rust inference engine for TRIBE v2 (d'Ascoli et al., 2026), a deep multimodal brain encoding model that predicts fMRI brain responses to naturalistic stimuli (video, audio, text).
Same model, new runtime.
tribev2-rsloads the exact same pretrained weights asfacebook/tribev2— no fine-tuning, no quantisation, no architectural changes. Every layer has been independently verified for numerical parity with the Python reference implementation.
Predicted cortical activity on the fsaverage5 surface (20,484 vertices), rendered from the pretrained TRIBE v2 model with multi-modal input.
Workspace Structure
tribev2-rs/
├── crates/
│ ├── tribev2/ Core brain encoding model, CLI, features, plotting
│ ├── tribev2-audio/ Wav2Vec-BERT 2.0 audio feature extraction (burn)
│ └── tribev2-video/ V-JEPA2 ViT-G video feature extraction (burn)
└── scripts/
└── extract_llama_features.py True per-layer LLaMA extraction (HuggingFace)
Crate Overview
| Crate | Description |
|---|---|
tribev2 |
FmriEncoderModel (pure-Rust + burn backends), weight loading, segment-based inference, events pipeline, brain surface plotting, CLI |
tribev2-audio |
Wav2Vec-BERT 2.0 conformer encoder in burn — raw waveform → per-layer hidden states at 2 Hz |
tribev2-video |
V-JEPA2 ViT-Giant in burn — video frames → 3D patch embedding → ViT layers → per-layer features at 2 Hz |
Features
- 100% inference parity with the Python implementation — every operation verified
- Two backends — pure-Rust (CPU) and burn (CPU/GPU via NdArray, wgpu Metal, Vulkan)
- Both backends load pretrained weights from safetensors
- Multi-modal inference — text, audio, and video features simultaneously
- Text feature extraction — LLaMA 3.2-3B via llama-cpp (Rust) or HuggingFace (Python script for true per-layer extraction)
- Audio feature extraction — Wav2Vec-BERT 2.0 in burn (16 kHz waveform → conformer hidden states)
- Video feature extraction — V-JEPA2 ViT-G in burn (frames → 3D patch embedding → ViT hidden states)
- Segment-based batching — long-form inference with configurable overlap
- Brain surface visualization — SVG rendering on fsaverage5 cortical mesh (6 views, 6 colormaps, colorbars, RGB overlays, MP4 time series)
- Events pipeline — whisperX transcription, ffmpeg audio extraction, sentence/context annotation
- HuggingFace Hub download support
Architecture
The model combines feature extractors — LLaMA 3.2 (text), V-JEPA2 (video), and Wav2Vec-BERT (audio) — into a unified x-transformers Encoder that maps multimodal representations onto the fsaverage5 cortical surface (~20,484 vertices).
| Component | Python | Rust (pure) | Rust (burn) |
|---|---|---|---|
| Projector (Linear/MLP/SubjectLayers) | Mlp / SubjectLayersModel |
model::projector::Projector |
model_burn::projector::Projector<B> |
| Combiner | Mlp / nn.Identity |
Projector (optional) |
MlpProjector<B> (optional) |
| Temporal smoothing | depthwise Conv1d |
TemporalSmoothing |
depthwise conv kernel |
| Time positional embedding | nn.Parameter |
Tensor |
Param<Tensor<B,3>> |
| Subject embedding | nn.Embedding |
Tensor |
Param<Tensor<B,2>> |
| x-transformers Encoder | x_transformers.Encoder |
XTransformerEncoder |
XTransformerEncoder<B> |
| ScaleNorm + RoPE + Attention + FF | x_transformers | hand-written | burn ops (+ optional fused CubeCL) |
| Low-rank head | nn.Linear(bias=False) |
Tensor matmul |
Linear<B> |
| Subject layers | SubjectLayersModel |
SubjectLayers |
SubjectLayers<B> |
| AdaptiveAvgPool1d | nn.AdaptiveAvgPool1d |
floor/ceil matching PyTorch | floor/ceil matching PyTorch |
| Weight loading | PyTorch load_state_dict |
weights::load_weights() |
model_burn::weights::load_burn_weights() |
Quick Start
1. Download weights
2. Run inference
# Text-only with LLaMA
# Multi-modal with pre-extracted features + brain plots
3. True per-layer LLaMA features (exact Python parity)
The llama-cpp backend extracts final-layer embeddings only. For true per-layer hidden states matching the Python pipeline:
# Extract with HuggingFace (requires: pip install transformers torch)
# Use in Rust (auto-reads .json sidecar for shape metadata)
4. Library usage
use BTreeMap;
use TribeV2;
use Tensor;
// Load pretrained model
let model = from_pretrained?;
// Build features: [1, n_layers*dim, timesteps]
let mut features = new;
features.insert;
features.insert;
features.insert;
// Forward pass → [1, 20484, 100]
let output = model.forward;
5. Burn backend (GPU inference)
use ;
use TribeV2Burn;
use ;
type B = NdArray; // or burn::backend::Wgpu
let device = Defaultdefault;
let config: TribeV2Config = from_str?;
let dims = pretrained;
let mut model = new;
// Load pretrained weights into burn model
let mut ws = from_safetensors?;
load_burn_weights?;
// Forward pass
let text = zeros;
let audio = zeros;
let video = zeros;
let output = model.forward;
// output: [1, 20484, 100]
Audio Feature Extraction (tribev2-audio)
use ;
use load_audio;
use ;
type B = NdArray;
let device = Defaultdefault;
let config = default; // facebook/w2v-bert-2.0
let mut model = new;
// Load HuggingFace weights
let mut ws = from_safetensors?;
load_wav2vec_bert_weights?;
// Extract features
let waveform = load_audio?;
let features = model.extract_features;
// features: [3, 1024, 120] at 2 Hz
Video Feature Extraction (tribev2-video)
use ;
use video_io;
use ;
type B = NdArray;
let device = Defaultdefault;
let config = default; // facebook/vjepa2-vitg-fpc64-256
let mut model = new;
let mut ws = from_safetensors?;
load_vjepa2_weights?;
// Extract frames and run model
// (see tribev2-video docs for full frame preprocessing pipeline)
Pretrained Model Details
| Parameter | Value |
|---|---|
| Hidden dim | 1152 |
| Encoder depth | 8 layers (8 attn + 8 FF) |
| Attention heads | 8 |
| FF multiplier | 4× |
| Norm | ScaleNorm |
| Position encoding | Rotary (dim=72) |
| Text extractor | LLaMA-3.2-3B (3 layer groups × 3072) |
| Audio extractor | Wav2Vec-BERT 2.0 (3 layer groups × 1024) |
| Video extractor | V-JEPA2 ViT-G (3 layer groups × 1408) |
| Low-rank head | 2048 |
| Output | fsaverage5 (20,484 vertices), 100 TRs |
Feature Flags
| Flag | Description |
|---|---|
ndarray |
Burn NdArray CPU backend (default) |
blas-accelerate |
+ Apple Accelerate BLAS |
wgpu |
Burn wgpu backend (auto-detects Metal/Vulkan/DX12) |
wgpu-metal |
+ native Metal MSL shaders |
wgpu-metal-f16 |
+ Metal f16 dtype (WMMA) |
wgpu-kernels-metal |
+ fused CubeCL kernels (fastest macOS) |
wgpu-vulkan |
+ Vulkan SPIR-V shaders |
llama-metal |
Metal GPU for LLaMA (default) |
llama-cuda |
CUDA for LLaMA |
llama-vulkan |
Vulkan for LLaMA |
hf-download |
HuggingFace Hub download support |
Benchmarks
Full forward pass: 1152-d, 8-layer transformer, 20,484 outputs, 100 timesteps, 3 modalities.
| Backend | Mean (ms) | Speedup |
|---|---|---|
| Rust CPU (naive) | 14,516 | 1× |
| Burn NdArray | 316 | 46× |
| Burn NdArray + Accelerate | 143 | 102× |
| Rust CPU + Accelerate | 73 | 199× |
| Burn wgpu Metal + fused kernels | 16.8 | 864× |
Citation
License
| Component | License |
|---|---|
| Rust source code | Apache-2.0 |
| Pretrained model weights | CC BY-NC 4.0 |