use std::time::Instant;
use burn::prelude::*;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
fn main() {
let d = device();
brainharmony::init_threads(None);
let seq = 7200usize;
let embed = 768usize;
let heads = 12usize;
let dh = embed / heads;
let warmup = 3;
let runs = 5;
println!("=== Overhead analysis ===\n");
let qkv: Tensor<B, 3> = Tensor::random([1, seq, 3 * embed],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for _ in 0..warmup {
let qkv2 = qkv.clone().reshape([1, seq, 3, heads, dh]);
let q = qkv2.clone().narrow(2, 0, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let k = qkv2.clone().narrow(2, 1, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let v = qkv2.narrow(2, 2, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let _ = (q + k + v).into_data();
}
let mut times_a = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let qkv2 = qkv.clone().reshape([1, seq, 3, heads, dh]);
let q = qkv2.clone().narrow(2, 0, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let k = qkv2.clone().narrow(2, 1, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let v = qkv2.narrow(2, 2, 1).reshape([1, seq, heads, dh]).swap_dims(1, 2);
let _ = (q + k + v).into_data();
times_a.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_a = times_a.iter().cloned().fold(f64::INFINITY, f64::min);
for _ in 0..warmup {
let qkv2 = qkv.clone().reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
let q = qkv2.clone().narrow(3, 0, dh);
let k = qkv2.clone().narrow(3, dh, dh);
let v = qkv2.narrow(3, 2 * dh, dh);
let _ = (q + k + v).into_data();
}
let mut times_b = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let qkv2 = qkv.clone().reshape([1, seq, heads, 3 * dh]).swap_dims(1, 2);
let q = qkv2.clone().narrow(3, 0, dh);
let k = qkv2.clone().narrow(3, dh, dh);
let v = qkv2.narrow(3, 2 * dh, dh);
let _ = (q + k + v).into_data();
times_b.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_b = times_b.iter().cloned().fold(f64::INFINITY, f64::min);
let wq: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let wk = wq.clone();
let wv = wq.clone();
let x2: Tensor<B, 2> = Tensor::random([seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for _ in 0..warmup {
let q = x2.clone().matmul(wq.clone()).reshape([1, heads, seq, dh]);
let k = x2.clone().matmul(wk.clone()).reshape([1, heads, seq, dh]);
let v = x2.clone().matmul(wv.clone()).reshape([1, heads, seq, dh]);
let _ = (q + k + v).into_data();
}
let mut times_c = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let q = x2.clone().matmul(wq.clone()).reshape([1, heads, seq, dh]);
let k = x2.clone().matmul(wk.clone()).reshape([1, heads, seq, dh]);
let v = x2.clone().matmul(wv.clone()).reshape([1, heads, seq, dh]);
let _ = (q + k + v).into_data();
times_c.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_c = times_c.iter().cloned().fold(f64::INFINITY, f64::min);
println!("QKV split methods:");
println!(" A: narrow(dim2) + reshape + swap : {best_a:.1}ms");
println!(" B: reshape + swap + narrow(dim3) : {best_b:.1}ms");
println!(" C: 3 separate matmuls + reshape : {best_c:.1}ms");
let k4: Tensor<B, 4> = Tensor::random([1, heads, seq, dh],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for _ in 0..warmup {
let _ = k4.clone().transpose().into_data();
}
let mut times_t = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let _ = k4.clone().transpose().into_data();
times_t.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_t = times_t.iter().cloned().fold(f64::INFINITY, f64::min);
println!("\nTranspose [1,12,7200,64] -> [1,12,64,7200]: {best_t:.1}ms");
let block = brainharmony::model::block::Block::<B>::new(embed, heads, 4.0, true, 1e-6, &d);
let x3: Tensor<B, 3> = Tensor::random([1, seq, embed],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for _ in 0..warmup {
let _ = block.forward(x3.clone(), None).into_data();
}
let mut times_block = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let out = block.forward(x3.clone(), None);
let _ = out.into_data();
times_block.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_block = times_block.iter().cloned().fold(f64::INFINITY, f64::min);
println!("\nFull Block forward: {best_block:.1}ms (x12 = {:.0}ms)", best_block * 12.0);
println!(" Per-op estimate: ~300ms (x12 = ~3600ms)");
println!(" Overhead per block: {:.1}ms", best_block - 300.0);
}