brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Flexible Patch Embedding via Conv2d (burn 0.20.1)
///
/// Python: `FlexiPatchEmbed` in flex_patch_embed.py.
///   Conv2d(in_chans, embed_dim, kernel_size=(1, patch_size), stride=(1, patch_size))
///
/// Supports dynamic patch size at runtime via pseudo-inverse resampling.
///
/// Input: [B, 1, n_rois, signal_length] raw fMRI signal
/// Output: [B, n_rois * (signal_length / patch_size), embed_dim]
use burn::nn::Linear;
use burn::prelude::*;

use crate::model::linear_zeros;

#[derive(Module, Debug)]
pub struct FlexiPatchEmbed<B: Backend> {
    /// Linear projection simulating Conv2d(1, embed_dim, (1, ps), (1, ps)).
    pub proj: Linear<B>,
    pub patch_size: usize,
    pub num_patches: usize,
    pub num_patches_2d: (usize, usize),
}

impl<B: Backend> FlexiPatchEmbed<B> {
    pub fn new(
        signal_size: (usize, usize), // (n_rois, signal_length)
        patch_size: usize,
        in_chans: usize,
        embed_dim: usize,
        device: &B::Device,
    ) -> Self {
        let n_rois = signal_size.0;
        let n_time_patches = signal_size.1 / patch_size;
        let num_patches = n_rois * n_time_patches;

        let proj = linear_zeros(in_chans * patch_size, embed_dim, true, device);

        Self {
            proj,
            patch_size,
            num_patches,
            num_patches_2d: (n_rois, n_time_patches),
        }
    }

    /// x: [B, 1, H, W] -> [B, H * (W/ps), embed_dim]
    ///
    /// The `_patch_size` parameter allows runtime patch size override.
    /// When it matches `self.patch_size`, uses stored weights directly.
    pub fn forward(&self, x: Tensor<B, 4>, _patch_size: Option<usize>) -> Tensor<B, 3> {
        let [b, _c, h, w] = x.dims();
        let ps = _patch_size.unwrap_or(self.patch_size);
        let n_t = w / ps;

        // Reshape: [B, 1, H, W] -> [B, H, W] -> [B, H, n_t, ps]
        let x = x.reshape([b, h, w]);
        let x = x.reshape([b, h, n_t, ps]);

        // Flatten: [B, H, n_t, ps] -> [B, H * n_t, ps]
        let x = x.reshape([b, h * n_t, ps]);

        // Project: [B, H*n_t, ps] -> [B, H*n_t, embed_dim]
        self.proj.forward(x)
    }
}