use crate::runtime::Device;
#[derive(Clone, Debug)]
pub struct CudaDevice {
pub(crate) index: usize,
}
impl CudaDevice {
pub fn new(index: usize) -> Self {
Self { index }
}
pub fn compute_capability(&self) -> Result<(u32, u32), CudaError> {
let device = cudarc::driver::result::device::get(self.index as i32).map_err(|e| {
CudaError::DeviceError(format!("Failed to get CUDA device {}: {:?}", self.index, e))
})?;
let major = unsafe {
cudarc::driver::result::device::get_attribute(
device,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
)
}
.map_err(|e| CudaError::DeviceError(format!("Failed to get compute capability major: {:?}", e)))? as u32;
let minor = unsafe {
cudarc::driver::result::device::get_attribute(
device,
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
)
}
.map_err(|e| CudaError::DeviceError(format!("Failed to get compute capability minor: {:?}", e)))? as u32;
Ok((major, minor))
}
pub fn sync(&self) -> Result<(), CudaError> {
cudarc::driver::result::ctx::synchronize().map_err(|e| {
CudaError::SyncError(format!(
"Failed to synchronize CUDA context for device {}: {:?}",
self.index, e
))
})
}
pub fn memory_info(&self) -> Result<(u64, u64), CudaError> {
let (free, total) = cudarc::driver::result::mem_get_info().map_err(|e| {
CudaError::DeviceError(format!(
"Failed to get memory info for device {}: {:?}",
self.index, e
))
})?;
Ok((free as u64, total as u64))
}
pub fn available_memory(&self) -> Result<u64, CudaError> {
let (free, _) = self.memory_info()?;
Ok(free)
}
pub fn total_memory(&self) -> Result<u64, CudaError> {
let (_, total) = self.memory_info()?;
Ok(total)
}
}
impl Device for CudaDevice {
fn id(&self) -> usize {
self.index
}
fn name(&self) -> String {
format!("cuda:{}", self.index)
}
}
impl Default for CudaDevice {
fn default() -> Self {
Self::new(0)
}
}
#[derive(Debug, Clone)]
pub enum CudaError {
DeviceError(String),
AllocationError(String),
CopyError(String),
KernelError(String),
SyncError(String),
CublasError(String),
ContextError(String),
}
impl std::fmt::Display for CudaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CudaError::DeviceError(msg) => write!(f, "CUDA device error: {}", msg),
CudaError::AllocationError(msg) => write!(f, "CUDA allocation error: {}", msg),
CudaError::CopyError(msg) => write!(f, "CUDA copy error: {}", msg),
CudaError::KernelError(msg) => write!(f, "CUDA kernel error: {}", msg),
CudaError::SyncError(msg) => write!(f, "CUDA sync error: {}", msg),
CudaError::CublasError(msg) => write!(f, "cuBLAS error: {}", msg),
CudaError::ContextError(msg) => write!(f, "CUDA context error: {}", msg),
}
}
}
impl std::error::Error for CudaError {}