brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// FlexVisionTransformer Encoder (burn 0.20.1)
///
/// Python: `FlexVisionTransformer` in flex_transformer.py.
///
/// Architecture:
///   1. FlexiPatchEmbed: Conv2d(1, embed_dim, (1, ps), (1, ps)) -> [B, N, D]
///   2. Add positional embeddings (brain gradient + geometric harmonics)
///   3. Optional CLS token prepend
///   4. Apply masks (index-gather for JEPA context selection)
///   5. Transformer blocks (pre-norm, GELU MLP) with optional attention masking
///   6. Final LayerNorm
use burn::module::{Param, ParamId};
use burn::prelude::*;

use crate::model::block::Block;
use crate::model::norm::LNorm;
use crate::model::patch_embed::FlexiPatchEmbed;
use crate::model::pos_embed::BrainHarmonyPosEmbed;

#[derive(Module, Debug)]
pub struct FlexVisionTransformer<B: Backend> {
    pub patch_embed: FlexiPatchEmbed<B>,
    pub pos_embed: BrainHarmonyPosEmbed<B>,
    pub cls_token: Option<Param<Tensor<B, 3>>>,
    pub blocks: Vec<Block<B>>,
    pub norm: LNorm<B>,
    pub embed_dim: usize,
    pub num_heads: usize,
}

impl<B: Backend> FlexVisionTransformer<B> {
    pub fn new(
        signal_size: (usize, usize),
        patch_size: usize,
        in_chans: usize,
        embed_dim: usize,
        depth: usize,
        num_heads: usize,
        mlp_ratio: f64,
        qkv_bias: bool,
        norm_eps: f64,
        grad_dim: usize,
        geoh_dim: usize,
        pred_embed_dim: usize,
        pos_mode: &str,
        use_cls_token: bool,
        use_decoder: bool,
        device: &B::Device,
    ) -> crate::error::Result<Self> {
        let patch_embed =
            FlexiPatchEmbed::new(signal_size, patch_size, in_chans, embed_dim, device);
        let grid_size = patch_embed.num_patches_2d;

        let pos_embed = BrainHarmonyPosEmbed::new(
            grad_dim,
            geoh_dim,
            embed_dim,
            pred_embed_dim,
            grid_size,
            pos_mode,
            use_cls_token,
            use_decoder,
            device,
        )?;

        let cls_token = if use_cls_token {
            Some(Param::initialized(
                ParamId::new(),
                Tensor::zeros([1, 1, embed_dim], device),
            ))
        } else {
            None
        };

        let blocks = (0..depth)
            .map(|_| Block::new(embed_dim, num_heads, mlp_ratio, qkv_bias, norm_eps, device))
            .collect();

        let norm = LNorm::new(embed_dim, norm_eps, device);

        Ok(Self {
            patch_embed,
            pos_embed,
            cls_token,
            blocks,
            norm,
            embed_dim,
            num_heads,
        })
    }

    /// Forward pass.
    ///
    /// x: [B, 1, H, W] raw fMRI signal
    /// gradient: [n_rois, grad_dim] brain gradient coordinates
    /// geoh: [n_rois, geoh_dim] geometric harmonics coordinates
    /// patch_size: optional runtime patch size override
    /// masks: optional list of index masks for JEPA context selection
    /// attn_mask: optional [B, N] binary attention mask
    ///
    /// Returns: [B, N_out, embed_dim]
    pub fn forward(
        &self,
        x: Tensor<B, 4>,
        gradient: Option<&Tensor<B, 2>>,
        geoh: Option<&Tensor<B, 2>>,
        patch_size: Option<usize>,
        masks: Option<&[Tensor<B, 2, Int>]>,
        attn_mask: Option<&Tensor<B, 2>>,
    ) -> Tensor<B, 3> {
        // 1. Patch embed: [B, 1, H, W] -> [B, N, D]
        let mut x = self.patch_embed.forward(x, patch_size);

        // 2. Add positional embeddings
        let (pos_emb_enc, _pos_emb_dec) = self.pos_embed.forward(gradient, geoh);

        if self.cls_token.is_some() {
            // Add pos embed to patches (skip CLS position)
            let pos_patches = pos_emb_enc.clone().narrow(1, 1, x.dims()[1]);
            x = x + pos_patches;

            // Apply masks before adding CLS
            if let Some(mask_list) = masks {
                x = apply_masks(x, mask_list);
            }

            // Prepend CLS token with its positional embedding
            let cls = self.cls_token.as_ref().unwrap().val();
            let cls_pos = pos_emb_enc.narrow(1, 0, 1);
            let cls_with_pos = cls + cls_pos;
            let cls_expanded = cls_with_pos.expand([x.dims()[0], 1, self.embed_dim]);
            x = Tensor::cat(vec![cls_expanded, x], 1);
        } else {
            x = x + pos_emb_enc;
            if let Some(mask_list) = masks {
                x = apply_masks(x, mask_list);
            }
        }

        // 3. Transformer blocks
        for block in &self.blocks {
            x = block.forward(x, attn_mask);
        }

        // 4. Final norm
        self.norm.forward(x)
    }
}

/// Gather patches from x using mask indices.
///
/// Python: `apply_masks(x, masks)` in libs/masks/utils.py.
///
/// x: [B, N, D]
/// masks: list of [B, K] int tensors (indices into dim 1)
/// Returns: [B * len(masks), K, D]
pub fn apply_masks<B: Backend>(
    x: Tensor<B, 3>,
    masks: &[Tensor<B, 2, Int>],
) -> Tensor<B, 3> {
    let [_b, _n, d] = x.dims();
    let parts: Vec<Tensor<B, 3>> = masks
        .iter()
        .map(|m| {
            let [b_m, k] = m.dims();
            let mask_exp = m.clone().unsqueeze_dim::<3>(2).expand([b_m, k, d]);
            x.clone().gather(1, mask_exp)
        })
        .collect();
    Tensor::cat(parts, 0)
}