use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::{LayerNormConfig};
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
fn bench_12_blocks(
x_in: &Tensor<B, 3>,
norms: &[(burn::nn::LayerNorm<B>, burn::nn::LayerNorm<B>)],
weights: &[(Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2>)], tile: usize,
heads: usize,
dh: usize,
scale: f32,
warmup: usize,
runs: usize,
) -> f64 {
let seq = x_in.dims()[1];
let embed = x_in.dims()[2];
let do_forward = |x: Tensor<B, 3>| -> Tensor<B, 3> {
let mut x = x;
for i in 0..12 {
let [b, n, c] = x.dims();
let xn = norms[i].0.forward(x.clone()).reshape([b * n, c]);
let qkv = xn.matmul(weights[i].0.clone()).reshape([b, n, heads, 3 * dh]).swap_dims(1, 2);
let q = qkv.clone().narrow(3, 0, dh).mul_scalar(scale);
let k = qkv.clone().narrow(3, dh, dh);
let v = qkv.narrow(3, 2 * dh, dh);
let k_t = k.transpose();
let attn_out = if tile >= n {
softmax(q.matmul(k_t), 3).matmul(v)
} else {
let mut tiles = Vec::new();
let mut off = 0;
while off < n {
let tl = (n - off).min(tile);
let qt = q.clone().narrow(2, off, tl);
tiles.push(softmax(qt.matmul(k_t.clone()), 3).matmul(v.clone()));
off += tl;
}
Tensor::cat(tiles, 2)
};
let attn_out = attn_out.swap_dims(1, 2).reshape([b * n, c]);
let attn_out = attn_out.matmul(weights[i].1.clone()).reshape([b, n, c]);
x = x + attn_out;
let hn = norms[i].1.forward(x.clone()).reshape([b * n, c]);
let mlp = burn::tensor::activation::gelu(hn.matmul(weights[i].2.clone()));
let mlp = mlp.matmul(weights[i].3.clone()).reshape([b, n, c]);
x = x + mlp;
}
x
};
for _ in 0..warmup {
let _ = do_forward(x_in.clone()).into_data();
}
let mut times = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let _ = do_forward(x_in.clone()).into_data();
times.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
best
}
fn main() {
let d = device();
brainharmony::init_threads(None);
let seq = 7200usize;
let embed = 768usize;
let heads = 12usize;
let dh = embed / heads;
let mlp_h = embed * 4;
let scale = (dh as f32).powf(-0.5);
println!("12-block encoder: tile size vs total time (wgpu f16)\n");
let mut norms = Vec::new();
let mut weights = Vec::new();
for _ in 0..12 {
norms.push((
LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
LayerNormConfig::new(embed).with_epsilon(1e-6).init::<B>(&d),
));
weights.push((
Tensor::<B, 2>::random([embed, 3 * embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
Tensor::<B, 2>::random([embed, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
Tensor::<B, 2>::random([embed, mlp_h], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
Tensor::<B, 2>::random([mlp_h, embed], burn::tensor::Distribution::Normal(0.0, 0.01), &d),
));
}
let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
for tile in [7200, 3600, 2400, 1800, 1200, 1024, 720] {
let n_tiles = (seq + tile - 1) / tile;
let dispatches_per_block = 4 + n_tiles * 3; let t = bench_12_blocks(&x, &norms, &weights, tile, heads, dh, scale, 3, 5);
println!(" tile={tile:5} ({n_tiles:2} tiles, ~{dispatches_per_block:2} dispatches/block): {t:>7.0}ms");
}
}