use crate::analysis::roofline::{Precision, RooflineModel};
use anyhow::Result;
use serde::Serialize;
#[derive(Debug, Clone, Serialize)]
pub struct BackendResult {
pub name: String,
pub wall_time_us: f64,
pub tflops: f64,
pub bandwidth_gbps: f64,
pub available: bool,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub measured: bool,
}
fn gemm_tflops(size: u32, time_us: f64) -> f64 {
if time_us <= 0.0 {
return 0.0;
}
let flops = 2.0 * (size as f64).powi(3);
flops / (time_us * 1e-6) / 1e12
}
fn get_actual_gemm_timing(size: u32) -> Option<(f64, f64)> {
let candidates = [
"/mnt/nvme-raid0/targets/trueno/release/examples/benchmark_matrix_suite",
"./target/release/examples/benchmark_matrix_suite",
];
let binary_path = candidates
.iter()
.find(|p| std::path::Path::new(p).exists())?;
let output = std::process::Command::new(*binary_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
let pattern = format!("Matrix Multiplication ({}x{}x{})", size, size, size);
for line in stdout.lines() {
if line.contains(&pattern) {
let after_dots = line.split("...").nth(1)?;
let time_ms = after_dots.split("ms").next()?.trim().parse::<f64>().ok()?;
let gflops = after_dots
.split('(')
.nth(1)?
.split(" GFLOPS")
.next()?
.trim()
.parse::<f64>()
.ok()?;
return Some((time_ms * 1000.0, gflops));
}
}
None
}
fn estimate_scalar_time_us(size: u32) -> f64 {
let ratio = (size as f64 / 256.0).powi(3);
11_700.0 * ratio
}
fn estimate_avx2_time_us(size: u32) -> f64 {
let flops = 2.0 * (size as f64).powi(3);
let gflops = 72.0; flops / (gflops * 1e9) * 1e6
}
fn estimate_avx512_time_us(size: u32) -> f64 {
let flops = 2.0 * (size as f64).powi(3);
let gflops = 80.0; flops / (gflops * 1e9) * 1e6
}
fn estimate_cuda_time_us(size: u32) -> f64 {
let ratio = (size as f64 / 512.0).powi(3);
23.2 * ratio
}
fn estimate_cublas_time_us(size: u32) -> f64 {
estimate_cuda_time_us(size) / 3.0
}
#[cfg(feature = "cuda")]
fn measure_cublas_gemm(size: u32) -> Option<(f64, f64)> {
use trueno_gpu::driver::{CublasHandle, CudaContext, CudaStream, GemmOp, GpuBuffer};
let ctx = CudaContext::new(0).ok()?;
let stream = CudaStream::new(&ctx).ok()?;
let handle = CublasHandle::new(&ctx).ok()?;
handle.set_stream(&stream).ok()?;
let n = size as usize;
let a_data = vec![0x3C00u16; n * n]; let b_data = vec![0x3C00u16; n * n];
let c_data = vec![0u16; n * n];
let a_buf = GpuBuffer::from_host(&ctx, &a_data).ok()?;
let b_buf = GpuBuffer::from_host(&ctx, &b_data).ok()?;
let c_buf = GpuBuffer::from_host(&ctx, &c_data).ok()?;
for _ in 0..5 {
let _ = handle.gemm_f16(
GemmOp::NoTrans,
GemmOp::NoTrans,
n as i32,
n as i32,
n as i32,
1.0,
a_buf.as_ptr(),
n as i32,
b_buf.as_ptr(),
n as i32,
0.0,
c_buf.as_ptr(),
n as i32,
);
}
stream.synchronize().ok()?;
let iters: u32 = if n <= 512 {
200
} else if n <= 1024 {
100
} else {
30
};
let start = std::time::Instant::now();
for _ in 0..iters {
let _ = handle.gemm_f16(
GemmOp::NoTrans,
GemmOp::NoTrans,
n as i32,
n as i32,
n as i32,
1.0,
a_buf.as_ptr(),
n as i32,
b_buf.as_ptr(),
n as i32,
0.0,
c_buf.as_ptr(),
n as i32,
);
}
stream.synchronize().ok()?;
let elapsed = start.elapsed();
let per_call_us = elapsed.as_micros() as f64 / iters as f64;
let flops = 2.0 * (n as f64).powi(3);
let tflops = flops / (per_call_us * 1e6);
Some((per_call_us, tflops))
}
#[cfg(feature = "cuda")]
fn measure_ptx_gemm(size: u32) -> Option<(f64, f64)> {
use std::ffi::c_void;
use trueno_gpu::driver::{CudaContext, CudaModule, CudaStream, GpuBuffer, LaunchConfig};
use trueno_gpu::kernels::build_cta64x128_mma_pipeline_fp16;
use trueno_gpu::ptx::PtxModule;
let ctx = CudaContext::new(0).ok()?;
let stream = CudaStream::new(&ctx).ok()?;
let n = size as usize;
let a16 = vec![0x3C00u16; n * n];
let b16 = vec![0x3C00u16; n * n];
let c32 = vec![0.0f32; n * n];
let a_buf = GpuBuffer::from_host(&ctx, &a16).ok()?;
let b_buf = GpuBuffer::from_host(&ctx, &b16).ok()?;
let c_buf = GpuBuffer::from_host(&ctx, &c32).ok()?;
let kernel = build_cta64x128_mma_pipeline_fp16(n as u32, n as u32, n as u32);
let ptx = PtxModule::new().target("sm_80").add_kernel(kernel).emit();
let mut module = CudaModule::from_ptx(&ctx, &ptx).ok()?;
let cfg = LaunchConfig {
grid: (((n + 127) / 128) as u32, ((n + 63) / 64) as u32, 1),
block: (512, 1, 1),
shared_mem: 18432,
};
let mut a_ptr = a_buf.as_ptr();
let mut b_ptr = b_buf.as_ptr();
let mut c_ptr = c_buf.as_ptr();
let mut m_v = n as u32;
let mut n_v = n as u32;
let mut k_v = n as u32;
let mut args: Vec<*mut c_void> = vec![
&mut a_ptr as *mut _ as *mut c_void,
&mut b_ptr as *mut _ as *mut c_void,
&mut c_ptr as *mut _ as *mut c_void,
&mut m_v as *mut _ as *mut c_void,
&mut n_v as *mut _ as *mut c_void,
&mut k_v as *mut _ as *mut c_void,
];
for _ in 0..5 {
unsafe {
stream
.launch_kernel(
&mut module,
"gemm_cta64x128_mma_pipeline_fp16",
&cfg,
&mut args,
)
.ok()?;
}
}
stream.synchronize().ok()?;
let iters: u32 = if n <= 512 {
100
} else if n <= 1024 {
50
} else {
20
};
let start = std::time::Instant::now();
for _ in 0..iters {
unsafe {
stream
.launch_kernel(
&mut module,
"gemm_cta64x128_mma_pipeline_fp16",
&cfg,
&mut args,
)
.ok()?;
}
}
stream.synchronize().ok()?;
let per_call_us = start.elapsed().as_micros() as f64 / iters as f64;
let flops = 2.0 * (n as f64).powi(3);
let tflops = flops / (per_call_us * 1e6);
Some((per_call_us, tflops))
}
pub fn run_compare(kernel: &str, size: u32, backends_str: &str, json: bool) -> Result<()> {
let backends: Vec<&str> = backends_str.split(',').map(|s| s.trim()).collect();
if !json {
println!("\n=== CGP Cross-Backend Comparison: {kernel} ({size}x{size}x{size}) ===\n");
}
let mut results: Vec<BackendResult> = Vec::new();
let actual = get_actual_gemm_timing(size);
for backend in &backends {
let (time_us, available, measured) = match *backend {
"scalar" => (estimate_scalar_time_us(size), true, false),
"avx2" | "avx512" => {
#[cfg(target_arch = "x86_64")]
let avail = if *backend == "avx512" {
std::arch::is_x86_feature_detected!("avx512f")
} else {
std::arch::is_x86_feature_detected!("avx2")
};
#[cfg(not(target_arch = "x86_64"))]
let avail = false;
if let Some((actual_us, _)) = actual {
(actual_us, avail, true)
} else if *backend == "avx512" {
(estimate_avx512_time_us(size), avail, false)
} else {
(estimate_avx2_time_us(size), avail, false)
}
}
"neon" => {
let avail = cfg!(target_arch = "aarch64");
(estimate_scalar_time_us(size) / 4.0, avail, false)
}
"cuda" => {
let avail = which::which("nvidia-smi").is_ok();
#[cfg(feature = "cuda")]
if avail {
if let Some((time_us, _tflops)) = measure_ptx_gemm(size) {
(time_us, true, true)
} else {
(estimate_cuda_time_us(size), avail, false)
}
} else {
(estimate_cuda_time_us(size), avail, false)
}
#[cfg(not(feature = "cuda"))]
{
(estimate_cuda_time_us(size), avail, false)
}
}
"cublas" => {
let avail = which::which("nvidia-smi").is_ok();
#[cfg(feature = "cuda")]
if avail {
if let Some((time_us, _tflops)) = measure_cublas_gemm(size) {
(time_us, true, true)
} else {
(estimate_cublas_time_us(size), avail, false)
}
} else {
(estimate_cublas_time_us(size), avail, false)
}
#[cfg(not(feature = "cuda"))]
{
(estimate_cublas_time_us(size), avail, false)
}
}
"wgpu" => {
let avail = which::which("nvidia-smi").is_ok();
(estimate_cuda_time_us(size) * 2.0, avail, false)
}
other => {
eprintln!(" Warning: unknown backend '{other}', skipping");
continue;
}
};
let tflops = gemm_tflops(size, time_us);
results.push(BackendResult {
name: backend.to_string(),
wall_time_us: time_us,
tflops,
bandwidth_gbps: 0.0,
available,
measured,
});
}
results.sort_by(|a, b| {
a.wall_time_us
.partial_cmp(&b.wall_time_us)
.unwrap_or(std::cmp::Ordering::Equal)
});
if json {
println!("{}", serde_json::to_string_pretty(&results)?);
return Ok(());
}
let best_time = results.first().map(|r| r.wall_time_us).unwrap_or(1.0);
println!(
" {:12} {:>12} {:>12} {:>10} {:>10} {:>8} {:>5}",
"Backend", "Time (us)", "TFLOP/s", "Efficiency", "vs Best", "Avail", "Src"
);
println!(" {}", "-".repeat(75));
let model = RooflineModel::rtx_4090();
let gpu_peak = model
.peak_compute
.get(&Precision::Fp16)
.copied()
.unwrap_or(330.0e12);
let cores = num_cpus::get_physical();
let cpu_peak = 2.0 * 8.0 * 2.0 * 3.5e9 * cores as f64;
for r in &results {
let peak = if r.name.contains("cuda") || r.name.contains("cublas") || r.name == "wgpu" {
gpu_peak / 1e12
} else {
cpu_peak / 1e12
};
let efficiency = if peak > 0.0 {
r.tflops / peak * 100.0
} else {
0.0
};
let ratio = format!("{:.2}x", r.wall_time_us / best_time);
let avail = if r.available { "yes" } else { "no" };
let time_str = if r.wall_time_us >= 1000.0 {
format!("{:.1} ms", r.wall_time_us / 1000.0)
} else {
format!("{:.1}", r.wall_time_us)
};
let src = if r.measured { "M" } else { "E" };
println!(
" {:12} {:>12} {:>12.1} {:>9.1}% {:>10} {:>8} {:>5}",
r.name, time_str, r.tflops, efficiency, ratio, avail, src
);
}
let has_measured = results.iter().any(|r| r.measured);
let has_estimated = results.iter().any(|r| !r.measured);
if has_measured || has_estimated {
print!(" Src: ");
if has_measured {
print!("M=measured ");
}
if has_estimated {
print!("E=estimated ");
}
println!();
}
if let Some(best) = results.first() {
if let Some(worst) = results.last() {
let speedup = worst.wall_time_us / best.wall_time_us;
println!(
"\n Best: {} ({:.1}x faster than {})",
best.name, speedup, worst.name
);
}
}
let has_cpu = results
.iter()
.any(|r| matches!(r.name.as_str(), "scalar" | "avx2" | "avx512"));
let has_gpu = results
.iter()
.any(|r| matches!(r.name.as_str(), "cuda" | "cublas" | "wgpu"));
if has_cpu && has_gpu {
let best_cpu = results
.iter()
.filter(|r| matches!(r.name.as_str(), "scalar" | "avx2" | "avx512"))
.map(|r| r.wall_time_us)
.fold(f64::INFINITY, f64::min);
let best_gpu = results
.iter()
.filter(|r| matches!(r.name.as_str(), "cuda" | "cublas" | "wgpu"))
.map(|r| r.wall_time_us)
.fold(f64::INFINITY, f64::min);
if best_gpu > 0.0 {
println!(
" CPU→GPU gap: {:.0}x (expected for large GEMM)",
best_cpu / best_gpu
);
}
}
println!();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemm_tflops() {
let tflops = gemm_tflops(512, 23.2);
assert!(
(tflops - 11.56).abs() < 0.1,
"Expected ~11.6 TFLOP/s, got {tflops:.2}"
);
}
#[test]
fn test_scalar_slower_than_avx2() {
let scalar = estimate_scalar_time_us(512);
let avx2 = estimate_avx2_time_us(512);
assert!(scalar > avx2 * 3.0, "Scalar should be >3x slower than AVX2");
}
#[test]
fn test_cuda_faster_than_cpu() {
let cpu = estimate_avx2_time_us(4096);
let cuda = estimate_cuda_time_us(4096);
assert!(
cpu > cuda * 10.0,
"CPU should be >10x slower than CUDA for 4096"
);
}
#[test]
fn test_cuda_faster_than_scalar_at_256() {
let scalar = estimate_scalar_time_us(256);
let cuda = estimate_cuda_time_us(256);
assert!(cuda < scalar, "CUDA should be faster than scalar at 256");
}
#[test]
fn test_simd_faster_than_scalar() {
let scalar = estimate_scalar_time_us(1024);
let avx2 = estimate_avx2_time_us(1024);
assert!(
scalar / avx2 >= 3.0,
"AVX2 speedup {:.1}x should be >= 3x",
scalar / avx2
);
}
#[test]
fn test_cublas_faster_than_ptx() {
let ptx = estimate_cuda_time_us(4096);
let cublas = estimate_cublas_time_us(4096);
assert!(cublas < ptx, "cuBLAS should be faster than PTX at 4096");
}
#[test]
fn test_run_compare_basic() {
let result = run_compare("gemm", 256, "scalar,avx2", false);
assert!(result.is_ok());
}
#[test]
fn test_run_compare_json() {
let result = run_compare("gemm", 256, "scalar,avx2", true);
assert!(result.is_ok());
}
#[test]
fn test_get_actual_gemm_timing() {
if let Some((time_us, gflops)) = get_actual_gemm_timing(1024) {
assert!(time_us > 0.0, "time should be positive");
assert!(gflops > 10.0, "GFLOPS should be > 10 for 1024 GEMM");
assert!(gflops < 2000.0, "GFLOPS should be < 2000");
eprintln!(
"Actual GEMM 1024: {:.1} us = {:.0} GFLOPS [MEASURED]",
time_us, gflops
);
} else {
eprintln!("benchmark_matrix_suite binary not found — actual data unavailable");
}
}
}