use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Instant;
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys as cublas_sys};
use cudarc::driver::CudaContext;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DeviceCalibration {
pub fp64_gflops: f64,
pub h2d_gb_s: f64,
pub d2h_gb_s: f64,
}
impl DeviceCalibration {
pub fn is_usable(&self) -> bool {
self.fp64_gflops.is_finite()
&& self.fp64_gflops > 0.0
&& self.h2d_gb_s.is_finite()
&& self.h2d_gb_s > 0.0
&& self.d2h_gb_s.is_finite()
&& self.d2h_gb_s > 0.0
}
}
pub fn measure_device(ctx: Arc<CudaContext>) -> Option<DeviceCalibration> {
ctx.bind_to_thread().ok()?;
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone()).ok()?;
const M: usize = 1024;
const N: usize = 1024;
const K: usize = 1024;
const TRANSFER_BYTES: usize = 32 * 1024 * 1024;
const TRANSFER_F64: usize = TRANSFER_BYTES / std::mem::size_of::<f64>();
let a_host: Vec<f64> = (0..M * K).map(|i| (i as f64).sin()).collect();
let b_host: Vec<f64> = (0..K * N).map(|i| (i as f64).cos()).collect();
let mut a_dev = stream.alloc_zeros::<f64>(M * K).ok()?;
let mut b_dev = stream.alloc_zeros::<f64>(K * N).ok()?;
let mut c_dev = stream.alloc_zeros::<f64>(M * N).ok()?;
let transfer_host: Vec<f64> = vec![0.0_f64; TRANSFER_F64];
let mut transfer_dev = stream.alloc_zeros::<f64>(TRANSFER_F64).ok()?;
stream
.memcpy_htod(transfer_host.as_slice(), &mut transfer_dev)
.ok()?;
stream.synchronize().ok()?;
let h2d_gb_s = best_of_n_transfer(3, || {
let start = Instant::now();
stream
.memcpy_htod(transfer_host.as_slice(), &mut transfer_dev)
.ok()?;
stream.synchronize().ok()?;
Some(bytes_per_sec(TRANSFER_BYTES, start.elapsed().as_secs_f64()))
})?;
let mut transfer_back: Vec<f64> = vec![0.0_f64; TRANSFER_F64];
stream
.memcpy_dtoh(&transfer_dev, transfer_back.as_mut_slice())
.ok()?;
stream.synchronize().ok()?;
let d2h_gb_s = best_of_n_transfer(3, || {
let start = Instant::now();
stream
.memcpy_dtoh(&transfer_dev, transfer_back.as_mut_slice())
.ok()?;
stream.synchronize().ok()?;
Some(bytes_per_sec(TRANSFER_BYTES, start.elapsed().as_secs_f64()))
})?;
stream.memcpy_htod(a_host.as_slice(), &mut a_dev).ok()?;
stream.memcpy_htod(b_host.as_slice(), &mut b_dev).ok()?;
stream.synchronize().ok()?;
let m_i = i32::try_from(M).ok()?;
let n_i = i32::try_from(N).ok()?;
let k_i = i32::try_from(K).ok()?;
let cfg = GemmConfig::<f64> {
transa: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
transb: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
m: m_i,
n: n_i,
k: k_i,
alpha: 1.0_f64,
lda: m_i,
ldb: k_i,
beta: 0.0_f64,
ldc: m_i,
};
for _ in 0..2 {
unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut c_dev) }.ok()?;
stream.synchronize().ok()?;
}
let flops = 2.0_f64 * (M as f64) * (N as f64) * (K as f64);
let fp64_gflops = best_of_n_transfer(5, || {
let start = Instant::now();
unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut c_dev) }.ok()?;
stream.synchronize().ok()?;
let elapsed = start.elapsed().as_secs_f64();
if elapsed <= 0.0 {
return None;
}
Some(flops / elapsed / 1e9)
})?;
let _c_host = stream.clone_dtoh(&c_dev).ok()?;
stream.synchronize().ok()?;
let calibration = DeviceCalibration {
fp64_gflops,
h2d_gb_s,
d2h_gb_s,
};
if calibration.is_usable() {
Some(calibration)
} else {
None
}
}
pub fn measured_cpu_fp64_gflops() -> f64 {
static CACHED: OnceLock<f64> = OnceLock::new();
*CACHED.get_or_init(measure_cpu_fp64_inner)
}
fn measure_cpu_fp64_inner() -> f64 {
use faer::{Mat, linalg::matmul::matmul};
const N: usize = 512;
let a = Mat::<f64>::from_fn(N, N, |i, j| ((i + j) as f64).sin());
let b = Mat::<f64>::from_fn(N, N, |i, j| ((i * 3 + j) as f64).cos());
let mut c = Mat::<f64>::zeros(N, N);
for _ in 0..2 {
matmul(
c.as_mut(),
faer::Accum::Replace,
a.as_ref(),
b.as_ref(),
1.0,
faer::Par::rayon(0),
);
}
let flops = 2.0_f64 * (N as f64).powi(3);
let mut best_gflops = 0.0_f64;
for _ in 0..5 {
let start = Instant::now();
matmul(
c.as_mut(),
faer::Accum::Replace,
a.as_ref(),
b.as_ref(),
1.0,
faer::Par::rayon(0),
);
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.0 {
let gflops = flops / elapsed / 1e9;
if gflops > best_gflops {
best_gflops = gflops;
}
}
}
if best_gflops.is_finite() && best_gflops > 0.0 {
best_gflops
} else {
10.0
}
}
fn bytes_per_sec(bytes: usize, seconds: f64) -> f64 {
if seconds <= 0.0 {
0.0
} else {
(bytes as f64) / seconds / 1e9
}
}
fn best_of_n_transfer<F>(n: usize, mut f: F) -> Option<f64>
where
F: FnMut() -> Option<f64>,
{
let mut best = f64::NEG_INFINITY;
let mut any = false;
for _ in 0..n {
if let Some(value) = f() {
if value.is_finite() && value > best {
best = value;
any = true;
}
}
}
if any { Some(best) } else { None }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calibration_usable_rejects_nonpositive_values() {
assert!(
DeviceCalibration {
fp64_gflops: 200.0,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
assert!(
!DeviceCalibration {
fp64_gflops: 0.0,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
assert!(
!DeviceCalibration {
fp64_gflops: f64::NAN,
h2d_gb_s: 12.0,
d2h_gb_s: 12.0,
}
.is_usable()
);
}
#[test]
fn cpu_fp64_calibration_runs() {
let gflops = measure_cpu_fp64_inner();
assert!(gflops.is_finite());
assert!(gflops > 0.0);
}
}