use std::time::Instant;
use burn::prelude::*;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
fn bench<F: FnMut()>(label: &str, warmup: usize, runs: usize, mut f: F) -> f64 {
for _ in 0..warmup { f(); }
let mut t = Vec::new();
for _ in 0..runs { let t0 = Instant::now(); f(); t.push(t0.elapsed().as_secs_f64() * 1000.0); }
let best = t.iter().cloned().fold(f64::INFINITY, f64::min);
let med = { let mut s = t.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
println!(" {label:45} best={best:>7.1}ms med={med:>7.1}ms");
best
}
fn main() {
let d = device();
brainharmony::init_threads(None);
println!("2D vs 3D matmul benchmark (wgpu f16)\n");
let w: Tensor<B, 2> = Tensor::random([768, 2304], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let x2: Tensor<B, 2> = Tensor::random([7200, 768], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
let x3: Tensor<B, 3> = x2.clone().unsqueeze_dim::<3>(0);
let t2 = bench("2D: [7200,768] @ [768,2304]", 20, 20, || {
let _ = x2.clone().matmul(w.clone()).into_data();
});
let t3 = bench("3D: [1,7200,768] @ [1,768,2304] (unsqueeze)", 20, 20, || {
let w3 = w.clone().unsqueeze_dim::<3>(0);
let _ = x3.clone().matmul(w3).into_data();
});
let linear = brainharmony::model::linear_zeros::<B>(768, 2304, true, &d);
let t_lin = bench("Linear::forward [1,7200,768]", 20, 20, || {
let _ = linear.forward(x3.clone()).into_data();
});
let norm = burn::nn::LayerNormConfig::new(768).with_epsilon(1e-6).init::<B>(&d);
let w_fc2: Tensor<B, 2> = Tensor::random([2304, 768], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let t_chain_2d = bench("12x (LN + 2D matmul + gelu + 2D matmul)", 10, 10, || {
let mut t = x3.clone();
for _ in 0..12 {
let tn = norm.forward(t).reshape([7200, 768]);
let h = burn::tensor::activation::gelu(tn.matmul(w.clone()));
t = h.matmul(w_fc2.clone()).reshape([1, 7200, 768]);
}
let _ = t.into_data();
});
let t_chain_3d = bench("12x (LN + Linear + gelu + Linear)", 10, 10, || {
let mut t = x3.clone();
let linear2 = brainharmony::model::linear_zeros::<B>(2304, 768, true, &d);
for _ in 0..12 {
let h = burn::tensor::activation::gelu(linear.forward(norm.forward(t)));
t = linear2.forward(h);
}
let _ = t.into_data();
});
println!();
println!("2D vs 3D matmul: {:.2}x", t3 / t2);
println!("2D vs Linear: {:.2}x", t_lin / t2);
println!("12-chain 2D vs 3D: {:.2}x", t_chain_3d / t_chain_2d);
}