use super::calibration::measured_cpu_fp64_gflops;
use super::device::GpuDeviceInfo;
const CPU_FP64_GFLOPS_FLOOR: f64 = 5.0;
const PCIE_GB_PER_S_FLOOR: f64 = 4.0;
const AVAILABLE_GPU_EFFECTIVE_SPEEDUP_FLOOR: f64 = 1.05;
#[derive(Clone, Debug)]
pub struct DispatchPolicy {
pub xtwx_min_rows: usize,
pub gemm_min_flops: u64,
pub gemv_min_flops: u64,
pub spmv_min_nnz: usize,
pub spmv_min_rows: usize,
pub chol_min_p: usize,
pub syevd_min_p: usize,
pub trsm_min_flops: u64,
}
impl DispatchPolicy {
pub fn for_devices(devices: &[GpuDeviceInfo]) -> Self {
if devices.is_empty() {
return Self::cpu_only();
}
let aggregate_gpu_gflops = devices
.iter()
.map(GpuDeviceInfo::peak_fp64_gflops)
.sum::<f64>();
let pcie_gb_per_s = devices
.iter()
.map(GpuDeviceInfo::pcie_gb_per_s)
.fold(f64::INFINITY, f64::min)
.max(PCIE_GB_PER_S_FLOOR);
Self::from_measurements(
aggregate_gpu_gflops,
measured_cpu_fp64_gflops().max(CPU_FP64_GFLOPS_FLOOR),
pcie_gb_per_s,
)
}
pub fn for_device(device: Option<&GpuDeviceInfo>) -> Self {
let Some(device) = device else {
return Self::cpu_only();
};
Self::from_measurements(
device.peak_fp64_gflops(),
measured_cpu_fp64_gflops().max(CPU_FP64_GFLOPS_FLOOR),
device.pcie_gb_per_s().max(PCIE_GB_PER_S_FLOOR),
)
}
fn from_measurements(peak_gpu_gflops: f64, cpu_gflops: f64, pcie_gb_per_s: f64) -> Self {
let effective_gpu_gflops =
peak_gpu_gflops.max(cpu_gflops * AVAILABLE_GPU_EFFECTIVE_SPEEDUP_FLOOR);
let speedup =
(effective_gpu_gflops / cpu_gflops).max(AVAILABLE_GPU_EFFECTIVE_SPEEDUP_FLOOR);
let gemm_min_flops = flops_threshold(
32.0 * 1024.0 * 1024.0,
effective_gpu_gflops,
cpu_gflops,
pcie_gb_per_s,
);
let gemv_min_flops = flops_threshold(
16.0 * 1024.0 * 1024.0,
effective_gpu_gflops,
cpu_gflops,
pcie_gb_per_s,
);
let trsm_min_flops = flops_threshold(
16.0 * 1024.0 * 1024.0,
effective_gpu_gflops,
cpu_gflops,
pcie_gb_per_s,
);
let xtwx_min_rows = usize_threshold((4096.0 / speedup).clamp(512.0, 65_536.0));
let spmv_min_nnz = usize_threshold((1_000_000.0 / speedup).max(100_000.0));
let spmv_min_rows = 1_024;
let chol_min_p = usize_threshold((4096.0 / speedup).clamp(128.0, 8_192.0));
let syevd_min_p = usize_threshold((2048.0 / speedup).clamp(64.0, 4_096.0));
Self {
xtwx_min_rows,
gemm_min_flops,
gemv_min_flops,
spmv_min_nnz,
spmv_min_rows,
chol_min_p,
syevd_min_p,
trsm_min_flops,
}
}
fn cpu_only() -> Self {
Self {
xtwx_min_rows: usize::MAX,
gemm_min_flops: u64::MAX,
gemv_min_flops: u64::MAX,
spmv_min_nnz: usize::MAX,
spmv_min_rows: usize::MAX,
chol_min_p: usize::MAX,
syevd_min_p: usize::MAX,
trsm_min_flops: u64::MAX,
}
}
pub fn route_chol_solve(&self, p: usize) -> bool {
p >= self.chol_min_p
}
pub fn route_chol_batched(&self, p: usize, batch_size: usize) -> bool {
if p == 0 || batch_size == 0 {
return false;
}
let p64 = p as u64;
let p3 = p64.saturating_mul(p64).saturating_mul(p64);
let total_flops = (batch_size as u64).saturating_mul(p3 / 3);
total_flops >= self.gemm_min_flops
}
pub fn route_syevd(&self, p: usize) -> bool {
p >= self.syevd_min_p
}
pub fn route_trsm(&self, p: usize, rhs_cols: usize) -> bool {
let flops = (p as u64)
.saturating_mul(p as u64)
.saturating_mul(rhs_cols.max(1) as u64);
flops >= self.trsm_min_flops
}
pub fn route_xt_diag_y(&self, rows: usize, lhs_cols: usize, rhs_cols: usize) -> bool {
let flops = (rows as u64)
.saturating_mul(lhs_cols as u64)
.saturating_mul(rhs_cols.max(1) as u64)
.saturating_mul(2);
rows >= self.xtwx_min_rows && flops >= self.gemm_min_flops
}
pub fn route_gemm(&self, m: usize, n: usize, k: usize) -> bool {
let flops = (m as u64)
.saturating_mul(n as u64)
.saturating_mul(k.max(1) as u64)
.saturating_mul(2);
flops >= self.gemm_min_flops
}
pub fn route_gemm_batched(&self, m: usize, n: usize, k: usize, batch_size: usize) -> bool {
if batch_size == 0 {
return false;
}
let flops = (m as u64)
.saturating_mul(n as u64)
.saturating_mul(k.max(1) as u64)
.saturating_mul(2)
.saturating_mul(batch_size as u64);
flops >= self.gemm_min_flops
}
pub fn route_gemv(&self, rows: usize, cols: usize) -> bool {
let flops = (rows as u64).saturating_mul(cols as u64).saturating_mul(2);
flops >= self.gemv_min_flops
}
pub fn route_csr_spmv(&self, rows: usize, _cols: usize, nnz: usize) -> bool {
rows >= self.spmv_min_rows && nnz >= self.spmv_min_nnz
}
}
fn crossover_flops(
payload_bytes: f64,
peak_gpu_gflops: f64,
cpu_gflops: f64,
pcie_gb_per_s: f64,
) -> f64 {
if peak_gpu_gflops <= cpu_gflops {
return f64::INFINITY;
}
let cpu_flops_per_s = cpu_gflops * 1e9;
let gpu_flops_per_s = peak_gpu_gflops * 1e9;
let pcie_bytes_per_s = pcie_gb_per_s * 1e9;
payload_bytes * cpu_flops_per_s * gpu_flops_per_s
/ (pcie_bytes_per_s * (gpu_flops_per_s - cpu_flops_per_s))
}
fn flops_threshold(
payload_bytes: f64,
peak_gpu_gflops: f64,
cpu_gflops: f64,
pcie_gb_per_s: f64,
) -> u64 {
let threshold =
crossover_flops(payload_bytes, peak_gpu_gflops, cpu_gflops, pcie_gb_per_s).ceil();
if !threshold.is_finite() || threshold >= u64::MAX as f64 {
u64::MAX
} else if threshold <= 0.0 {
0
} else {
threshold as u64
}
}
fn usize_threshold(value: f64) -> usize {
let threshold = value.ceil();
if !threshold.is_finite() || threshold >= usize::MAX as f64 {
usize::MAX
} else if threshold <= 0.0 {
0
} else {
threshold as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
fn device(major: i32, sms: i32) -> GpuDeviceInfo {
use super::super::calibration::DeviceCalibration;
let per_sm_fp64_gflops = if major >= 9 {
200.0
} else if major >= 8 {
80.0
} else {
6.0
};
let fp64 = (sms as f64) * per_sm_fp64_gflops;
GpuDeviceInfo {
ordinal: 0,
name: "test-device".to_string(),
compute_capability_major: major,
compute_capability_minor: 0,
sm_count: sms,
total_memory_bytes: 16 * 1024 * 1024 * 1024,
calibration: DeviceCalibration {
fp64_gflops: fp64,
h2d_gb_s: 25.0,
d2h_gb_s: 25.0,
},
}
}
#[test]
fn faster_device_lowers_thresholds() {
let slower = DispatchPolicy::for_device(Some(&device(7, 40)));
let faster = DispatchPolicy::for_device(Some(&device(9, 132)));
assert!(faster.gemm_min_flops < slower.gemm_min_flops);
assert!(faster.gemv_min_flops < slower.gemv_min_flops);
assert!(faster.xtwx_min_rows <= slower.xtwx_min_rows);
}
#[test]
fn aggregate_devices_lower_batched_thresholds() {
let single = DispatchPolicy::for_devices(&[device(7, 40)]);
let fleet = DispatchPolicy::for_devices(&[
device(7, 40),
GpuDeviceInfo {
ordinal: 1,
..device(7, 40)
},
GpuDeviceInfo {
ordinal: 2,
..device(7, 40)
},
GpuDeviceInfo {
ordinal: 3,
..device(7, 40)
},
]);
assert!(fleet.gemm_min_flops < single.gemm_min_flops);
assert!(fleet.route_gemm_batched(512, 512, 512, 16));
}
#[test]
fn cpu_only_policy_never_routes() {
let p = DispatchPolicy::for_device(None);
assert!(!p.route_gemm(1_000_000, 1_000_000, 1_000_000));
assert!(!p.route_gemv(1_000_000, 1_000_000));
assert!(!p.route_xt_diag_y(1_000_000, 1_000, 1_000));
assert!(!p.route_csr_spmv(1_000_000, 1_000_000, 1_000_000_000));
assert!(!p.route_chol_solve(1_000_000));
assert!(!p.route_syevd(1_000_000));
assert!(!p.route_trsm(1_000_000, 1_000_000));
}
#[test]
fn slow_available_gpu_still_routes_bulk_work() {
let p = DispatchPolicy::from_measurements(
20.0, 200.0,
16.0,
);
assert!(p.gemm_min_flops < u64::MAX);
assert!(p.gemv_min_flops < u64::MAX);
assert!(p.trsm_min_flops < u64::MAX);
assert!(!p.route_gemm(128, 128, 128));
assert!(p.route_gemm(8_192, 8_192, 8_192));
assert!(p.route_xt_diag_y(1_000_000, 512, 512));
}
#[test]
fn route_xt_diag_y_uses_shape_only() {
let p = DispatchPolicy::for_device(Some(&device(8, 108)));
assert!(!p.route_xt_diag_y(128, 16, 16));
assert!(p.route_xt_diag_y(1_000_000, 512, 512));
}
#[test]
fn route_gemm_and_gemv_use_separate_thresholds() {
let p = DispatchPolicy::for_device(Some(&device(8, 108)));
assert!(!p.route_gemm(128, 128, 128));
assert!(p.route_gemm(4_096, 4_096, 4_096));
assert!(!p.route_gemv(1_024, 1_024));
assert!(p.route_gemv(16_384, 16_384));
}
#[test]
fn route_csr_spmv_uses_device_threshold() {
let p = DispatchPolicy::for_device(Some(&device(8, 108)));
assert!(!p.route_csr_spmv(10_000, 1_000, 1_024));
assert!(p.route_csr_spmv(10_000, 1_000, 1_000_000));
}
#[test]
fn route_cusolver_uses_device_thresholds() {
let p = DispatchPolicy::for_device(Some(&device(8, 108)));
assert!(!p.route_chol_solve(p.chol_min_p.saturating_sub(1)));
assert!(p.route_chol_solve(p.chol_min_p));
assert!(!p.route_syevd(p.syevd_min_p.saturating_sub(1)));
assert!(p.route_syevd(p.syevd_min_p));
assert!(!p.route_trsm(128, 128));
assert!(p.route_trsm(8_192, 8_192));
}
}