brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Measure raw GPU kernel dispatch overhead.
/// How much does each tensor op cost just to launch?
use std::time::Instant;
use burn::prelude::*;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    // Measure: how long does a chain of N trivial ops take?
    // This isolates the dispatch overhead per operation.
    let x: Tensor<B, 2> = Tensor::random([7200, 768],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    println!("Dispatch overhead test\n");

    for n_ops in [1, 5, 10, 20, 50, 100] {
        // Warmup
        for _ in 0..3 {
            let mut t = x.clone();
            for _ in 0..n_ops { t = t.mul_scalar(1.0001f32); }
            let _ = t.into_data();
        }
        // Measure
        let mut times = Vec::new();
        for _ in 0..10 {
            let t0 = Instant::now();
            let mut t = x.clone();
            for _ in 0..n_ops { t = t.mul_scalar(1.0001f32); }
            let _ = t.into_data();
            times.push(t0.elapsed().as_secs_f64() * 1000.0);
        }
        let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
        let per_op = best / n_ops as f64;
        println!("  {n_ops:3} ops: {best:>6.1}ms total  ({per_op:.2}ms/op)");
    }

    println!("\n--- Non-fuseable ops (matmul) ---\n");
    let w: Tensor<B, 2> = Tensor::random([768, 768],
        burn::tensor::Distribution::Normal(0.0, 0.01), &d);

    for n_ops in [1, 4, 8, 12] {
        for _ in 0..3 {
            let mut t = x.clone();
            for _ in 0..n_ops { t = t.matmul(w.clone()); }
            let _ = t.into_data();
        }
        let mut times = Vec::new();
        for _ in 0..5 {
            let t0 = Instant::now();
            let mut t = x.clone();
            for _ in 0..n_ops { t = t.matmul(w.clone()); }
            let _ = t.into_data();
            times.push(t0.elapsed().as_secs_f64() * 1000.0);
        }
        let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
        let per_op = best / n_ops as f64;
        println!("  {n_ops:3} matmuls: {best:>7.1}ms total  ({per_op:.1}ms/op)");
    }
}