brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Classification head for downstream tasks.
///
/// Brain-Harmony stage 2 uses a 3-layer MLP head (MLPHead) instead of
/// a single linear layer. Also supports a simple linear head.
///
/// Pipeline:
///   embeddings = encoder(signal)         -> [B, N, embed_dim]
///   pooled     = mean(embeddings, 1)     -> [B, embed_dim]
///   normed     = LayerNorm(pooled)       -> [B, embed_dim]
///   logits     = head(normed)            -> [B, num_classes]
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;

use crate::model::linear_zeros;

/// 3-layer MLP classification head (Python: `MLPHead` in stage2 models.py).
///
/// Architecture:
///   Linear(in, hidden) -> ReLU -> Linear(hidden, hidden) -> ReLU -> Linear(hidden, out)
#[derive(Module, Debug)]
pub struct MLPHead<B: Backend> {
    pub lin1: Linear<B>,
    pub lin2: Linear<B>,
    pub lin3: Linear<B>,
}

impl<B: Backend> MLPHead<B> {
    pub fn new(
        in_features: usize,
        hidden_dim: usize,
        out_features: usize,
        device: &B::Device,
    ) -> Self {
        Self {
            lin1: linear_zeros(in_features, hidden_dim, true, device),
            lin2: linear_zeros(hidden_dim, hidden_dim, true, device),
            lin3: linear_zeros(hidden_dim, out_features, true, device),
        }
    }

    /// x: [B, in_features] -> [B, out_features]
    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.lin1.forward(x);
        let x = burn::tensor::activation::relu(x);
        let x = self.lin2.forward(x);
        let x = burn::tensor::activation::relu(x);
        self.lin3.forward(x)
    }
}

/// Linear classification head with global average pooling.
#[derive(Module, Debug)]
pub struct ClassificationHead<B: Backend> {
    pub fc_norm: burn::nn::LayerNorm<B>,
    pub head: Linear<B>,
    pub num_classes: usize,
}

impl<B: Backend> ClassificationHead<B> {
    /// Create a new linear classification head.
    pub fn new(embed_dim: usize, num_classes: usize, device: &B::Device) -> Self {
        Self {
            fc_norm: burn::nn::LayerNormConfig::new(embed_dim)
                .with_epsilon(1e-6)
                .init(device),
            head: linear_zeros(embed_dim, num_classes, true, device),
            num_classes,
        }
    }

    /// Forward pass: pool encoder output and classify.
    ///
    /// encoder_output: [B, N, embed_dim]
    /// Returns: [B, num_classes] logits
    pub fn forward(&self, encoder_output: Tensor<B, 3>) -> Tensor<B, 2> {
        let [b, _n, d] = encoder_output.dims();
        let pooled = encoder_output.mean_dim(1).reshape([b, d]);
        let normed = self.fc_norm.forward(pooled);
        self.head.forward(normed)
    }

    /// Load classification head weights from a weight map.
    pub fn load_weights(
        &mut self,
        wm: &mut crate::weights::WeightMap,
        prefix: &str,
        device: &<B as Backend>::Device,
    ) -> anyhow::Result<()> {
        if wm.has(&format!("{prefix}.head.weight")) {
            let w: Tensor<B, 2> = wm.take(&format!("{prefix}.head.weight"), device)?;
            self.head.weight = Param::initialized(ParamId::new(), w.transpose());
        }
        if wm.has(&format!("{prefix}.head.bias")) {
            let b: Tensor<B, 1> = wm.take(&format!("{prefix}.head.bias"), device)?;
            self.head.bias = Some(Param::initialized(ParamId::new(), b));
        }
        if wm.has(&format!("{prefix}.fc_norm.weight")) {
            let w: Tensor<B, 1> = wm.take(&format!("{prefix}.fc_norm.weight"), device)?;
            self.fc_norm.gamma = Param::initialized(ParamId::new(), w);
        }
        if wm.has(&format!("{prefix}.fc_norm.bias")) {
            let b: Tensor<B, 1> = wm.take(&format!("{prefix}.fc_norm.bias"), device)?;
            self.fc_norm.beta = Some(Param::initialized(ParamId::new(), b));
        }
        Ok(())
    }
}

/// Argmax along the last dimension — returns predicted class indices.
pub fn predict_classes<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1, Int> {
    let [b, _c] = logits.dims();
    logits.argmax(1).reshape([b])
}