brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Multi-Head Self-Attention (burn 0.20.1)
///
/// Uses packed QKV projection and scaled dot-product attention.
/// Tiles queries into chunks of 1024 to improve GPU cache utilization
/// for long sequences (softmax on [H, 1024, N] instead of [H, N, N]).
use burn::nn::Linear;
use burn::prelude::*;
use burn::tensor::activation::softmax;

use crate::model::linear_zeros;

/// Query tile size — 1024 benchmarked as optimal on Apple M4 Pro Metal.
const ATTN_TILE: usize = 1024;

#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
    pub qkv: Linear<B>,
    pub proj: Linear<B>,
    pub num_heads: usize,
    pub head_dim: usize,
    pub scale: f32,
}

impl<B: Backend> Attention<B> {
    pub fn new(dim: usize, num_heads: usize, qkv_bias: bool, device: &B::Device) -> Self {
        let head_dim = dim / num_heads;
        let scale = (head_dim as f64).powf(-0.5) as f32;
        Self {
            qkv: linear_zeros(dim, dim * 3, qkv_bias, device),
            proj: linear_zeros(dim, dim, true, device),
            num_heads,
            head_dim,
            scale,
        }
    }

    /// x: [B, N, C] -> [B, N, C]
    /// attn_mask: optional [B, N] binary mask (1 = attend, 0 = masked)
    pub fn forward(&self, x: Tensor<B, 3>, attn_mask: Option<&Tensor<B, 2>>) -> Tensor<B, 3> {
        let [b, n, c] = x.dims();
        let h = self.num_heads;
        let dh = self.head_dim;

        // QKV: [B, N, 3*C] -> [B, N, 3, H, Dh]
        let qkv = self.qkv.forward(x);
        let qkv = qkv.reshape([b, n, 3, h, dh]);

        let q = qkv.clone().narrow(2, 0, 1).reshape([b, n, h, dh]);
        let k = qkv.clone().narrow(2, 1, 1).reshape([b, n, h, dh]);
        let v = qkv.narrow(2, 2, 1).reshape([b, n, h, dh]);

        // Transpose to [B, H, N, Dh]
        let q = q.swap_dims(1, 2).mul_scalar(self.scale);
        let k = k.swap_dims(1, 2);
        let v = v.swap_dims(1, 2);

        // Tiled attention: split Q into tiles so softmax works on
        // [B, H, tile, N] instead of the full [B, H, N, N].
        let k_t = k.transpose(); // [B, H, Dh, N]

        // Pre-expand mask if needed
        let mask_4d = attn_mask.map(|mask| {
            let m = mask.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(2);
            let neg_inf = Tensor::<B, 4>::full([b, 1, 1, n], -1e9, &m.device());
            (m, neg_inf)
        });

        let mut tiles: Vec<Tensor<B, 4>> = Vec::with_capacity((n + ATTN_TILE - 1) / ATTN_TILE);
        let mut offset = 0;
        while offset < n {
            let tile_len = (n - offset).min(ATTN_TILE);
            let q_tile = q.clone().narrow(2, offset, tile_len);
            let mut scores = q_tile.matmul(k_t.clone()); // [B, H, tile, N]

            if let Some((ref m, ref neg_inf)) = mask_4d {
                scores = scores * m.clone() + neg_inf.clone() * (m.clone().mul_scalar(-1.0) + 1.0);
            }

            tiles.push(softmax(scores, 3).matmul(v.clone()));
            offset += tile_len;
        }

        let out = Tensor::cat(tiles, 2); // [B, H, N, Dh]
        let out = out.swap_dims(1, 2).reshape([b, n, c]);
        self.proj.forward(out)
    }
}