brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Test if GPU pipeline caching helps: measure convergence over many runs.
/// Also test: does unrolling the block loop let Burn's fusion optimize better?
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::LayerNormConfig;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

const TILE: usize = 1024;

struct BlockWeights {
    norm1: burn::nn::LayerNorm<B>,
    norm2: burn::nn::LayerNorm<B>,
    w_qkv: Tensor<B, 2>,
    w_proj: Tensor<B, 2>,
    w_fc1: Tensor<B, 2>,
    w_fc2: Tensor<B, 2>,
}

/// Run 12 blocks with pre-transposed weights and minimal allocations.
/// Stores K^T pre-computed to avoid redundant transpose.
fn forward_12_cached(
    x: Tensor<B, 3>,
    blocks: &[BlockWeights],
    heads: usize,
    dh: usize,
    scale: f32,
) -> Tensor<B, 3> {
    let mut x = x;

    for blk in blocks {
        let [b, n, c] = x.dims();

        // ---- Attention ----
        let xn = blk.norm1.forward(x.clone()).reshape([b * n, c]);

        // QKV: fused projection, single reshape+swap, narrow along last dim
        let qkv = xn.matmul(blk.w_qkv.clone())
            .reshape([b, n, heads, 3 * dh])
            .swap_dims(1, 2); // [B, H, N, 3Dh]

        let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
        let k = qkv.clone().narrow(3, dh, dh);
        let v = qkv.narrow(3, 2 * dh, dh);

        // Compute K^T once, reuse across tiles
        let k_t = k.transpose();

        // Tiled attention — same tile reuses k_t and v without clone
        let attn_out = if n <= TILE {
            softmax(q.matmul(k_t), 3).matmul(v)
        } else {
            let mut tiles = Vec::new();
            let mut off = 0;
            while off < n {
                let tl = (n - off).min(TILE);
                let qt = q.clone().narrow(2, off, tl);
                tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
                off += tl;
            }
            Tensor::cat(tiles, 2)
        };

        // Output projection + residual
        let attn_out = attn_out.swap_dims(1, 2).reshape([b * n, c])
            .matmul(blk.w_proj.clone())
            .reshape([b, n, c]);
        x = x + attn_out;

        // ---- MLP ----
        let hn = blk.norm2.forward(x.clone()).reshape([b * n, c]);
        let h = burn::tensor::activation::gelu(hn.matmul(blk.w_fc1.clone()));
        let mlp = h.matmul(blk.w_fc2.clone()).reshape([b, n, c]);
        x = x + mlp;
    }

    x
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let mlp_h = embed * 4;
    let scale = (dh as f32).powf(-0.5);

    println!("GPU cache convergence test\n");

    // Create blocks with pre-transposed weights (no bias for simplicity)
    let blocks: Vec<BlockWeights> = (0..12).map(|_| {
        BlockWeights {
            norm1: LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
            norm2: LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
            // Store weights already transposed: [out, in] -> [in, out]
            w_qkv: Tensor::random([embed, 3 * embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            w_proj: Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            w_fc1: Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            w_fc2: Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
        }
    }).collect();

    let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // Run 30 iterations and track convergence
    println!("  Run    Time    Delta");
    println!("  ---    ----    -----");
    let mut prev = 0.0f64;
    for i in 0..30 {
        let t0 = Instant::now();
        let out = forward_12_cached(x.clone(), &blocks, heads, dh, scale);
        let _ = out.into_data(); // sync
        let ms = t0.elapsed().as_secs_f64() * 1000.0;
        let delta = if i == 0 { 0.0 } else { ms - prev };
        println!("  {i:3}    {ms:>7.0}ms  {delta:>+7.0}ms");
        prev = ms;
    }
}