use std::time::Instant;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use burn::nn::LayerNormConfig;
#[cfg(all(feature = "wgpu-f16", not(feature = "wgpu")))]
mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub fn device() -> burn::backend::wgpu::WgpuDevice { burn::backend::wgpu::WgpuDevice::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice};
pub fn device() -> WgpuDevice { WgpuDevice::DefaultDevice }
}
#[cfg(not(any(feature = "wgpu", feature = "wgpu-f16")))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice { burn::backend::ndarray::NdArrayDevice::Cpu }
}
use backend::{B, device};
fn main() {
let d = device();
brainharmony::init_threads(None);
let seq = 7200usize;
let embed = 768usize;
let heads = 12usize;
let dh = embed / heads;
let scale = (dh as f32).powf(-0.5);
let warmup = 20;
let runs = 10;
println!("Chain vs per-block sync (20 warmup, 10 runs)\n");
let x: Tensor<B, 3> = Tensor::random([1, seq, embed],
burn::tensor::Distribution::Normal(0.0, 1.0), &d);
let blocks: Vec<brainharmony::model::block::Block<B>> = (0..12)
.map(|_| brainharmony::model::block::Block::new(embed, heads, 4.0, true, 1e-6, &d))
.collect();
for _ in 0..warmup {
let mut t = x.clone();
for blk in &blocks { t = blk.forward(t, None); }
let _ = t.into_data();
}
let mut times_sync = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let mut t = x.clone();
for blk in &blocks {
t = blk.forward(t, None);
let _ = t.clone().into_data(); }
times_sync.push(t0.elapsed().as_secs_f64() * 1000.0);
}
for _ in 0..warmup {
let mut t = x.clone();
for blk in &blocks { t = blk.forward(t, None); }
let _ = t.into_data();
}
let mut times_chain = Vec::new();
for _ in 0..runs {
let t0 = Instant::now();
let mut t = x.clone();
for blk in &blocks { t = blk.forward(t, None); }
let _ = t.into_data(); times_chain.push(t0.elapsed().as_secs_f64() * 1000.0);
}
let sync_best = times_sync.iter().cloned().fold(f64::INFINITY, f64::min);
let chain_best = times_chain.iter().cloned().fold(f64::INFINITY, f64::min);
let sync_med = { let mut s = times_sync.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
let chain_med = { let mut s = times_chain.clone(); s.sort_by(|a,b| a.partial_cmp(b).unwrap()); s[s.len()/2] };
println!(" Per-block sync (12 syncs): best={sync_best:.0}ms med={sync_med:.0}ms");
println!(" Single sync (1 sync): best={chain_best:.0}ms med={chain_med:.0}ms");
println!(" Speedup: {:.2}x", sync_best / chain_best);
}