brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Benchmark Brain-Harmony FlexVisionTransformer encoder (Rust/Burn).
///
/// Runs the encoder forward pass with zero weights and random data, timing it
/// for a fair comparison against the Python implementation.
///
/// ```sh
/// cargo run --example bench --release --features accelerate
/// ```
use std::time::Instant;
use burn::prelude::*;
use burn::backend::NdArray;

type B = NdArray;

fn main() {
    let device = burn::backend::ndarray::NdArrayDevice::Cpu;
    let n_threads = brainharmony::init_threads(None);

    let embed_dim = 768usize;
    let depth = 12usize;
    let num_heads = 12usize;
    let patch_size = 48usize;
    let n_rois = 400usize;
    let n_time_patches = 18usize;
    let signal_length = n_time_patches * patch_size; // 864
    let n_warmup = 3usize;
    let n_runs = 10usize;

    println!("Brain-Harmony Rust Benchmark");
    println!("  Backend    : NdArray ({n_threads} threads)");
    println!("  Model      : embed={embed_dim} depth={depth} heads={num_heads}");
    println!("  Patch size : {patch_size}");
    println!("  Input      : [1, 1, {n_rois}, {signal_length}]");
    println!("  Seq length : {} patches", n_rois * n_time_patches);
    println!("  Warmup     : {n_warmup}  Runs: {n_runs}");
    println!();

    // Build encoder with zero weights (same as Python random init — weights don't affect timing)
    println!("Creating model...");
    let encoder = brainharmony::model::encoder::FlexVisionTransformer::<B>::new(
        (n_rois, signal_length),
        patch_size,
        1, // in_chans
        embed_dim,
        depth,
        num_heads,
        4.0,   // mlp_ratio
        true,  // qkv_bias
        1e-6,  // norm_eps
        30,    // grad_dim (unused for sincos mode)
        200,   // geoh_dim (unused for sincos mode)
        384,   // pred_embed_dim
        "sincos",
        false, // no cls token
        false, // no decoder
        &device,
    ).expect("failed to create encoder");

    // Create random input
    let x: Tensor<B, 4> = Tensor::random(
        [1, 1, n_rois, signal_length],
        burn::tensor::Distribution::Normal(0.0, 1.0),
        &device,
    );

    // Warmup
    println!("Warming up...");
    for _ in 0..n_warmup {
        let _ = encoder.forward(x.clone(), None, None, None, None, None);
    }

    // Timed runs
    println!("Benchmarking...");
    let mut times = Vec::with_capacity(n_runs);
    let mut output_shape = [0usize; 3];
    for i in 0..n_runs {
        let t0 = Instant::now();
        let out = encoder.forward(x.clone(), None, None, None, None, None);
        let ms = t0.elapsed().as_secs_f64() * 1000.0;
        if i == 0 {
            output_shape = out.dims();
            println!("  Output shape: [{}, {}, {}]", output_shape[0], output_shape[1], output_shape[2]);
        }
        times.push(ms);
    }

    let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
    let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;

    println!();
    println!("Results ({n_runs} runs):");
    println!("  Times (ms) : {}", times.iter().map(|t| format!("{t:.1}")).collect::<Vec<_>>().join(", "));
    println!("  Best       : {best:.1} ms");
    println!("  Average    : {avg:.1} ms");
    println!("TIMING encode_best={best:.1}ms encode_avg={avg:.1}ms");
}