brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Compare tile=1024 (7 tiles) vs tile=7200 (1 tile = no tiling) vs tile=2400 (3 tiles)
/// to measure tile dispatch overhead vs memory pressure tradeoff.
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::{LayerNormConfig};

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn bench_12_blocks(
    x_in: &Tensor<B, 3>,
    norms: &[(burn::nn::LayerNorm<B>, burn::nn::LayerNorm<B>)],
    weights: &[(Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2>)], // qkv, proj, fc1, fc2
    tile: usize,
    heads: usize,
    dh: usize,
    scale: f32,
    warmup: usize,
    runs: usize,
) -> f64 {
    let seq = x_in.dims()[1];
    let embed = x_in.dims()[2];

    let do_forward = |x: Tensor<B, 3>| -> Tensor<B, 3> {
        let mut x = x;
        for i in 0..12 {
            let [b, n, c] = x.dims();
            // Attention
            let xn = norms[i].0.forward(x.clone()).reshape([b * n, c]);
            let qkv = xn.matmul(weights[i].0.clone()).reshape([b, n, heads, 3 * dh]).swap_dims(1, 2);
            let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
            let k = qkv.clone().narrow(3, dh, dh);
            let v = qkv.narrow(3, 2 * dh, dh);
            let k_t = k.transpose();

            let attn_out = if tile >= n {
                // No tiling — single matmul
                softmax(q.matmul(k_t), 3).matmul(v)
            } else {
                let mut tiles = Vec::new();
                let mut off = 0;
                while off < n {
                    let tl = (n - off).min(tile);
                    let qt = q.clone().narrow(2, off, tl);
                    tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
                    off += tl;
                }
                Tensor::cat(tiles, 2)
            };

            let attn_out = attn_out.swap_dims(1, 2).reshape([b * n, c]);
            let attn_out = attn_out.matmul(weights[i].1.clone()).reshape([b, n, c]);
            x = x + attn_out;

            // MLP
            let hn = norms[i].1.forward(x.clone()).reshape([b * n, c]);
            let mlp = burn::tensor::activation::gelu(hn.matmul(weights[i].2.clone()));
            let mlp = mlp.matmul(weights[i].3.clone()).reshape([b, n, c]);
            x = x + mlp;
        }
        x
    };

    for _ in 0..warmup {
        let _ = do_forward(x_in.clone()).into_data();
    }
    let mut times = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let _ = do_forward(x_in.clone()).into_data();
        times.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    best
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let mlp_h = embed * 4;
    let scale = (dh as f32).powf(-0.5);

    println!("12-block encoder: tile size vs total time (wgpu f16)\n");

    // Create weights
    let mut norms = Vec::new();
    let mut weights = Vec::new();
    for _ in 0..12 {
        norms.push((
            LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
            LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
        ));
        weights.push((
            Tensor::<B, 2>::random([embed, 3 * embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            Tensor::<B, 2>::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            Tensor::<B, 2>::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
            Tensor::<B, 2>::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
        ));
    }

    let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    for tile in [7200, 3600, 2400, 1800, 1200, 1024, 720] {
        let n_tiles = (seq + tile - 1) / tile;
        let dispatches_per_block = 4 + n_tiles * 3; // qkv + tiles*(matmul+softmax+matmul) + proj + fc1 + gelu + fc2 + norms
        let t = bench_12_blocks(&x, &norms, &weights, tile, heads, dh, scale, 3, 5);
        println!("  tile={tile:5}  ({n_tiles:2} tiles, ~{dispatches_per_block:2} dispatches/block): {t:>7.0}ms");
    }
}