use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DeviceType {
Cpu,
Cuda,
Metal,
Vulkan,
Rocm,
}
impl DeviceType {
pub fn is_gpu(&self) -> bool {
matches!(
self,
DeviceType::Cuda | DeviceType::Metal | DeviceType::Vulkan | DeviceType::Rocm
)
}
pub fn is_cpu(&self) -> bool {
matches!(self, DeviceType::Cpu)
}
pub fn name(&self) -> &'static str {
match self {
DeviceType::Cpu => "CPU",
DeviceType::Cuda => "CUDA",
DeviceType::Metal => "Metal",
DeviceType::Vulkan => "Vulkan",
DeviceType::Rocm => "ROCm",
}
}
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Device {
pub device_type: DeviceType,
pub index: usize,
}
impl Device {
pub fn cpu() -> Self {
Self {
device_type: DeviceType::Cpu,
index: 0,
}
}
pub fn cuda(index: usize) -> Self {
Self {
device_type: DeviceType::Cuda,
index,
}
}
pub fn metal() -> Self {
Self {
device_type: DeviceType::Metal,
index: 0,
}
}
pub fn vulkan(index: usize) -> Self {
Self {
device_type: DeviceType::Vulkan,
index,
}
}
pub fn rocm(index: usize) -> Self {
Self {
device_type: DeviceType::Rocm,
index,
}
}
pub fn is_cpu(&self) -> bool {
self.device_type.is_cpu()
}
pub fn is_gpu(&self) -> bool {
self.device_type.is_gpu()
}
pub fn device_type(&self) -> DeviceType {
self.device_type
}
pub fn index(&self) -> usize {
self.index
}
}
impl Default for Device {
fn default() -> Self {
Self::cpu()
}
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.index == 0 && self.is_cpu() {
write!(f, "{}", self.device_type)
} else {
write!(f, "{}:{}", self.device_type, self.index)
}
}
}
#[derive(Debug, Clone)]
pub struct SystemDeviceManager {
available_devices: Vec<Device>,
default_device: Device,
}
impl SystemDeviceManager {
pub fn new() -> Self {
#[cfg(test)] let available_devices = vec![Device::cpu()];
#[cfg(not(test))] let available_devices = {
let mut devices = vec![Device::cpu()];
let cuda_devices = crate::cuda_detect::detect_cuda_devices();
for cuda_info in cuda_devices {
devices.push(Device::cuda(cuda_info.index));
}
devices
};
Self {
available_devices: available_devices.clone(),
default_device: available_devices[0].clone(),
}
}
pub fn available_devices(&self) -> &[Device] {
&self.available_devices
}
pub fn default_device(&self) -> &Device {
&self.default_device
}
pub fn set_default_device(&mut self, device: Device) -> Result<(), DeviceError> {
if !self.available_devices.contains(&device) {
return Err(DeviceError::DeviceNotAvailable(device));
}
self.default_device = device;
Ok(())
}
pub fn is_available(&self, device: &Device) -> bool {
self.available_devices.contains(device)
}
pub fn get_device(&self, device_type: DeviceType, index: usize) -> Option<&Device> {
self.available_devices
.iter()
.find(|d| d.device_type == device_type && d.index == index)
}
pub fn count_devices(&self, device_type: DeviceType) -> usize {
self.available_devices
.iter()
.filter(|d| d.device_type == device_type)
.count()
}
pub fn has_gpu(&self) -> bool {
self.available_devices.iter().any(|d| d.is_gpu())
}
}
impl Default for SystemDeviceManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum DeviceError {
#[error("Device not available: {0}")]
DeviceNotAvailable(Device),
#[error("Device memory allocation failed: {0}")]
AllocationFailed(String),
#[error("Device synchronization failed: {0}")]
SyncFailed(String),
#[error("Unsupported operation on device {device}: {operation}")]
UnsupportedOperation { device: Device, operation: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_type_properties() {
assert!(DeviceType::Cpu.is_cpu());
assert!(!DeviceType::Cpu.is_gpu());
assert!(DeviceType::Cuda.is_gpu());
assert!(!DeviceType::Cuda.is_cpu());
assert!(DeviceType::Metal.is_gpu());
assert!(DeviceType::Vulkan.is_gpu());
assert!(DeviceType::Rocm.is_gpu());
}
#[test]
fn test_device_type_display() {
assert_eq!(DeviceType::Cpu.to_string(), "CPU");
assert_eq!(DeviceType::Cuda.to_string(), "CUDA");
assert_eq!(DeviceType::Metal.to_string(), "Metal");
}
#[test]
fn test_device_creation() {
let cpu = Device::cpu();
assert!(cpu.is_cpu());
assert_eq!(cpu.index(), 0);
let cuda = Device::cuda(1);
assert!(cuda.is_gpu());
assert_eq!(cuda.index(), 1);
assert_eq!(cuda.device_type(), DeviceType::Cuda);
}
#[test]
fn test_device_default() {
let device = Device::default();
assert!(device.is_cpu());
assert_eq!(device.index(), 0);
}
#[test]
fn test_device_display() {
assert_eq!(Device::cpu().to_string(), "CPU");
assert_eq!(Device::cuda(0).to_string(), "CUDA:0");
assert_eq!(Device::cuda(1).to_string(), "CUDA:1");
assert_eq!(Device::metal().to_string(), "Metal:0");
}
#[test]
fn test_device_manager_creation() {
let manager = SystemDeviceManager::new();
assert!(!manager.available_devices().is_empty());
assert!(manager.default_device().is_cpu());
}
#[test]
fn test_device_manager_queries() {
let manager = SystemDeviceManager::new();
assert!(manager.is_available(&Device::cpu()));
assert_eq!(manager.count_devices(DeviceType::Cpu), 1);
assert_eq!(manager.default_device(), &Device::cpu());
}
#[test]
fn test_device_manager_set_default() {
let mut manager = SystemDeviceManager::new();
let cpu = Device::cpu();
assert!(manager.set_default_device(cpu.clone()).is_ok());
assert_eq!(manager.default_device(), &cpu);
let cuda = Device::cuda(99);
assert!(manager.set_default_device(cuda).is_err());
}
#[test]
fn test_device_manager_get_device() {
let manager = SystemDeviceManager::new();
let cpu = manager.get_device(DeviceType::Cpu, 0);
assert!(cpu.is_some());
assert_eq!(cpu.expect("cpu device expected"), &Device::cpu());
let cuda = manager.get_device(DeviceType::Cuda, 0);
assert!(cuda.is_none());
}
#[test]
fn test_device_error_display() {
let err = DeviceError::DeviceNotAvailable(Device::cuda(0));
assert!(err.to_string().contains("not available"));
let err = DeviceError::AllocationFailed("out of memory".to_string());
assert!(err.to_string().contains("allocation failed"));
}
}