brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Benchmark Brain-Harmony encoder on GPU (wgpu/Metal).
///
/// cargo run --example bench_gpu --release --no-default-features --features wgpu-f16
use std::time::Instant;
use burn::prelude::*;

#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
    pub type Device = burn::backend::wgpu::WgpuDevice;
    pub fn device() -> Device { Device::DefaultDevice }
    pub const NAME: &str = "wgpu f16 (Metal)";
}

#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
    pub fn device() -> Device { Device::DefaultDevice }
    pub const NAME: &str = "wgpu f32 (Metal)";
}

#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
    pub use burn::backend::NdArray as B;
    pub type Device = burn::backend::ndarray::NdArrayDevice;
    pub fn device() -> Device { Device::Cpu }
    pub const NAME: &str = "NdArray (CPU) -- use --features wgpu for GPU";
}

use backend::{B, device};

fn main() {
    let dev = device();
    let n_threads = brainharmony::init_threads(None);

    let embed_dim = 768usize;
    let depth = 12usize;
    let num_heads = 12usize;
    let patch_size = 48usize;
    let n_rois = 400usize;
    let n_time_patches = 18usize;
    let signal_length = n_time_patches * patch_size;
    let n_warmup = 20usize;  // enough to fully cache GPU pipelines
    let n_runs = 20usize;

    println!("Brain-Harmony Rust GPU Benchmark");
    println!("  Backend    : {} ({n_threads} CPU threads)", backend::NAME);
    println!("  Model      : embed={embed_dim} depth={depth} heads={num_heads}");
    println!("  Patch size : {patch_size}");
    println!("  Input      : [1, 1, {n_rois}, {signal_length}]");
    println!("  Seq length : {} patches", n_rois * n_time_patches);
    println!("  Warmup     : {n_warmup}  Runs: {n_runs}");
    println!();

    println!("Creating model...");
    let encoder = brainharmony::model::encoder::FlexVisionTransformer::<B>::new(
        (n_rois, signal_length),
        patch_size,
        1,
        embed_dim,
        depth,
        num_heads,
        4.0,
        true,
        1e-6,
        30,
        200,
        384,
        "sincos",
        false,
        false,
        &dev,
    ).expect("failed to create encoder");

    let x: Tensor<B, 4> = Tensor::random(
        [1, 1, n_rois, signal_length],
        burn::tensor::Distribution::Normal(0.0, 1.0),
        &dev,
    );

    // Extended warmup to fully cache all GPU pipeline compilations
    println!("Warming up ({n_warmup} runs to cache GPU pipelines)...");
    for i in 0..n_warmup {
        let t0 = Instant::now();
        let out = encoder.forward(x.clone(), None, None, None, None, None);
        let _ = out.into_data();
        if i < 5 || i == n_warmup - 1 {
            let ms = t0.elapsed().as_secs_f64() * 1000.0;
            println!("  warmup {i}: {ms:.0}ms");
        }
    }

    let shape = encoder.forward(x.clone(), None, None, None, None, None).dims();
    println!("  Output shape: [{}, {}, {}]", shape[0], shape[1], shape[2]);

    // Timed runs — pipelines are fully cached now
    println!("Benchmarking ({n_runs} runs, pipelines cached)...");
    let mut times = Vec::with_capacity(n_runs);
    for _ in 0..n_runs {
        let t0 = Instant::now();
        let out = encoder.forward(x.clone(), None, None, None, None, None);
        let _ = out.into_data();
        times.push(t0.elapsed().as_secs_f64() * 1000.0);
    }

    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
    let median = {
        let mut sorted = times.clone();
        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
        sorted[sorted.len() / 2]
    };

    println!();
    println!("Results ({n_runs} runs):");
    println!("  Times (ms) : {}", times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
    println!("  Best       : {best:.0} ms");
    println!("  Median     : {median:.0} ms");
    println!("  Average    : {avg:.0} ms");
    println!("TIMING encode_best={best:.1}ms encode_median={median:.1}ms encode_avg={avg:.1}ms");
}