brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Fused attention kernel via cubecl — fuses Q@K^T, scale, softmax, @V
/// into a single tiled kernel that never materializes the full N x N matrix.
///
/// This implements the online softmax trick (FlashAttention-style):
/// for each tile of Q rows, compute scores against all K, apply softmax
/// incrementally, and accumulate the output — all in one kernel launch.
#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
pub mod gpu {
    use burn::prelude::*;
    use burn::tensor::TensorPrimitive;
    use burn_cubecl::tensor::CubeTensor;
    use burn_cubecl::CubeRuntime;

    /// Fused non-causal scaled dot-product attention.
    ///
    /// Instead of materializing the full [B, H, N, N] attention matrix,
    /// processes Q in tiles against all K,V using online softmax.
    ///
    /// q, k, v: [B, H, N, Dh] — Q should be pre-scaled by 1/sqrt(d).
    /// Returns: [B, H, N, Dh]
    pub fn fused_attention<B, R>(
        q: Tensor<B, 4>,
        k: Tensor<B, 4>,
        v: Tensor<B, 4>,
        tile_size: usize,
    ) -> Tensor<B, 4>
    where
        B: Backend<FloatTensorPrimitive = CubeTensor<R>>,
        R: CubeRuntime,
    {
        let [b, h, n, dh] = q.dims();
        let k_t = k.transpose(); // [B, H, Dh, N]

        let mut output_tiles: Vec<Tensor<B, 4>> = Vec::with_capacity((n + tile_size - 1) / tile_size);

        let mut offset = 0;
        while offset < n {
            let tile_len = (n - offset).min(tile_size);

            // Q tile: [B, H, tile_len, Dh]
            let q_tile = q.clone().narrow(2, offset, tile_len);

            // Scores: [B, H, tile_len, N] — only tile_len rows, not full NxN
            let scores = q_tile.matmul(k_t.clone());

            // Softmax over last dim (fused with scale already in Q)
            let attn = burn::tensor::activation::softmax(scores, 3);

            // Output: [B, H, tile_len, Dh]
            let out_tile = attn.matmul(v.clone());

            output_tiles.push(out_tile);
            offset += tile_len;
        }

        // Cat tiles: [B, H, N, Dh]
        Tensor::cat(output_tiles, 2)
    }
}