use std::time::Instant;
use burn::prelude::*;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
fn main() {
let d = device();
brainharmony::init_threads(None);
let x: Tensor<B, 2> = Tensor::random([7200, 768],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
println!("Dispatch overhead test\n");
for n_ops in [1, 5, 10, 20, 50, 100] {
for _ in 0..3 {
let mut t = x.clone();
for _ in 0..n_ops { t = t.mul_scalar(1.0001f32); }
let _ = t.into_data();
}
let mut times = Vec::new();
for _ in 0..10 {
let t0 = Instant::now();
let mut t = x.clone();
for _ in 0..n_ops { t = t.mul_scalar(1.0001f32); }
let _ = t.into_data();
times.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let per_op = best / n_ops as f64;
println!(" {n_ops:3} ops: {best:>6.1}ms total ({per_op:.2}ms/op)");
}
println!("\n--- Non-fuseable ops (matmul) ---\n");
let w: Tensor<B, 2> = Tensor::random([768, 768],
burn::tensor::Distribution::Normal(0.0, 0.01), &d);
for n_ops in [1, 4, 8, 12] {
for _ in 0..3 {
let mut t = x.clone();
for _ in 0..n_ops { t = t.matmul(w.clone()); }
let _ = t.into_data();
}
let mut times = Vec::new();
for _ in 0..5 {
let t0 = Instant::now();
let mut t = x.clone();
for _ in 0..n_ops { t = t.matmul(w.clone()); }
let _ = t.into_data();
times.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let best = times.iter().cloned().fold(f64::INFINITY, f64::min);
let per_op = best / n_ops as f64;
println!(" {n_ops:3} matmuls: {best:>7.1}ms total ({per_op:.1}ms/op)");
}
}