use crate::error::{ClusteringError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GpuBackend {
Cuda,
OpenCl,
Rocm,
OneApi,
Metal,
CpuFallback,
}
impl Default for GpuBackend {
fn default() -> Self {
GpuBackend::CpuFallback
}
}
impl std::fmt::Display for GpuBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GpuBackend::Cuda => write!(f, "CUDA"),
GpuBackend::OpenCl => write!(f, "OpenCL"),
GpuBackend::Rocm => write!(f, "ROCm"),
GpuBackend::OneApi => write!(f, "Intel OneAPI"),
GpuBackend::Metal => write!(f, "Apple Metal"),
GpuBackend::CpuFallback => write!(f, "CPU Fallback"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuDevice {
pub device_id: u32,
pub name: String,
pub total_memory: usize,
pub available_memory: usize,
pub compute_capability: String,
pub compute_units: u32,
pub backend: GpuBackend,
pub supports_double_precision: bool,
}
impl GpuDevice {
pub fn new(
device_id: u32,
name: String,
total_memory: usize,
available_memory: usize,
compute_capability: String,
compute_units: u32,
backend: GpuBackend,
supports_double_precision: bool,
) -> Self {
Self {
device_id,
name,
total_memory,
available_memory,
compute_capability,
compute_units,
backend,
supports_double_precision,
}
}
pub fn memory_utilization(&self) -> f64 {
if self.total_memory == 0 {
0.0
} else {
100.0 * (1.0 - (self.available_memory as f64 / self.total_memory as f64))
}
}
pub fn is_suitable_for_double_precision(&self) -> bool {
self.supports_double_precision
}
pub fn get_device_score(&self) -> f64 {
let memory_score = self.available_memory as f64 / 1_000_000_000.0; let compute_score = self.compute_units as f64;
let precision_bonus = if self.supports_double_precision {
1.5
} else {
1.0
};
(memory_score + compute_score) * precision_bonus
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DeviceSelection {
First,
MostMemory,
HighestCompute,
Specific(u32),
Auto,
Fastest,
}
impl Default for DeviceSelection {
fn default() -> Self {
DeviceSelection::Auto
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
pub preferred_backend: GpuBackend,
pub device_selection: DeviceSelection,
pub auto_fallback: bool,
pub memory_pool_size: Option<usize>,
pub optimize_memory: bool,
pub backend_options: HashMap<String, String>,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
preferred_backend: GpuBackend::CpuFallback,
device_selection: DeviceSelection::Auto,
auto_fallback: true,
memory_pool_size: None,
optimize_memory: true,
backend_options: HashMap::new(),
}
}
}
impl GpuConfig {
pub fn new(backend: GpuBackend) -> Self {
Self {
preferred_backend: backend,
..Default::default()
}
}
pub fn with_device_selection(mut self, strategy: DeviceSelection) -> Self {
self.device_selection = strategy;
self
}
pub fn with_memory_pool_size(mut self, size: usize) -> Self {
self.memory_pool_size = Some(size);
self
}
pub fn without_fallback(mut self) -> Self {
self.auto_fallback = false;
self
}
pub fn with_backend_option(mut self, key: String, value: String) -> Self {
self.backend_options.insert(key, value);
self
}
pub fn cuda() -> Self {
Self::new(GpuBackend::Cuda)
}
pub fn opencl() -> Self {
Self::new(GpuBackend::OpenCl)
}
pub fn rocm() -> Self {
Self::new(GpuBackend::Rocm)
}
pub fn metal() -> Self {
Self::new(GpuBackend::Metal)
}
pub fn validate(&self) -> Result<()> {
if let DeviceSelection::Specific(id) = self.device_selection {
if id > 64 {
return Err(ClusteringError::InvalidInput(
"Device ID too high".to_string(),
));
}
}
if let Some(pool_size) = self.memory_pool_size {
if pool_size < 1024 * 1024 {
return Err(ClusteringError::InvalidInput(
"Memory pool size too small (minimum 1MB)".to_string(),
));
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct GpuContext {
pub device: GpuDevice,
pub config: GpuConfig,
pub gpu_available: bool,
pub backend_context: BackendContext,
}
impl GpuContext {
pub fn new(device: GpuDevice, config: GpuConfig) -> Result<Self> {
config.validate()?;
let gpu_available = Self::check_gpu_availability(&device, &config);
let backend_context = BackendContext::new(&device.backend)?;
Ok(Self {
device,
config,
gpu_available,
backend_context,
})
}
fn check_gpu_availability(device: &GpuDevice, config: &GpuConfig) -> bool {
match (device.backend, config.preferred_backend) {
(GpuBackend::CpuFallback, _) => false,
(backend1, backend2) if backend1 == backend2 => true,
_ => config.auto_fallback,
}
}
pub fn effective_backend(&self) -> GpuBackend {
if self.gpu_available {
self.device.backend
} else {
GpuBackend::CpuFallback
}
}
pub fn is_gpu_accelerated(&self) -> bool {
self.gpu_available && self.device.backend != GpuBackend::CpuFallback
}
pub fn memory_info(&self) -> (usize, usize) {
(self.device.total_memory, self.device.available_memory)
}
}
#[derive(Debug)]
pub enum BackendContext {
Cuda {
context_handle: u64,
stream_handle: u64,
},
OpenCl {
context_handle: u64,
queue_handle: u64,
},
Rocm {
context_handle: u64,
},
OneApi {
context_handle: u64,
},
Metal {
device_handle: u64,
queue_handle: u64,
},
CpuFallback,
}
impl BackendContext {
pub fn new(backend: &GpuBackend) -> Result<Self> {
match backend {
GpuBackend::Cuda => Ok(BackendContext::Cuda {
context_handle: 0, stream_handle: 0,
}),
GpuBackend::OpenCl => Ok(BackendContext::OpenCl {
context_handle: 0, queue_handle: 0,
}),
GpuBackend::Rocm => Ok(BackendContext::Rocm {
context_handle: 0, }),
GpuBackend::OneApi => Ok(BackendContext::OneApi {
context_handle: 0, }),
GpuBackend::Metal => Ok(BackendContext::Metal {
device_handle: 0, queue_handle: 0,
}),
GpuBackend::CpuFallback => Ok(BackendContext::CpuFallback),
}
}
pub fn is_valid(&self) -> bool {
match self {
BackendContext::CpuFallback => true,
_ => true, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_device_creation() {
let device = GpuDevice::new(
0,
"Test GPU".to_string(),
8_000_000_000, 6_000_000_000, "7.5".to_string(),
2048,
GpuBackend::Cuda,
true,
);
assert_eq!(device.device_id, 0);
assert_eq!(device.name, "Test GPU");
assert_eq!(device.memory_utilization(), 25.0); assert!(device.is_suitable_for_double_precision());
}
#[test]
fn test_gpu_config_validation() {
let config = GpuConfig::default();
assert!(config.validate().is_ok());
let invalid_config = GpuConfig::default().with_memory_pool_size(1024); assert!(invalid_config.validate().is_err());
}
#[test]
fn test_device_selection_strategies() {
assert_eq!(DeviceSelection::default(), DeviceSelection::Auto);
let specific = DeviceSelection::Specific(0);
if let DeviceSelection::Specific(id) = specific {
assert_eq!(id, 0);
}
}
#[test]
fn test_backend_context_creation() {
let cuda_context = BackendContext::new(&GpuBackend::Cuda).expect("Operation failed");
assert!(cuda_context.is_valid());
let cpu_context = BackendContext::new(&GpuBackend::CpuFallback).expect("Operation failed");
assert!(cpu_context.is_valid());
}
}