brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Positional Embeddings for Brain-Harmony (burn 0.20.1)
///
/// Supports two modes:
/// 1. "gradient_geoh" — Brain gradient + geometric harmonics projection
///    - Height (ROI): fixed sincos positional encoding
///    - Width (temporal): learned projection from gradient + geoh coordinates
///    - Python: `BrainGradient_GeometricHarmonics_Anatomical_400_PosEmbed`
///
/// 2. "sincos" — Standard 2D sine-cosine positional encoding
///    - Both dimensions use fixed sincos on grid indices
///    - Python: `SineCosine_PosEmbed`
///
/// Output: (encoder_pos_embed, decoder_pos_embed) both [1, N, embed_dim]
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;

use crate::error::BrainHarmonyError;
use crate::model::linear_zeros;

#[derive(Module, Debug)]
pub struct BrainHarmonyPosEmbed<B: Backend> {
    /// Fixed sincos embeddings for the height (ROI) dimension: [H*W, D/2]
    pub emb_h: Param<Tensor<B, 2>>,
    /// Gradient projection: gradient coords -> embed_dim/2 (for 'gradient_geoh' mode)
    pub grad_proj: Option<Linear<B>>,
    /// Geometric harmonics projection: geoh coords -> embed_dim/2 (for 'gradient_geoh' mode)
    pub geoh_proj: Option<Linear<B>>,
    /// Fixed sincos for width dimension (for 'sincos' mode): [H*W, D/2]
    pub emb_w: Option<Param<Tensor<B, 2>>>,
    /// Decoder height sincos: [H*W, pred_D/2]
    pub emb_h_decoder: Option<Param<Tensor<B, 2>>>,
    /// Decoder width projection: embed_dim/2 -> pred_dim/2 (for 'gradient_geoh')
    pub decoder_pos_embed_proj: Option<Linear<B>>,
    /// Decoder width sincos (for 'sincos' mode): [H*W, pred_D/2]
    pub emb_w_decoder: Option<Param<Tensor<B, 2>>>,
    pub embed_dim: usize,
    pub grid_h: usize,
    pub grid_w: usize,
    pub mode: String,
    pub use_cls_token: bool,
    pub use_decoder: bool,
}

impl<B: Backend> BrainHarmonyPosEmbed<B> {
    pub fn new(
        grad_dim: usize,
        geoh_dim: usize,
        embed_dim: usize,
        pred_embed_dim: usize,
        grid_size: (usize, usize),
        mode: &str,
        use_cls_token: bool,
        use_decoder: bool,
        device: &B::Device,
    ) -> crate::error::Result<Self> {
        let (gh, gw) = grid_size;
        let n = gh * gw;
        let half_dim = embed_dim / 2;

        // Height (ROI) positional embeddings: fixed sincos
        let emb_h_data = sincos_1d_grid(half_dim, gh, gw);
        let emb_h = Param::initialized(
            ParamId::new(),
            Tensor::<B, 2>::from_data(TensorData::new(emb_h_data, vec![n, half_dim]), device),
        );

        let (grad_proj, geoh_proj, emb_w) = match mode {
            "gradient_geoh" => {
                let gp = linear_zeros(grad_dim, half_dim, true, device);
                let ghp = linear_zeros(geoh_dim, half_dim, true, device);
                (Some(gp), Some(ghp), None)
            }
            "sincos" => {
                let emb_w_data = sincos_1d_width(half_dim, gh, gw);
                let t = Param::initialized(
                    ParamId::new(),
                    Tensor::<B, 2>::from_data(
                        TensorData::new(emb_w_data, vec![n, half_dim]),
                        device,
                    ),
                );
                (None, None, Some(t))
            }
            _ => {
                return Err(BrainHarmonyError::InvalidPosMode {
                    mode: mode.to_string(),
                })
            }
        };

        // Decoder embeddings (optional)
        let (emb_h_decoder, decoder_pos_embed_proj, emb_w_decoder) = if use_decoder {
            let pred_half = pred_embed_dim / 2;
            let emb_h_dec_data = sincos_1d_grid(pred_half, gh, gw);
            let emb_h_dec = Param::initialized(
                ParamId::new(),
                Tensor::<B, 2>::from_data(TensorData::new(emb_h_dec_data, vec![n, pred_half]), device),
            );
            match mode {
                "gradient_geoh" => {
                    let proj = linear_zeros(half_dim, pred_half, true, device);
                    (Some(emb_h_dec), Some(proj), None)
                }
                "sincos" => {
                    let emb_w_dec_data = sincos_1d_width(pred_half, gh, gw);
                    let t = Param::initialized(
                        ParamId::new(),
                        Tensor::<B, 2>::from_data(
                            TensorData::new(emb_w_dec_data, vec![n, pred_half]),
                            device,
                        ),
                    );
                    (Some(emb_h_dec), None, Some(t))
                }
                _ => (None, None, None),
            }
        } else {
            (None, None, None)
        };

        Ok(Self {
            emb_h,
            grad_proj,
            geoh_proj,
            emb_w,
            emb_h_decoder,
            decoder_pos_embed_proj,
            emb_w_decoder,
            embed_dim,
            grid_h: gh,
            grid_w: gw,
            mode: mode.to_string(),
            use_cls_token,
            use_decoder,
        })
    }

    /// Compute position embeddings.
    ///
    /// gradient: [n_rois, grad_dim] brain gradient coordinates
    /// geoh: [n_rois, geoh_dim] geometric harmonics coordinates
    ///
    /// Returns: (encoder_pos_embed, decoder_pos_embed)
    ///   encoder: [1, N, embed_dim] (or [1, N+1, embed_dim] with CLS token)
    ///   decoder: [1, N, pred_dim] or None
    pub fn forward(
        &self,
        gradient: Option<&Tensor<B, 2>>,
        geoh: Option<&Tensor<B, 2>>,
    ) -> (Tensor<B, 3>, Option<Tensor<B, 3>>) {
        let emb_w = if self.mode == "gradient_geoh" {
            let grad = gradient.expect("BUG: gradient tensor required for gradient_geoh mode");
            let geoh_data = geoh.expect("BUG: geoh tensor required for gradient_geoh mode");
            let grad_proj = self.grad_proj.as_ref().unwrap();
            let geoh_proj = self.geoh_proj.as_ref().unwrap();

            let grad_emb = grad_proj.forward(grad.clone()); // [n_rois, D/2]
            let geoh_emb = geoh_proj.forward(geoh_data.clone()); // [n_rois, D/2]

            // Average gradient and geometric harmonics projections
            let pos_embed = (grad_emb + geoh_emb).mul_scalar(0.5f32);
            // Repeat for each time patch
            let repeated = repeat_interleave_dim0(pos_embed, self.grid_w);
            // Normalize to [-1, 1]
            let min_val: f32 = repeated.clone().min().into_scalar().elem();
            let max_val: f32 = repeated.clone().max().into_scalar().elem();
            let range = (max_val - min_val).max(1e-8);
            repeated
                .sub_scalar(min_val)
                .div_scalar(range)
                .mul_scalar(2.0f32)
                .sub_scalar(1.0f32)
        } else {
            self.emb_w
                .as_ref()
                .expect("BUG: emb_w missing in sincos mode")
                .val()
        };

        // Encoder: [H*W, D]
        let emb_encoder = Tensor::cat(vec![self.emb_h.val(), emb_w.clone()], 1).unsqueeze_dim::<3>(0);

        // Optionally prepend zero CLS token position
        let pos_embed_encoder = if self.use_cls_token {
            let [_, _n, d] = emb_encoder.dims();
            let cls_zeros = Tensor::<B, 3>::zeros([1, 1, d], &emb_encoder.device());
            Tensor::cat(vec![cls_zeros, emb_encoder], 1)
        } else {
            emb_encoder
        };

        // Decoder position embeddings
        let pos_embed_decoder = if self.use_decoder {
            let emb_h_dec = self.emb_h_decoder.as_ref().unwrap().val();
            let emb_w_dec = if self.mode == "gradient_geoh" {
                let proj = self.decoder_pos_embed_proj.as_ref().unwrap();
                proj.forward(emb_w)
            } else {
                self.emb_w_decoder.as_ref().unwrap().val()
            };
            let emb_decoder = Tensor::cat(vec![emb_h_dec, emb_w_dec], 1).unsqueeze_dim::<3>(0);
            let dec = if self.use_cls_token {
                let [_, _n, d] = emb_decoder.dims();
                let cls_zeros = Tensor::<B, 3>::zeros([1, 1, d], &emb_decoder.device());
                Tensor::cat(vec![cls_zeros, emb_decoder], 1)
            } else {
                emb_decoder
            };
            Some(dec)
        } else {
            None
        };

        (pos_embed_encoder, pos_embed_decoder)
    }
}

/// Repeat each row of a 2D tensor `repeats` times along dim 0.
fn repeat_interleave_dim0<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
    let [n, d] = t.dims();
    t.unsqueeze_dim::<3>(1)
        .expand([n, repeats, d])
        .reshape([n * repeats, d])
}

/// Generate 1D sincos positional embeddings for the height (ROI) dimension.
fn sincos_1d_grid(half_dim: usize, grid_h: usize, grid_w: usize) -> Vec<f32> {
    let n = grid_h * grid_w;
    let quarter = half_dim / 2;
    let mut data = vec![0.0f32; n * half_dim];

    for h in 0..grid_h {
        for w in 0..grid_w {
            let pos = h as f64;
            let idx = h * grid_w + w;
            for k in 0..quarter {
                let omega = 1.0 / 10000.0_f64.powf(k as f64 / quarter as f64);
                let angle = pos * omega;
                data[idx * half_dim + k] = angle.sin() as f32;
                data[idx * half_dim + quarter + k] = angle.cos() as f32;
            }
        }
    }
    data
}

/// Generate 1D sincos positional embeddings for the width (temporal) dimension.
fn sincos_1d_width(half_dim: usize, grid_h: usize, grid_w: usize) -> Vec<f32> {
    let n = grid_h * grid_w;
    let quarter = half_dim / 2;
    let mut data = vec![0.0f32; n * half_dim];

    for h in 0..grid_h {
        for w in 0..grid_w {
            let pos = w as f64;
            let idx = h * grid_w + w;
            for k in 0..quarter {
                let omega = 1.0 / 10000.0_f64.powf(k as f64 / quarter as f64);
                let angle = pos * omega;
                data[idx * half_dim + k] = angle.sin() as f32;
                data[idx * half_dim + quarter + k] = angle.cos() as f32;
            }
        }
    }
    data
}

/// Generate 1D sincos positional embeddings for a flat sequence.
/// Used for latent tokens and OneTokRegViT positional embeddings.
pub fn sincos_1d_flat(embed_dim: usize, n_positions: usize) -> Vec<f32> {
    let half = embed_dim / 2;
    let mut data = vec![0.0f32; n_positions * embed_dim];

    for pos in 0..n_positions {
        for k in 0..half {
            let omega = 1.0 / 10000.0_f64.powf(k as f64 / half as f64);
            let angle = pos as f64 * omega;
            data[pos * embed_dim + k] = angle.sin() as f32;
            data[pos * embed_dim + half + k] = angle.cos() as f32;
        }
    }
    data
}