use candle_core::{DType, Device, Tensor};
use cortex_rust::kernels::matmul_4bit::gemm_4bit;
use std::time::Instant;
fn main() -> anyhow::Result<()> {
println!("=== gemm_4bit Benchmark ===\n");
let device = Device::cuda_if_available(0)?;
println!("Device: {:?}\n", device);
let configs = [
("Embedding lookup equiv", 1, 5120, 32000), ("Q/K/V proj", 1, 5120, 5120), ("O proj", 1, 5120, 5120), ("Gate proj", 1, 5120, 13824), ("Up proj", 1, 5120, 13824), ("Down proj", 1, 13824, 5120), ];
let group_size = 128usize;
for (name, batch, in_dim, out_dim) in configs {
println!("--- {} ---", name);
println!(
"Shape: [{}, {}] x [{}, {}]",
batch,
in_dim,
out_dim,
in_dim / 2
);
let x = Tensor::randn(0.0f32, 1.0, (batch, in_dim), &device)?;
let packed_dim = in_dim / 2;
let w_packed_data: Vec<u8> = (0..(out_dim * packed_dim))
.map(|i| ((i % 16) as u8) | (((i / 16) % 16) as u8) << 4)
.collect();
let w_packed = Tensor::from_vec(w_packed_data, (out_dim, packed_dim), &device)?;
let n_groups = in_dim.div_ceil(group_size);
let scales = Tensor::ones((out_dim, n_groups), DType::F32, &device)?;
for _ in 0..3 {
let _ = gemm_4bit(&x, &w_packed, &scales, group_size)?;
}
let iterations = 10;
let start = Instant::now();
for _ in 0..iterations {
let output = gemm_4bit(&x, &w_packed, &scales, group_size)?;
let _ = output.flatten_all()?.to_vec1::<f32>();
}
let elapsed = start.elapsed();
let avg_ms = elapsed.as_secs_f64() * 1000.0 / iterations as f64;
let flops = 2.0 * batch as f64 * in_dim as f64 * out_dim as f64;
let gflops = flops / (avg_ms / 1000.0) / 1e9;
let bytes = (batch * in_dim * 4) as f64 + (out_dim * packed_dim) as f64 + (out_dim * n_groups * 4) as f64 + (batch * out_dim * 4) as f64; let gbps = bytes / (avg_ms / 1000.0) / 1e9;
println!(" Time: {:.3} ms", avg_ms);
println!(" GFLOPS: {:.2}", gflops);
println!(" Memory BW: {:.2} GB/s", gbps);
println!();
}
println!("--- Sync Overhead Test ---");
let x = Tensor::randn(0.0f32, 1.0, (1, 5120), &device)?;
let packed_dim = 5120 / 2;
let w_packed_data: Vec<u8> = (0..(5120 * packed_dim)).map(|i| (i % 16) as u8).collect();
let w_packed = Tensor::from_vec(w_packed_data, (5120, packed_dim), &device)?;
let n_groups = 5120_usize.div_ceil(group_size);
let scales = Tensor::ones((5120, n_groups), DType::F32, &device)?;
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = gemm_4bit(&x, &w_packed, &scales, group_size)?;
}
let no_sync_ms = start.elapsed().as_secs_f64() * 1000.0 / iterations as f64;
let start = Instant::now();
for _ in 0..iterations {
let output = gemm_4bit(&x, &w_packed, &scales, group_size)?;
let _ = output.flatten_all()?.to_vec1::<f32>();
}
let with_sync_ms = start.elapsed().as_secs_f64() * 1000.0 / iterations as f64;
println!(" Without sync: {:.3} ms/call", no_sync_ms);
println!(" With sync (to_vec1): {:.3} ms/call", with_sync_ms);
println!(
" Sync overhead: {:.3} ms ({:.1}%)",
with_sync_ms - no_sync_ms,
(with_sync_ms - no_sync_ms) / with_sync_ms * 100.0
);
Ok(())
}