use std::time::Instant;
use burn::prelude::*;
#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
use brainharmony::model::flash_attn::gpu::flash_attention_tensor;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
pub const NAME: &str = "wgpu f16 (Metal)";
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
pub const NAME: &str = "wgpu f32 (Metal)";
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
pub const NAME: &str = "NdArray";
}
use backend::{B, device};
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
fn main() {
eprintln!("bench_flash requires --features wgpu or wgpu-f16");
std::process::exit(1);
}
#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
fn main() {
let dev = device();
let n_threads = brainharmony::init_threads(None);
let embed_dim = 768usize;
let num_heads = 12usize;
let head_dim = embed_dim / num_heads;
let seq = 7200usize;
let n_warmup = 5usize;
let n_runs = 10usize;
println!("Flash Attention Benchmark (non-causal, cubek)");
println!(" Backend : {} ({n_threads} CPU threads)", backend::NAME);
println!(" Heads : {num_heads} head_dim: {head_dim}");
println!(" Seq length : {seq}");
println!(" Shape : [1, {num_heads}, {seq}, {head_dim}]");
println!(" Warmup : {n_warmup} Runs: {n_runs}");
println!();
let q: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0), &dev);
let k: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0), &dev);
let v: Tensor<B, 4> = Tensor::random([1, num_heads, seq, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0), &dev);
println!("Warming up...");
for _ in 0..n_warmup {
let out = flash_attention_tensor::<B, _>(q.clone(), k.clone(), v.clone());
let _ = out.into_data();
}
println!("Benchmarking flash attention...");
let mut times = Vec::with_capacity(n_runs);
for _ in 0..n_runs {
let t0 = Instant::now();
let out = flash_attention_tensor::<B, _>(q.clone(), k.clone(), v.clone());
let _ = out.into_data(); let ms = t0.elapsed().as_secs_f64() * 1000.0;
times.push(ms);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let avg: f64 = times.iter().sum::<f64>() / times.len() as f64;
let median = {
let mut s = times.clone();
s.sort_by(|a, b| a.partial_cmp(b).unwrap());
s[s.len() / 2]
};
println!();
println!("Results ({n_runs} runs):");
println!(" Times (ms) : {}", times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
println!(" Best : {best:.0} ms");
println!(" Median : {median:.0} ms");
println!(" Average : {avg:.0} ms");
println!();
println!("Benchmarking naive attention (matmul->softmax->matmul)...");
for _ in 0..n_warmup {
let scores = q.clone().matmul(k.clone().transpose()).mul_scalar(1.0 / (head_dim as f32).sqrt());
let attn = burn::tensor::activation::softmax(scores, 3);
let out = attn.matmul(v.clone());
let _ = out.into_data();
}
let mut naive_times = Vec::with_capacity(n_runs);
for _ in 0..n_runs {
let t0 = Instant::now();
let scores = q.clone().matmul(k.clone().transpose()).mul_scalar(1.0 / (head_dim as f32).sqrt());
let attn = burn::tensor::activation::softmax(scores, 3);
let out = attn.matmul(v.clone());
let _ = out.into_data();
let ms = t0.elapsed().as_secs_f64() * 1000.0;
naive_times.push(ms);
}
let naive_best = naive_times.iter().cloned().fold(f64::INFINITY, f64::min);
let naive_avg: f64 = naive_times.iter().sum::<f64>() / naive_times.len() as f64;
println!(" Times (ms) : {}", naive_times.iter().map(|t| format!("{t:.0}")).collect::<Vec<_>>().join(", "));
println!(" Best : {naive_best:.0} ms");
println!(" Average : {naive_avg:.0} ms");
println!();
println!("Speedup (flash / naive): {:.2}x", naive_best / best);
}