brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Benchmark tiled attention vs naive on wgpu.
///
/// cargo run --example bench_fused --release --no-default-features --features wgpu-f16
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f16";
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f32";
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
    pub const NAME: &str = "NdArray";
}

use backend::{B, device};

fn bench<F: FnMut()>(label: &str, warmup: usize, runs: usize, mut f: F) -> f64 {
    for _ in 0..warmup { f(); }
    let mut times = Vec::with_capacity(runs);
    for _ in 0..runs {
        let t0 = Instant::now();
        f();
        times.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
    println!("  {label:35} best={best:>7.1}ms  avg={avg:>7.1}ms  [{runs} runs]");
    best
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let heads = 12usize;
    let seq = 7200usize;
    let hdim = 64usize;
    let scale = 1.0 / (hdim as f32).sqrt();
    let warmup = 5;
    let runs = 10;

    println!("Tiled attention benchmark: {}", backend::NAME);
    println!("  Shape: [1, {heads}, {seq}, {hdim}]");
    println!();

    let q: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let k: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let v: Tensor<B, 4> = Tensor::random([1, heads, seq, hdim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // Pre-scale Q
    let q_scaled = q.clone().mul_scalar(scale);

    // 1. Naive: full NxN matrix
    let t_naive = bench("Naive (full 7200x7200)", warmup, runs, || {
        let s = q_scaled.clone().matmul(k.clone().transpose());
        let a = softmax(s, 3);
        let out = a.matmul(v.clone());
        let _ = out.into_data();
    });

    // 2-6. Tiled at various tile sizes
    for tile in [256, 512, 1024, 1800, 3600] {
        let label = format!("Tiled (tile={})", tile);
        let _ = bench(&label, warmup, runs, || {
            let k_t = k.clone().transpose();
            let mut tiles = Vec::new();
            let mut off = 0;
            while off < seq {
                let tl = (seq - off).min(tile);
                let qt = q_scaled.clone().narrow(2, off, tl);
                let s = qt.matmul(k_t.clone());
                let a = softmax(s, 3);
                tiles.push(a.matmul(v.clone()));
                off += tl;
            }
            let out = Tensor::<B, 4>::cat(tiles, 2);
            let _ = out.into_data();
        });
    }

    println!();
    println!("Naive baseline: {t_naive:.1}ms  (= {:.1}ms * 12 blocks = {:.0}ms encoder)",
        t_naive, t_naive * 12.0);
}