brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Benchmark with non-causal flash attention kernel (cubek).
///
/// ```sh
/// cargo run --example bench_flash --release --no-default-features --features wgpu
/// cargo run --example bench_flash --release --no-default-features --features wgpu-f16
/// ```
use std::time::Instant;
use burn::prelude::*;

#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
use brainharmony::model::flash_attn::gpu::flash_attention_tensor;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f16 (Metal)";
}

#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
    pub const NAME: &str = "wgpu f32 (Metal)";
}

#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
    pub const NAME: &str = "NdArray";
}

use backend::{B, device};

#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
fn main() {
    eprintln!("bench_flash requires --features wgpu or wgpu-f16");
    std::process::exit(1);
}

#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
fn main() {
    let dev = device();
    let n_threads = brainharmony::init_threads(None);

    let embed_dim = 768usize;
    let num_heads = 12usize;
    let head_dim = embed_dim / num_heads;
    let seq = 7200usize;
    let n_warmup = 5usize;
    let n_runs = 10usize;

    println!("Flash Attention Benchmark (non-causal, cubek)");
    println!("  Backend    : {} ({n_threads} CPU threads)", backend::NAME);
    println!("  Heads      : {num_heads}  head_dim: {head_dim}");
    println!("  Seq length : {seq}");
    println!("  Shape      : [1, {num_heads}, {seq}, {head_dim}]");
    println!("  Warmup     : {n_warmup}  Runs: {n_runs}");
    println!();

    let q: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &dev);
    let k: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &dev);
    let v: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
        burn::tensor::Distribution::Normal(0.0, 1.0), &dev);

    // Warmup
    println!("Warming up...");
    for _ in 0..n_warmup {
        let out = flash_attention_tensor::<B, _>(q.clone(), k.clone(), v.clone());
        let _ = out.into_data();
    }

    // Timed runs
    println!("Benchmarking flash attention...");
    let mut times = Vec::with_capacity(n_runs);
    for _ in 0..n_runs {
        let t0 = Instant::now();
        let out = flash_attention_tensor::<B, _>(q.clone(), k.clone(), v.clone());
        let _ = out.into_data(); // sync
        let ms = t0.elapsed().as_secs_f64() * 1000.0;
        times.push(ms);
    }

    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
    let median = {
        let mut s = times.clone();
        s.sort_by(|a, b| a.partial_cmp(b).unwrap());
        s[s.len() / 2]
    };

    println!();
    println!("Results ({n_runs} runs):");
    println!("  Times (ms) : {}", times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
    println!("  Best       : {best:.0} ms");
    println!("  Median     : {median:.0} ms");
    println!("  Average    : {avg:.0} ms");

    // Compare: also bench the naive matmul->softmax->matmul path
    println!();
    println!("Benchmarking naive attention (matmul->softmax->matmul)...");
    for _ in 0..n_warmup {
        let scores = q.clone().matmul(k.clone().transpose()).mul_scalar(1.0 / (head_dim as f32).sqrt());
        let attn = burn::tensor::activation::softmax(scores, 3);
        let out = attn.matmul(v.clone());
        let _ = out.into_data();
    }
    let mut naive_times = Vec::with_capacity(n_runs);
    for _ in 0..n_runs {
        let t0 = Instant::now();
        let scores = q.clone().matmul(k.clone().transpose()).mul_scalar(1.0 / (head_dim as f32).sqrt());
        let attn = burn::tensor::activation::softmax(scores, 3);
        let out = attn.matmul(v.clone());
        let _ = out.into_data();
        let ms = t0.elapsed().as_secs_f64() * 1000.0;
        naive_times.push(ms);
    }

    let naive_best = naive_times.iter().cloned().fold(f64::INFINITY, f64::min);
    let naive_avg: f64 = naive_times.iter().sum::<f64>() / naive_times.len() as f64;
    println!("  Times (ms) : {}", naive_times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
    println!("  Best       : {naive_best:.0} ms");
    println!("  Average    : {naive_avg:.0} ms");
    println!();
    println!("Speedup (flash / naive): {:.2}x", naive_best / best);
}