brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Test: does running 12 blocks as a single lazy chain (one sync at the end)
/// perform better than syncing per block?
use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::LayerNormConfig;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
    pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};

fn main() {
    let d = device();
    brainharmony::init_threads(None);

    let seq = 7200usize;
    let embed = 768usize;
    let heads = 12usize;
    let dh = embed / heads;
    let scale = (dh as f32).powf(-0.5);
    let warmup = 20;
    let runs = 10;

    println!("Chain vs per-block sync (20 warmup, 10 runs)\n");

    let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
        burn::tensor::Distribution::Normal(0.0, 1.0), &d);

    // Use Burn's Block module
    let blocks: Vec<brainharmony::model::block::Block<B>> = (0..12)
        .map(|_| brainharmony::model::block::Block::new(embed, heads, 4.0, true, 1e-6, &d))
        .collect();

    // Method A: 12 blocks, sync per block
    for _ in 0..warmup {
        let mut t = x.clone();
        for blk in &blocks { t = blk.forward(t, None); }
        let _ = t.into_data();
    }
    let mut times_sync = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let mut t = x.clone();
        for blk in &blocks {
            t = blk.forward(t, None);
            let _ = t.clone().into_data(); // sync per block
        }
        times_sync.push(t0.elapsed().as_secs_f64() * 1000.0);
    }

    // Method B: 12 blocks, single sync at end
    for _ in 0..warmup {
        let mut t = x.clone();
        for blk in &blocks { t = blk.forward(t, None); }
        let _ = t.into_data();
    }
    let mut times_chain = Vec::new();
    for _ in 0..runs {
        let t0 = Instant::now();
        let mut t = x.clone();
        for blk in &blocks { t = blk.forward(t, None); }
        let _ = t.into_data(); // single sync
        times_chain.push(t0.elapsed().as_secs_f64() * 1000.0);
    }

    let sync_best = times_sync.iter().cloned().fold(f64::INFINITY, f64::min);
    let chain_best = times_chain.iter().cloned().fold(f64::INFINITY, f64::min);
    let sync_med = { let mut s = times_sync.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
    let chain_med = { let mut s = times_chain.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };

    println!("  Per-block sync (12 syncs):  best={sync_best:.0}ms  med={sync_med:.0}ms");
    println!("  Single sync (1 sync):       best={chain_best:.0}ms  med={chain_med:.0}ms");
    println!("  Speedup:                    {:.2}x", sync_best / chain_best);
}