brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// A/B test: same computation, module-based vs raw tensors.
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn bench<F: FnMut()>(label: &str, warmup: usize, runs: usize, mut f: F) -> f64 {
    for _ in 0..warmup { f(); }
    let mut times = Vec::new();
    for _ in 0..runs { let t0 = Instant::now(); f(); times.push(t0.elapsed().as_secs_f64() * 1000.0); }
    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let med = { let mut s = times.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
    println!("  {label:40} best={best:>7.0}ms  med={med:>7.0}ms");
    best
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let mlp_h = embed * 4;
    let scale = (dh as f32).powf(-0.5);
    let warmup = 20;
    let runs = 10;

    println!("Module vs Raw comparison (20 warmup, 10 runs)\n");

    let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // --- Module-based block ---
    let block = brainharmony::model::block::Block::<B>::new(embed, heads, 4.0, true, 1e-6, &d);
    let t_module = bench("1 block (Module)", warmup, runs, || {
        let _ = block.forward(x.clone(), None).into_data();
    });

    // --- Raw block (same math, no Module/Param/Linear) ---
    let w_qkv: Tensor<B, 2> = Tensor::random([embed, 3*embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let w_proj: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let w_fc1: Tensor<B, 2> = Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let w_fc2: Tensor<B, 2> = Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let ln1 = burn::nn::LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
    let ln2 = burn::nn::LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);

    let t_raw = bench("1 block (Raw matmul, no bias)", warmup, runs, || {
        let xn = ln1.forward(x.clone()).reshape([seq, embed]);
        let qkv = xn.matmul(w_qkv.clone()).reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
        let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
        let k = qkv.clone().narrow(3, dh, dh);
        let v = qkv.narrow(3, 2 * dh, dh);
        let k_t = k.transpose();
        let attn = softmax(q.matmul(k_t), 3).matmul(v);
        let attn = attn.swap_dims(1, 2).reshape([seq, embed]).matmul(w_proj.clone()).reshape([1, seq, embed]);
        let h = x.clone() + attn;
        let hn = ln2.forward(h.clone()).reshape([seq, embed]);
        let mlp = burn::tensor::activation::gelu(hn.matmul(w_fc1.clone()))
            .matmul(w_fc2.clone()).reshape([1, seq, embed]);
        let _ = (h + mlp).into_data();
    });

    // --- Raw block with tiling ---
    let t_raw_tiled = bench("1 block (Raw, tiled attn 1024)", warmup, runs, || {
        let xn = ln1.forward(x.clone()).reshape([seq, embed]);
        let qkv = xn.matmul(w_qkv.clone()).reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
        let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
        let k = qkv.clone().narrow(3, dh, dh);
        let v = qkv.narrow(3, 2 * dh, dh);
        let k_t = k.transpose();
        let mut tiles = Vec::new();
        let mut off = 0;
        while off < seq {
            let tl = (seq - off).min(1024);
            let qt = q.clone().narrow(2, off, tl);
            tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
            off += tl;
        }
        let attn = Tensor::<B, 4>::cat(tiles, 2)
            .swap_dims(1, 2).reshape([seq, embed]).matmul(w_proj.clone()).reshape([1, seq, embed]);
        let h = x.clone() + attn;
        let hn = ln2.forward(h.clone()).reshape([seq, embed]);
        let mlp = burn::tensor::activation::gelu(hn.matmul(w_fc1.clone()))
            .matmul(w_fc2.clone()).reshape([1, seq, embed]);
        let _ = (h + mlp).into_data();
    });

    println!();
    println!("Module vs Raw speedup: {:.2}x", t_module / t_raw);
    println!("Module vs Raw+Tiled:   {:.2}x", t_module / t_raw_tiled);
    println!();
    println!("Projected 12-block:");
    println!("  Module:     {:.0}ms", t_module * 12.0);
    println!("  Raw:        {:.0}ms", t_raw * 12.0);
    println!("  Raw+Tiled:  {:.0}ms", t_raw_tiled * 12.0);
}