use std::time::Instant;
use burn::prelude::*;
use burn::backend::NdArray;
type B = NdArray;
fn main() {
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
let n_threads = brainharmony::init_threads(None);
let embed_dim = 768usize;
let depth = 12usize;
let num_heads = 12usize;
let patch_size = 48usize;
let n_rois = 400usize;
let n_time_patches = 18usize;
let signal_length = n_time_patches * patch_size; let n_warmup = 3usize;
let n_runs = 10usize;
println!("Brain-Harmony Rust Benchmark");
println!(" Backend : NdArray ({n_threads} threads)");
println!(" Model : embed={embed_dim} depth={depth} heads={num_heads}");
println!(" Patch size : {patch_size}");
println!(" Input : [1, 1, {n_rois}, {signal_length}]");
println!(" Seq length : {} patches", n_rois * n_time_patches);
println!(" Warmup : {n_warmup} Runs: {n_runs}");
println!();
println!("Creating model...");
let encoder = brainharmony::model::encoder::FlexVisionTransformer::<B>::new(
(n_rois, signal_length),
patch_size,
1, embed_dim,
depth,
num_heads,
4.0, true, 1e-6, 30, 200, 384, "sincos",
false, false, &device,
).expect("failed to create encoder");
let x: Tensor<B, 4> = Tensor::random(
[1, 1, n_rois, signal_length],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
println!("Warming up...");
for _ in 0..n_warmup {
let _ = encoder.forward(x.clone(), None, None, None, None, None);
}
println!("Benchmarking...");
let mut times = Vec::with_capacity(n_runs);
let mut output_shape = [0usize; 3];
for i in 0..n_runs {
let t0 = Instant::now();
let out = encoder.forward(x.clone(), None, None, None, None, None);
let ms = t0.elapsed().as_secs_f64() * 1000.0;
if i == 0 {
output_shape = out.dims();
println!(" Output shape: [{}, {}, {}]", output_shape[0], output_shape[1], output_shape[2]);
}
times.push(ms);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
println!();
println!("Results ({n_runs} runs):");
println!(" Times (ms) : {}", times.iter().map(|t| format!("{t:.1}")).collect::<Vec<_>>().join(", "));
println!(" Best : {best:.1} ms");
println!(" Average : {avg:.1} ms");
println!("TIMING encode_best={best:.1}ms encode_avg={avg:.1}ms");
}