use super::{GpuDeviceInfo, GpuDeviceType};
use crate::error::LinalgResult;
#[derive(Debug, Clone, Copy)]
pub struct DeviceCapabilities {
pub supports_fp64: bool,
pub supports_fp16: bool,
pub supports_unified_memory: bool,
pub supports_p2p: bool,
pub max_threads_per_block: usize,
pub max_shared_memory: usize,
pub warpsize: usize,
}
#[derive(Debug, Clone)]
pub struct DevicePerformance {
pub memory_bandwidth: f64,
pub peak_gflops_fp32: f64,
pub peak_gflops_fp64: f64,
pub memory_latency_ns: f64,
pub cachesize: usize,
}
#[derive(Debug, Clone)]
pub struct ExtendedDeviceInfo {
pub basic_info: GpuDeviceInfo,
pub capabilities: DeviceCapabilities,
pub performance: DevicePerformance,
}
impl ExtendedDeviceInfo {
pub fn from_basic(basicinfo: GpuDeviceInfo) -> Self {
let (capabilities, performance) = match basicinfo.device_type {
GpuDeviceType::Cuda => estimate_cuda_specs(&basicinfo),
GpuDeviceType::OpenCl => estimate_opencl_specs(&basicinfo),
GpuDeviceType::Rocm => estimate_rocm_specs(&basicinfo),
GpuDeviceType::Vulkan => estimate_vulkan_specs(&basicinfo),
GpuDeviceType::Metal => estimate_metal_specs(&basicinfo),
GpuDeviceType::OneApi => estimate_oneapi_specs(&basicinfo),
GpuDeviceType::WebGpu => estimate_webgpu_specs(&basicinfo),
};
Self {
basic_info: basicinfo,
capabilities,
performance,
}
}
pub fn is_suitable_for_workload(&self, elements: usize, requiresfp64: bool) -> bool {
let memory_required = elements * 8;
let memory_available = self.basic_info.total_memory;
if memory_required > memory_available / 2 {
return false; }
if requiresfp64 && !self.capabilities.supports_fp64 {
return false;
}
true
}
pub fn estimate_matmul_performance(&self, m: usize, n: usize, k: usize) -> f64 {
let ops = 2.0 * m as f64 * n as f64 * k as f64; let peak_ops_per_sec = self.performance.peak_gflops_fp32 * 1e9;
ops / (peak_ops_per_sec * 0.5)
}
pub fn recommended_blocksize(&self) -> (usize, usize) {
match self.basic_info.device_type {
GpuDeviceType::Cuda => (32, 32), GpuDeviceType::OpenCl => (16, 16), GpuDeviceType::Rocm => (32, 32), GpuDeviceType::Vulkan => (32, 32),
GpuDeviceType::Metal => (16, 16),
GpuDeviceType::OneApi => (16, 16), GpuDeviceType::WebGpu => (8, 8), }
}
}
#[allow(dead_code)]
fn estimate_cuda_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: info.supports_fp16,
supports_unified_memory: true, supports_p2p: true,
max_threads_per_block: 1024,
max_shared_memory: 48 * 1024, warpsize: 32,
};
let estimated_cores = info.compute_units * 64; let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 10.0, peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.5, memory_latency_ns: 400.0,
cachesize: 256 * 1024, };
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_opencl_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: info.supports_fp16,
supports_unified_memory: false, supports_p2p: false,
max_threads_per_block: info.max_work_groupsize,
max_shared_memory: 32 * 1024, warpsize: 64, };
let estimated_cores = info.compute_units * 64;
let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 8.0,
peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.25, memory_latency_ns: 500.0,
cachesize: 128 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_rocm_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: info.supports_fp16,
supports_unified_memory: true, supports_p2p: true,
max_threads_per_block: 1024,
max_shared_memory: 64 * 1024, warpsize: 64, };
let estimated_cores = info.compute_units * 64;
let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 12.0, peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.5,
memory_latency_ns: 350.0,
cachesize: 512 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_vulkan_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: info.supports_fp16,
supports_unified_memory: false,
supports_p2p: false,
max_threads_per_block: info.max_work_groupsize,
max_shared_memory: 32 * 1024,
warpsize: 32, };
let estimated_cores = info.compute_units * 32;
let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 6.0,
peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.25,
memory_latency_ns: 600.0,
cachesize: 64 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_metal_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: true, supports_unified_memory: true, supports_p2p: false,
max_threads_per_block: 1024,
max_shared_memory: 32 * 1024,
warpsize: 32, };
let estimated_cores = info.compute_units * 32; let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 15.0, peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.5,
memory_latency_ns: 200.0, cachesize: 128 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_oneapi_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: info.supports_fp64,
supports_fp16: true, supports_unified_memory: true, supports_p2p: false,
max_threads_per_block: 256,
max_shared_memory: 64 * 1024,
warpsize: 16, };
let estimated_cores = info.compute_units * 16;
let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 2.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 8.0,
peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: peak_gflops * 0.5,
memory_latency_ns: 300.0,
cachesize: 32 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
fn estimate_webgpu_specs(info: &GpuDeviceInfo) -> (DeviceCapabilities, DevicePerformance) {
let capabilities = DeviceCapabilities {
supports_fp64: false, supports_fp16: false, supports_unified_memory: false,
supports_p2p: false,
max_threads_per_block: 256,
max_shared_memory: 16 * 1024, warpsize: 32,
};
let estimated_cores = info.compute_units * 8; let peak_gflops = estimated_cores as f64 * info.clock_frequency as f64 * 1.0 / 1000.0;
let performance = DevicePerformance {
memory_bandwidth: (info.total_memory as f64 / 1e9) * 2.0, peak_gflops_fp32: peak_gflops,
peak_gflops_fp64: 0.0, memory_latency_ns: 1000.0, cachesize: 8 * 1024,
};
(capabilities, performance)
}
#[allow(dead_code)]
pub fn benchmark_device_performance(
device_info: &GpuDeviceInfo,
) -> LinalgResult<DevicePerformance> {
let extended_info = ExtendedDeviceInfo::from_basic(device_info.clone());
Ok(extended_info.performance)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extended_device_info_creation() {
#[cfg(target_pointer_width = "32")]
let total_memory = 512 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let total_memory = 8usize * 1024 * 1024 * 1024;
let basic_info = GpuDeviceInfo {
device_type: GpuDeviceType::Cuda,
name: "Test GPU".to_string(),
total_memory,
compute_units: 80,
clock_frequency: 1500,
supports_fp64: true,
supports_fp16: true,
max_work_groupsize: 1024,
memory_bandwidth: 900.0,
l2_cachesize: 6 * 1024 * 1024, shared_memory_per_block: 48 * 1024, registers_per_block: 65536,
warpsize: 32,
max_threads_per_mp: 2048,
multiprocessor_count: 80,
supports_tensor_cores: true,
supports_mixed_precision: true,
vendor: "NVIDIA".to_string(),
};
let extended_info = ExtendedDeviceInfo::from_basic(basic_info);
assert_eq!(extended_info.basic_info.device_type, GpuDeviceType::Cuda);
assert_eq!(extended_info.capabilities.warpsize, 32);
assert!(extended_info.performance.peak_gflops_fp32 > 0.0);
}
#[test]
fn test_workload_suitability() {
#[cfg(target_pointer_width = "32")]
let total_memory = 256 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let total_memory = 1024 * 1024 * 1024;
let basic_info = GpuDeviceInfo {
device_type: GpuDeviceType::Cuda,
name: "Test GPU".to_string(),
total_memory,
compute_units: 10,
clock_frequency: 1000,
supports_fp64: false,
supports_fp16: true,
max_work_groupsize: 1024,
memory_bandwidth: 400.0,
l2_cachesize: 1024 * 1024, shared_memory_per_block: 32 * 1024, registers_per_block: 32768,
warpsize: 32,
max_threads_per_mp: 1536,
multiprocessor_count: 10,
supports_tensor_cores: false,
supports_mixed_precision: false,
vendor: "Test".to_string(),
};
let extended_info = ExtendedDeviceInfo::from_basic(basic_info);
assert!(extended_info.is_suitable_for_workload(1000, false));
assert!(!extended_info.is_suitable_for_workload(100_000_000, false));
assert!(!extended_info.is_suitable_for_workload(1000, true));
}
#[test]
fn test_performance_estimation() {
#[cfg(target_pointer_width = "32")]
let total_memory = 512 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let total_memory = 8usize * 1024 * 1024 * 1024;
let basic_info = GpuDeviceInfo {
device_type: GpuDeviceType::Cuda,
name: "Test GPU".to_string(),
total_memory,
compute_units: 80,
clock_frequency: 1500,
supports_fp64: true,
supports_fp16: true,
max_work_groupsize: 1024,
memory_bandwidth: 900.0,
l2_cachesize: 6 * 1024 * 1024, shared_memory_per_block: 48 * 1024, registers_per_block: 65536,
warpsize: 32,
max_threads_per_mp: 2048,
multiprocessor_count: 80,
supports_tensor_cores: true,
supports_mixed_precision: true,
vendor: "Test".to_string(),
};
let extended_info = ExtendedDeviceInfo::from_basic(basic_info);
let time_estimate = extended_info.estimate_matmul_performance(1000, 1000, 1000);
assert!(time_estimate > 0.0);
assert!(time_estimate < 1.0); }
}