brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Test: 2D matmul vs 3D batch matmul on wgpu.
use std::time::Instant;
use burn::prelude::*;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn bench<F: FnMut()>(label: &str, warmup: usize, runs: usize, mut f: F) -> f64 {
    for _ in 0..warmup { f(); }
    let mut t = Vec::new();
    for _ in 0..runs { let t0 = Instant::now(); f(); t.push(t0.elapsed().as_secs_f64() * 1000.0); }
    let best = t.iter().cloned().fold(f64::INFINITY, f64::min);
    let med = { let mut s = t.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
    println!("  {label:45} best={best:>7.1}ms  med={med:>7.1}ms");
    best
}

fn main() {
    let d = device();
    brainharmony::init_threads(None);
    println!("2D vs 3D matmul benchmark (wgpu f16)\n");

    let w: Tensor<B, 2> = Tensor::random([768, 2304], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let x2: Tensor<B, 2> = Tensor::random([7200, 768], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    let x3: Tensor<B, 3> = x2.clone().unsqueeze_dim::<3>(0);

    let t2 = bench("2D: [7200,768] @ [768,2304]", 20, 20, || {
        let _ = x2.clone().matmul(w.clone()).into_data();
    });

    let t3 = bench("3D: [1,7200,768] @ [1,768,2304] (unsqueeze)", 20, 20, || {
        let w3 = w.clone().unsqueeze_dim::<3>(0);
        let _ = x3.clone().matmul(w3).into_data();
    });

    // Linear module for comparison
    let linear = brainharmony::model::linear_zeros::<B>(768, 2304, true, &d);
    let t_lin = bench("Linear::forward [1,7200,768]", 20, 20, || {
        let _ = linear.forward(x3.clone()).into_data();
    });

    // 12-block chain: 2D only
    let norm = burn::nn::LayerNormConfig::new(768).with_epsilon(1e-6).init::<B>(&d);
    let w_fc2: Tensor<B, 2> = Tensor::random([2304, 768], burn::tensor::Distribution::Normal(0.0, 0.01), &d);

    let t_chain_2d = bench("12x (LN + 2D matmul + gelu + 2D matmul)", 10, 10, || {
        let mut t = x3.clone();
        for _ in 0..12 {
            let tn = norm.forward(t).reshape([7200, 768]);
            let h = burn::tensor::activation::gelu(tn.matmul(w.clone()));
            t = h.matmul(w_fc2.clone()).reshape([1, 7200, 768]);
        }
        let _ = t.into_data();
    });

    let t_chain_3d = bench("12x (LN + Linear + gelu + Linear)", 10, 10, || {
        let mut t = x3.clone();
        let linear2 = brainharmony::model::linear_zeros::<B>(2304, 768, true, &d);
        for _ in 0..12 {
            let h = burn::tensor::activation::gelu(linear.forward(norm.forward(t)));
            t = linear2.forward(h);
        }
        let _ = t.into_data();
    });

    println!();
    println!("2D vs 3D matmul: {:.2}x", t3 / t2);
    println!("2D vs Linear:    {:.2}x", t_lin / t2);
    println!("12-chain 2D vs 3D: {:.2}x", t_chain_3d / t_chain_2d);
}