brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Profile per-layer costs to find bottlenecks.
use std::time::Instant;
use burn::prelude::*;
use burn::backend::NdArray;
use burn::tensor::activation::{softmax, gelu};

type B = NdArray;

fn main() {
    let d = burn::backend::ndarray::NdArrayDevice::Cpu;
    let _ = brainharmony::init_threads(None);

    let embed = 768usize;
    let heads = 12usize;
    let seq = 7200usize; // 400 ROIs * 18 time patches
    let hdim = embed / heads;
    let mlp_h = embed * 4;
    let iters = 3usize;

    println!("Profile: seq={seq} embed={embed} heads={heads} hdim={hdim} mlp_h={mlp_h}  ({iters} iters)");
    println!();

    let x: Tensor<B, 2> = Tensor::random([seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    let w_qkv: Tensor<B, 2> = Tensor::random([embed, 3*embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = x.clone().matmul(w_qkv.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("QKV proj       [7200,768]@[768,2304]:  {ms:>8.1} ms");

    let q: Tensor<B, 3> = Tensor::random([heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let k: Tensor<B, 3> = Tensor::random([heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = q.clone().matmul(k.clone().transpose()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("Q@K^T          [12,7200,64]@[12,64,7200]: {ms:>8.1} ms");

    let scores: Tensor<B, 3> = Tensor::random([heads, seq, seq], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = softmax(scores.clone(), 2); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("Softmax        [12,7200,7200]:         {ms:>8.1} ms");

    let attn: Tensor<B, 3> = Tensor::random([heads, seq, seq], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let v: Tensor<B, 3> = Tensor::random([heads, seq, hdim], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = attn.clone().matmul(v.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("Attn@V         [12,7200,7200]@[12,7200,64]: {ms:>8.1} ms");

    let w_o: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = x.clone().matmul(w_o.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("Out proj       [7200,768]@[768,768]:   {ms:>8.1} ms");

    let w1: Tensor<B, 2> = Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = x.clone().matmul(w1.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("MLP fc1        [7200,768]@[768,3072]:  {ms:>8.1} ms");

    let h: Tensor<B, 2> = Tensor::random([seq, mlp_h], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = gelu(h.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("GELU           [7200,3072]:            {ms:>8.1} ms");

    let w2: Tensor<B, 2> = Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = h.clone().matmul(w2.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("MLP fc2        [7200,3072]@[3072,768]: {ms:>8.1} ms");

    let ln = burn::nn::LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
    let x3: Tensor<B, 3> = Tensor::random([1, seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let t0 = Instant::now();
    for _ in 0..iters { let _ = ln.forward(x3.clone()); }
    let ms = t0.elapsed().as_secs_f64() * 1000.0 / iters as f64;
    println!("LayerNorm      [1,7200,768]:           {ms:>8.1} ms");

    println!();
    println!("Total for 12 blocks ~ 12 * (QKV + Q@K + Softmax + A@V + OutProj + fc1 + GELU + fc2 + 2*LN)");
}