brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Benchmark a hand-inlined transformer block that minimizes tensor operations.
/// Compares: current Block::forward vs a flat inlined version.
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::{Linear, LayerNorm, LayerNormConfig};

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

const TILE: usize = 1024;

/// Flat inlined block: minimizes reshape/transpose/narrow operations
/// by fusing QKV split and attention into fewer ops.
fn inlined_block(
    x: Tensor<B, 3>,  // [1, N, C]
    norm1: &LayerNorm<B>,
    w_qkv: &Tensor<B, 2>,  // [C, 3C]
    b_qkv: &Tensor<B, 1>,  // [3C]
    w_proj: &Tensor<B, 2>, // [C, C]
    b_proj: &Tensor<B, 1>, // [C]
    norm2: &LayerNorm<B>,
    w_fc1: &Tensor<B, 2>,  // [C, 4C]
    b_fc1: &Tensor<B, 1>,
    w_fc2: &Tensor<B, 2>,  // [4C, C]
    b_fc2: &Tensor<B, 1>,
    scale: f32,
    heads: usize,
    dh: usize,
) -> Tensor<B, 3> {
    let [b, n, c] = x.dims();

    // --- Attention ---
    // LN1 + QKV in one chain (Burn fusion can fuse LN elementwise ops)
    let xn = norm1.forward(x.clone());
    let xn2 = xn.reshape([b * n, c]);

    // QKV projection: [BN, C] @ [C, 3C] + bias -> [BN, 3C]
    let qkv = xn2.matmul(w_qkv.clone()) + b_qkv.clone().unsqueeze_dim::<2>(0);

    // Reshape to [B, N, H, 3Dh] then swap to [B, H, N, 3Dh]
    // Then narrow along last dim — this avoids 3 separate swap_dims
    let qkv = qkv.reshape([b, n, heads, 3 * dh]).swap_dims(1, 2); // [B, H, N, 3Dh]
    let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
    let k = qkv.clone().narrow(3, dh, dh);
    let v = qkv.narrow(3, 2 * dh, dh);

    // Tiled attention
    let k_t = k.transpose();
    let mut tiles: Vec<Tensor<B, 4>> = Vec::new();
    let mut off = 0;
    while off < n {
        let tl = (n - off).min(TILE);
        let qt = q.clone().narrow(2, off, tl);
        tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
        off += tl;
    }
    let attn_out = Tensor::cat(tiles, 2); // [B, H, N, Dh]

    // Reshape back + output projection
    let attn_out = attn_out.swap_dims(1, 2).reshape([b * n, c]);
    let attn_out = attn_out.matmul(w_proj.clone()) + b_proj.clone().unsqueeze_dim::<2>(0);
    let attn_out = attn_out.reshape([b, n, c]);

    // Residual
    let h = x + attn_out;

    // --- MLP ---
    let hn = norm2.forward(h.clone());
    let hn2 = hn.reshape([b * n, c]);

    // fc1 + GELU + fc2
    let mlp = hn2.matmul(w_fc1.clone()) + b_fc1.clone().unsqueeze_dim::<2>(0);
    let mlp = fast_gelu(mlp);
    let mlp = mlp.matmul(w_fc2.clone()) + b_fc2.clone().unsqueeze_dim::<2>(0);
    let mlp = mlp.reshape([b, n, c]);

    h + mlp
}

fn fast_gelu(x: Tensor<B, 2>) -> Tensor<B, 2> {
    let x3 = x.clone() * x.clone() * x.clone();
    let inner = (x3.mul_scalar(0.044715f32) + x.clone()).mul_scalar(0.7978845608f32);
    x.mul_scalar(0.5f32) * (inner.tanh() + 1.0)
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let mlp_h = embed * 4;
    let scale = (dh as f32).powf(-0.5);
    let warmup = 5;
    let runs = 10;

    println!("Fused block benchmark (wgpu f16)\n");

    // Create weights for inlined version
    let norm1 = LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
    let norm2 = LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
    let w_qkv: Tensor<B, 2> = Tensor::random([embed, 3*embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let b_qkv: Tensor<B, 1> = Tensor::zeros([3*embed], &d);
    let w_proj: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let b_proj: Tensor<B, 1> = Tensor::zeros([embed], &d);
    let w_fc1: Tensor<B, 2> = Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let b_fc1: Tensor<B, 1> = Tensor::zeros([mlp_h], &d);
    let w_fc2: Tensor<B, 2> = Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let b_fc2: Tensor<B, 1> = Tensor::zeros([embed], &d);

    let x: Tensor<B, 3> = Tensor::random([1, seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // Warmup
    for _ in 0..warmup {
        let out = inlined_block(
            x.clone(), &norm1, &w_qkv, &b_qkv, &w_proj, &b_proj,
            &norm2, &w_fc1, &b_fc1, &w_fc2, &b_fc2, scale, heads, dh,
        );
        let _ = out.into_data();
    }

    // Bench inlined
    let mut times = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let out = inlined_block(
            x.clone(), &norm1, &w_qkv, &b_qkv, &w_proj, &b_proj,
            &norm2, &w_fc1, &b_fc1, &w_fc2, &b_fc2, scale, heads, dh,
        );
        let _ = out.into_data();
        times.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let med = { let mut s = times.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };

    // Bench standard Block
    let block = brainharmony::model::block::Block::<B>::new(embed, heads, 4.0, true, 1e-6, &d);
    for _ in 0..warmup {
        let _ = block.forward(x.clone(), None).into_data();
    }
    let mut times_std = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let out = block.forward(x.clone(), None);
        let _ = out.into_data();
        times_std.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_std = times_std.iter().cloned().fold(f64::INFINITY, f64::min);
    let med_std = { let mut s = times_std.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };

    println!("Standard Block:  best={best_std:.0}ms  med={med_std:.0}ms  (x12 = {:.0}ms)", best_std * 12.0);
    println!("Inlined Block:   best={best:.0}ms  med={med:.0}ms  (x12 = {:.0}ms)", best * 12.0);
    println!("Speedup:         {:.2}x", best_std / best);
}