use std::time::Instant;
use burn::prelude::*;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub type Device = burn::backend::wgpu::WgpuDevice;
pub fn device() -> Device { Device::DefaultDevice }
pub const NAME: &str = "wgpu f16 (Metal)";
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
pub fn device() -> Device { Device::DefaultDevice }
pub const NAME: &str = "wgpu f32 (Metal)";
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device { Device::Cpu }
pub const NAME: &str = "NdArray (CPU) -- use --features wgpu for GPU";
}
use backend::{B, device};
fn main() {
let dev = device();
let n_threads = brainharmony::init_threads(None);
let embed_dim = 768usize;
let depth = 12usize;
let num_heads = 12usize;
let patch_size = 48usize;
let n_rois = 400usize;
let n_time_patches = 18usize;
let signal_length = n_time_patches * patch_size;
let n_warmup = 20usize; let n_runs = 20usize;
println!("Brain-Harmony Rust GPU Benchmark");
println!(" Backend : {} ({n_threads} CPU threads)", backend::NAME);
println!(" Model : embed={embed_dim} depth={depth} heads={num_heads}");
println!(" Patch size : {patch_size}");
println!(" Input : [1, 1, {n_rois}, {signal_length}]");
println!(" Seq length : {} patches", n_rois * n_time_patches);
println!(" Warmup : {n_warmup} Runs: {n_runs}");
println!();
println!("Creating model...");
let encoder = brainharmony::model::encoder::FlexVisionTransformer::<B>::new(
(n_rois, signal_length),
patch_size,
1,
embed_dim,
depth,
num_heads,
4.0,
true,
1e-6,
30,
200,
384,
"sincos",
false,
false,
&dev,
).expect("failed to create encoder");
let x: Tensor<B, 4> = Tensor::random(
[1, 1, n_rois, signal_length],
burn::tensor::Distribution::Normal(0.0, 1.0),
&dev,
);
println!("Warming up ({n_warmup} runs to cache GPU pipelines)...");
for i in 0..n_warmup {
let t0 = Instant::now();
let out = encoder.forward(x.clone(), None, None, None, None, None);
let _ = out.into_data();
if i < 5 || i == n_warmup - 1 {
let ms = t0.elapsed().as_secs_f64() * 1000.0;
println!(" warmup {i}: {ms:.0}ms");
}
}
let shape = encoder.forward(x.clone(), None, None, None, None, None).dims();
println!(" Output shape: [{}, {}, {}]", shape[0], shape[1], shape[2]);
println!("Benchmarking ({n_runs} runs, pipelines cached)...");
let mut times = Vec::with_capacity(n_runs);
for _ in 0..n_runs {
let t0 = Instant::now();
let out = encoder.forward(x.clone(), None, None, None, None, None);
let _ = out.into_data();
times.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
let median = {
let mut sorted = times.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
sorted[sorted.len() / 2]
};
println!();
println!("Results ({n_runs} runs):");
println!(" Times (ms) : {}", times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
println!(" Best : {best:.0} ms");
println!(" Median : {median:.0} ms");
println!(" Average : {avg:.0} ms");
println!("TIMING encode_best={best:.1}ms encode_median={median:.1}ms encode_avg={avg:.1}ms");
}