use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Instant;
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys as cublas_sys};
use cudarc::driver::CudaContext;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DeviceCalibration {
pub fp64_gflops: f64,
pub h2d_gb_s: f64,
pub d2h_gb_s: f64,
}
impl DeviceCalibration {
pub fn is_usable(&self) -> bool {
self.fp64_gflops.is_finite()
&& self.fp64_gflops > 0.0
&& self.h2d_gb_s.is_finite()
&& self.h2d_gb_s > 0.0
&& self.d2h_gb_s.is_finite()
&& self.d2h_gb_s > 0.0
}
}
pub fn measure_device(ctx: Arc<CudaContext>) -> Result<DeviceCalibration, String> {
ctx.bind_to_thread()
.map_err(|e| format!("bind_to_thread: {e}"))?;
let stream = ctx.new_stream().map_err(|e| format!("new_stream: {e}"))?;
let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas_init: {e}"))?;
const M: usize = 1024;
const N: usize = 1024;
const K: usize = 1024;
const TRANSFER_BYTES: usize = 32 * 1024 * 1024;
const TRANSFER_F64: usize = TRANSFER_BYTES / std::mem::size_of::<f64>();
let a_host: Vec<f64> = (0..M * K).map(|i| (i as f64).sin()).collect();
let b_host: Vec<f64> = (0..K * N).map(|i| (i as f64).cos()).collect();
let mut a_dev = stream
.alloc_zeros::<f64>(M * K)
.map_err(|e| format!("alloc A {}x{}: {e}", M, K))?;
let mut b_dev = stream
.alloc_zeros::<f64>(K * N)
.map_err(|e| format!("alloc B {}x{}: {e}", K, N))?;
let mut c_dev = stream
.alloc_zeros::<f64>(M * N)
.map_err(|e| format!("alloc C {}x{}: {e}", M, N))?;
let transfer_host: Vec<f64> = vec![0.0_f64; TRANSFER_F64];
let mut transfer_dev = stream
.alloc_zeros::<f64>(TRANSFER_F64)
.map_err(|e| format!("alloc transfer buffer {} bytes: {e}", TRANSFER_BYTES))?;
stream
.memcpy_htod(transfer_host.as_slice(), &mut transfer_dev)
.map_err(|e| format!("h2d warmup: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("h2d warmup sync: {e}"))?;
let h2d_gb_s = best_of_n_transfer(3, || {
let start = Instant::now();
stream
.memcpy_htod(transfer_host.as_slice(), &mut transfer_dev)
.map_err(|e| format!("h2d copy: {e}"))?;
stream.synchronize().map_err(|e| format!("h2d sync: {e}"))?;
Ok(bytes_per_sec(TRANSFER_BYTES, start.elapsed().as_secs_f64()))
})?;
let mut transfer_back: Vec<f64> = vec![0.0_f64; TRANSFER_F64];
stream
.memcpy_dtoh(&transfer_dev, transfer_back.as_mut_slice())
.map_err(|e| format!("d2h warmup: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("d2h warmup sync: {e}"))?;
let d2h_gb_s = best_of_n_transfer(3, || {
let start = Instant::now();
stream
.memcpy_dtoh(&transfer_dev, transfer_back.as_mut_slice())
.map_err(|e| format!("d2h copy: {e}"))?;
stream.synchronize().map_err(|e| format!("d2h sync: {e}"))?;
Ok(bytes_per_sec(TRANSFER_BYTES, start.elapsed().as_secs_f64()))
})?;
stream
.memcpy_htod(a_host.as_slice(), &mut a_dev)
.map_err(|e| format!("h2d A: {e}"))?;
stream
.memcpy_htod(b_host.as_slice(), &mut b_dev)
.map_err(|e| format!("h2d B: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("h2d AB sync: {e}"))?;
let m_i = i32::try_from(M).map_err(|e| format!("M overflow i32: {e}"))?;
let n_i = i32::try_from(N).map_err(|e| format!("N overflow i32: {e}"))?;
let k_i = i32::try_from(K).map_err(|e| format!("K overflow i32: {e}"))?;
let cfg = GemmConfig::<f64> {
transa: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
transb: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
m: m_i,
n: n_i,
k: k_i,
alpha: 1.0_f64,
lda: m_i,
ldb: k_i,
beta: 0.0_f64,
ldc: m_i,
};
for i in 0..2 {
unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut c_dev) }
.map_err(|e| format!("dgemm warmup {i}: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("dgemm warmup {i} sync: {e}"))?;
}
let flops = 2.0_f64 * (M as f64) * (N as f64) * (K as f64);
let fp64_gflops = best_of_n_transfer(5, || {
let start = Instant::now();
unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut c_dev) }
.map_err(|e| format!("dgemm timed: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("dgemm timed sync: {e}"))?;
let elapsed = start.elapsed().as_secs_f64();
if elapsed <= 0.0 {
return Err(format!("dgemm timing nonpositive elapsed: {elapsed}"));
}
Ok(flops / elapsed / 1e9)
})?;
let _c_host = stream
.clone_dtoh(&c_dev)
.map_err(|e| format!("d2h C result: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("d2h C result sync: {e}"))?;
let calibration = DeviceCalibration {
fp64_gflops,
h2d_gb_s,
d2h_gb_s,
};
if calibration.is_usable() {
Ok(calibration)
} else {
Err(format!(
"calibration result not usable: fp64_gflops={} h2d_gb_s={} d2h_gb_s={}",
calibration.fp64_gflops, calibration.h2d_gb_s, calibration.d2h_gb_s
))
}
}
pub fn measured_cpu_fp64_gflops() -> f64 {
static CACHED: OnceLock<f64> = OnceLock::new();
*CACHED.get_or_init(measure_cpu_fp64_inner)
}
fn measure_cpu_fp64_inner() -> f64 {
use faer::{Mat, linalg::matmul::matmul};
const N: usize = 512;
let a = Mat::<f64>::from_fn(N, N, |i, j| ((i + j) as f64).sin());
let b = Mat::<f64>::from_fn(N, N, |i, j| ((i * 3 + j) as f64).cos());
let mut c = Mat::<f64>::zeros(N, N);
for _ in 0..2 {
matmul(
c.as_mut(),
faer::Accum::Replace,
a.as_ref(),
b.as_ref(),
1.0,
faer::Par::rayon(0),
);
}
let flops = 2.0_f64 * (N as f64).powi(3);
let mut best_gflops = 0.0_f64;
for _ in 0..5 {
let start = Instant::now();
matmul(
c.as_mut(),
faer::Accum::Replace,
a.as_ref(),
b.as_ref(),
1.0,
faer::Par::rayon(0),
);
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.0 {
let gflops = flops / elapsed / 1e9;
if gflops > best_gflops {
best_gflops = gflops;
}
}
}
if best_gflops.is_finite() && best_gflops > 0.0 {
best_gflops
} else {
10.0
}
}
fn bytes_per_sec(bytes: usize, seconds: f64) -> f64 {
if seconds <= 0.0 {
0.0
} else {
(bytes as f64) / seconds / 1e9
}
}
fn best_of_n_transfer<F>(n: usize, mut f: F) -> Result<f64, String>
where
F: FnMut() -> Result<f64, String>,
{
let mut best = f64::NEG_INFINITY;
let mut any = false;
let mut last_err: Option<String> = None;
for _ in 0..n {
match f() {
Ok(value) if value.is_finite() && value > best => {
best = value;
any = true;
}
Ok(_) => {}
Err(e) => last_err = Some(e),
}
}
if any {
Ok(best)
} else {
Err(last_err.unwrap_or_else(|| format!("no usable sample across {n} iterations")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calibration_usable_rejects_nonpositive_values() {
assert!(
DeviceCalibration {
fp64_gflops: 200.0,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
assert!(
!DeviceCalibration {
fp64_gflops: 0.0,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
assert!(
!DeviceCalibration {
fp64_gflops: f64::NAN,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
}
#[test]
fn cpu_fp64_calibration_runs() {
let gflops = measure_cpu_fp64_inner();
assert!(gflops.is_finite());
assert!(gflops > 0.0);
}
}