#![cfg(feature = "cuda")]
use std::sync::Arc;
use cudarc::driver::CudaContext;
use crate::GpuError;
#[derive(Clone)]
pub struct CudaDevice {
pub(crate) ctx: Arc<CudaContext>,
}
impl CudaDevice {
pub fn new() -> Result<Self, GpuError> {
Self::with_ordinal(0)
}
pub fn with_ordinal(ordinal: usize) -> Result<Self, GpuError> {
let ctx = CudaContext::new(ordinal).map_err(map_cuda_err)?;
Ok(Self { ctx })
}
pub fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
pub fn name(&self) -> Result<String, GpuError> {
self.ctx.name().map_err(map_cuda_err)
}
pub fn compute_capability(&self) -> Result<(i32, i32), GpuError> {
self.ctx.compute_capability().map_err(map_cuda_err)
}
pub fn total_memory_bytes(&self) -> Result<usize, GpuError> {
self.ctx.total_mem().map_err(map_cuda_err)
}
pub fn uuid(&self) -> Result<[u8; 16], GpuError> {
let cu_uuid = self.ctx.uuid().map_err(map_cuda_err)?;
let bytes: [u8; 16] = unsafe { core::mem::transmute(cu_uuid.bytes) };
Ok(bytes)
}
}
fn map_cuda_err<E: std::fmt::Display>(e: E) -> GpuError {
GpuError::CudaError(e.to_string())
}