brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Trace per-operation overhead in the encoder pipeline.
/// Compares: sum-of-parts vs actual full forward.
use std::time::Instant;
use burn::prelude::*;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let warmup = 3;
    let runs = 5;

    println!("=== Overhead analysis ===\n");

    // 1. Measure QKV split overhead (reshape + narrow + swap_dims)
    let qkv: Tensor<B, 3> = Tensor::random([1, seq, 3 * embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // Warmup
    for _ in 0..warmup {
        let qkv2 = qkv.clone().reshape([1, seq, 3, heads, dh]);
        let q = qkv2.clone().narrow(2, 0, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let k = qkv2.clone().narrow(2, 1, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let v = qkv2.narrow(2, 2, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let _ = (q + k + v).into_data();
    }

    // Method A: current (narrow along dim 2, then reshape+swap)
    let mut times_a = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let qkv2 = qkv.clone().reshape([1, seq, 3, heads, dh]);
        let q = qkv2.clone().narrow(2, 0, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let k = qkv2.clone().narrow(2, 1, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let v = qkv2.narrow(2, 2, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
        let _ = (q + k + v).into_data();
        times_a.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_a = times_a.iter().cloned().fold(f64::INFINITY, f64::min);

    // Method B: reshape to [B, N, H, 3*Dh], swap, then narrow along last dim
    for _ in 0..warmup {
        let qkv2 = qkv.clone().reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
        let q = qkv2.clone().narrow(3, 0, dh);
        let k = qkv2.clone().narrow(3, dh, dh);
        let v = qkv2.narrow(3, 2 * dh, dh);
        let _ = (q + k + v).into_data();
    }
    let mut times_b = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let qkv2 = qkv.clone().reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
        let q = qkv2.clone().narrow(3, 0, dh);
        let k = qkv2.clone().narrow(3, dh, dh);
        let v = qkv2.narrow(3, 2 * dh, dh);
        let _ = (q + k + v).into_data();
        times_b.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_b = times_b.iter().cloned().fold(f64::INFINITY, f64::min);

    // Method C: three separate projections (no narrow/split at all)
    let wq: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
    let wk = wq.clone();
    let wv = wq.clone();
    let x2: Tensor<B, 2> = Tensor::random([seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    for _ in 0..warmup {
        let q = x2.clone().matmul(wq.clone()).reshape([1, heads, seq, dh]);
        let k = x2.clone().matmul(wk.clone()).reshape([1, heads, seq, dh]);
        let v = x2.clone().matmul(wv.clone()).reshape([1, heads, seq, dh]);
        let _ = (q + k + v).into_data();
    }
    let mut times_c = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let q = x2.clone().matmul(wq.clone()).reshape([1, heads, seq, dh]);
        let k = x2.clone().matmul(wk.clone()).reshape([1, heads, seq, dh]);
        let v = x2.clone().matmul(wv.clone()).reshape([1, heads, seq, dh]);
        let _ = (q + k + v).into_data();
        times_c.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_c = times_c.iter().cloned().fold(f64::INFINITY, f64::min);

    println!("QKV split methods:");
    println!("  A: narrow(dim2) + reshape + swap  : {best_a:.1}ms");
    println!("  B: reshape + swap + narrow(dim3)  : {best_b:.1}ms");
    println!("  C: 3 separate matmuls + reshape   : {best_c:.1}ms");

    // 2. Measure transpose overhead
    let k4: Tensor<B, 4> = Tensor::random([1, heads, seq, dh],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);
    for _ in 0..warmup {
        let _ = k4.clone().transpose().into_data();
    }
    let mut times_t = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let _ = k4.clone().transpose().into_data();
        times_t.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_t = times_t.iter().cloned().fold(f64::INFINITY, f64::min);
    println!("\nTranspose [1,12,7200,64] -> [1,12,64,7200]: {best_t:.1}ms");

    // 3. Measure full 1 block vs sum of parts
    let block = brainharmony::model::block::Block::<B>::new(embed, heads, 4.0, true, 1e-6, &d);
    let x3: Tensor<B, 3> = Tensor::random([1, seq, embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    for _ in 0..warmup {
        let _ = block.forward(x3.clone(), None).into_data();
    }
    let mut times_block = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let out = block.forward(x3.clone(), None);
        let _ = out.into_data();
        times_block.push(t0.elapsed().as_secs_f64() * 1000.0);
    }
    let best_block = times_block.iter().cloned().fold(f64::INFINITY, f64::min);
    println!("\nFull Block forward: {best_block:.1}ms (x12 = {:.0}ms)", best_block * 12.0);
    println!("  Per-op estimate:  ~300ms  (x12 = ~3600ms)");
    println!("  Overhead per block: {:.1}ms", best_block - 300.0);
}