use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::{Linear, LayerNorm, LayerNormConfig};
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
const TILE: usize = 1024;
fn inlined_block(
x: Tensor<B, 3>, norm1: &LayerNorm<B>,
w_qkv: &Tensor<B, 2>, b_qkv: &Tensor<B, 1>, w_proj: &Tensor<B, 2>, b_proj: &Tensor<B, 1>, norm2: &LayerNorm<B>,
w_fc1: &Tensor<B, 2>, b_fc1: &Tensor<B, 1>,
w_fc2: &Tensor<B, 2>, b_fc2: &Tensor<B, 1>,
scale: f32,
heads: usize,
dh: usize,
) -> Tensor<B, 3> {
let [b, n, c] = x.dims();
let xn = norm1.forward(x.clone());
let xn2 = xn.reshape([b * n, c]);
let qkv = xn2.matmul(w_qkv.clone()) + b_qkv.clone().unsqueeze_dim::<2>(0);
let qkv = qkv.reshape([b, n, heads, 3 * dh]).swap_dims(1, 2); let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
let k = qkv.clone().narrow(3, dh, dh);
let v = qkv.narrow(3, 2 * dh, dh);
let k_t = k.transpose();
let mut tiles: Vec<Tensor<B, 4>> = Vec::new();
let mut off = 0;
while off < n {
let tl = (n - off).min(TILE);
let qt = q.clone().narrow(2, off, tl);
tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
off += tl;
}
let attn_out = Tensor::cat(tiles, 2);
let attn_out = attn_out.swap_dims(1, 2).reshape([b * n, c]);
let attn_out = attn_out.matmul(w_proj.clone()) + b_proj.clone().unsqueeze_dim::<2>(0);
let attn_out = attn_out.reshape([b, n, c]);
let h = x + attn_out;
let hn = norm2.forward(h.clone());
let hn2 = hn.reshape([b * n, c]);
let mlp = hn2.matmul(w_fc1.clone()) + b_fc1.clone().unsqueeze_dim::<2>(0);
let mlp = fast_gelu(mlp);
let mlp = mlp.matmul(w_fc2.clone()) + b_fc2.clone().unsqueeze_dim::<2>(0);
let mlp = mlp.reshape([b, n, c]);
h + mlp
}
fn fast_gelu(x: Tensor<B, 2>) -> Tensor<B, 2> {
let x3 = x.clone() * x.clone() * x.clone();
let inner = (x3.mul_scalar(0.044715f32) + x.clone()).mul_scalar(0.7978845608f32);
x.mul_scalar(0.5f32) * (inner.tanh() + 1.0)
}
fn main() {
let d = device();
brainharmony::init_threads(None);
let seq = 7200usize;
let embed = 768usize;
let heads = 12usize;
let dh = embed / heads;
let mlp_h = embed * 4;
let scale = (dh as f32).powf(-0.5);
let warmup = 5;
let runs = 10;
println!("Fused block benchmark (wgpu f16)\n");
let norm1 = LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
let norm2 = LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d);
let w_qkv: Tensor<B, 2> = Tensor::random([embed, 3*embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let b_qkv: Tensor<B, 1> = Tensor::zeros([3*embed], &d);
let w_proj: Tensor<B, 2> = Tensor::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let b_proj: Tensor<B, 1> = Tensor::zeros([embed], &d);
let w_fc1: Tensor<B, 2> = Tensor::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let b_fc1: Tensor<B, 1> = Tensor::zeros([mlp_h], &d);
let w_fc2: Tensor<B, 2> = Tensor::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d);
let b_fc2: Tensor<B, 1> = Tensor::zeros([embed], &d);
let x: Tensor<B, 3> = Tensor::random([1, seq, embed], burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for _ in 0..warmup {
let out = inlined_block(
x.clone(), &norm1, &w_qkv, &b_qkv, &w_proj, &b_proj,
&norm2, &w_fc1, &b_fc1, &w_fc2, &b_fc2, scale, heads, dh,
);
let _ = out.into_data();
}
let mut times = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let out = inlined_block(
x.clone(), &norm1, &w_qkv, &b_qkv, &w_proj, &b_proj,
&norm2, &w_fc1, &b_fc1, &w_fc2, &b_fc2, scale, heads, dh,
);
let _ = out.into_data();
times.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let med = { let mut s = times.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
let block = brainharmony::model::block::Block::<B>::new(embed, heads, 4.0, true, 1e-6, &d);
for _ in 0..warmup {
let _ = block.forward(x.clone(), None).into_data();
}
let mut times_std = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let out = block.forward(x.clone(), None);
let _ = out.into_data();
times_std.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best_std = times_std.iter().cloned().fold(f64::INFINITY, f64::min);
let med_std = { let mut s = times_std.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
println!("Standard Block: best={best_std:.0}ms med={med_std:.0}ms (x12 = {:.0}ms)", best_std * 12.0);
println!("Inlined Block: best={best:.0}ms med={med:.0}ms (x12 = {:.0}ms)", best * 12.0);
println!("Speedup: {:.2}x", best_std / best);
}