brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Profile individual ops on wgpu to find fusion opportunities.
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f16";
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f32";
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
    pub const NAME: &str = "NdArray CPU";
}

use backend::{B, device};

fn bench<F: FnMut()>(label: &str, warmup: usize, runs: usize, mut f: F) {
    for _ in 0..warmup { f(); }
    let mut times = Vec::with_capacity(runs);
    for _ in 0..runs {
        let t0 = Instant::now();
        f();
        times.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
    println!("  {label:30} best={best:>8.1}ms  avg={avg:>8.1}ms");
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let embed = 768usize;
    let heads = 12usize;
    let seq = 7200usize;
    let hdim = embed / heads;
    let mlp_h = embed * 4;
    let warmup = 3;
    let runs = 5;

    println!("Per-op GPU profile: {} seq={seq} embed={embed}", backend::NAME);
    println!();

    // LayerNorm
    let ln = burn::nn::LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
    let x3: Tensor<B, 3> = Tensor::random([1, seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("LayerNorm [1,7200,768]", warmup, runs, || {
        let out = ln.forward(x3.clone());
        let _ = out.into_data();
    });

    // QKV projection
    let w_qkv: Tensor<B, 2> = Tensor::random([embed, 3*embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let x2: Tensor<B, 2> = Tensor::random([seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("QKV matmul [7200,768]@[768,2304]", warmup, runs, || {
        let out = x2.clone().matmul(w_qkv.clone());
        let _ = out.into_data();
    });

    // Q@K^T
    let q: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let k: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("Q@K^T [1,12,7200,64]@[..64,7200]", warmup, runs, || {
        let out = q.clone().matmul(k.clone().transpose());
        let _ = out.into_data();
    });

    // Softmax
    let scores: Tensor<B, 4> = Tensor::random([1, heads, seq, seq], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("Softmax [1,12,7200,7200]", warmup, runs, || {
        let out = softmax(scores.clone(), 3);
        let _ = out.into_data();
    });

    // Attn@V
    let attn: Tensor<B, 4> = Tensor::random([1, heads, seq, seq], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let v: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("Attn@V [1,12,7200,7200]@[..7200,64]", warmup, runs, || {
        let out = attn.clone().matmul(v.clone());
        let _ = out.into_data();
    });

    // Out projection
    let w_o: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    bench("OutProj [7200,768]@[768,768]", warmup, runs, || {
        let out = x2.clone().matmul(w_o.clone());
        let _ = out.into_data();
    });

    // MLP fc1
    let w1: Tensor<B, 2> = Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    bench("MLP fc1 [7200,768]@[768,3072]", warmup, runs, || {
        let out = x2.clone().matmul(w1.clone());
        let _ = out.into_data();
    });

    // GELU
    let h: Tensor<B, 2> = Tensor::random([seq, mlp_h], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    bench("GELU [7200,3072]", warmup, runs, || {
        let out = burn::tensor::activation::gelu(h.clone());
        let _ = out.into_data();
    });

    // MLP fc2
    let w2: Tensor<B, 2> = Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    bench("MLP fc2 [7200,3072]@[3072,768]", warmup, runs, || {
        let out = h.clone().matmul(w2.clone());
        let _ = out.into_data();
    });

    // Residual add
    bench("Residual add [1,7200,768]", warmup, runs, || {
        let out = x3.clone() + x3.clone();
        let _ = out.into_data();
    });

    // Full naive attention (Q@K->softmax->@V)
    bench("Full attn (naive fused)", warmup, runs, || {
        let s = q.clone().matmul(k.clone().transpose()).mul_scalar(1.0f32 / 8.0);
        let a = softmax(s, 3);
        let out = a.matmul(v.clone());
        let _ = out.into_data();
    });

    println!();
    println!("12-block estimate = 12 * (2*LN + QKV + Attn + OutProj + fc1 + GELU + fc2 + 2*ResAdd)");
}