use std::cell::Cell;
use std::ffi::c_int;
use oxicuda_driver::loader::try_driver;
use crate::error::{CudaRtError, CudaRtResult};
thread_local! {
static CURRENT_DEVICE: Cell<Option<c_int>> = const { Cell::new(None) };
}
#[derive(Debug, Clone)]
pub struct CudaDeviceProp {
pub name: String,
pub total_global_mem: usize,
pub shared_mem_per_block: usize,
pub regs_per_block: u32,
pub warp_size: u32,
pub mem_pitch: usize,
pub max_threads_per_block: u32,
pub max_threads_dim: [u32; 3],
pub max_grid_size: [u32; 3],
pub clock_rate: u32,
pub total_const_mem: usize,
pub major: u32,
pub minor: u32,
pub texture_alignment: usize,
pub texture_pitch_alignment: usize,
pub device_overlap: bool,
pub multi_processor_count: u32,
pub ecc_enabled: bool,
pub integrated: bool,
pub can_map_host_memory: bool,
pub unified_addressing: bool,
pub memory_clock_rate: u32,
pub memory_bus_width: u32,
pub l2_cache_size: u32,
pub max_threads_per_multi_processor: u32,
pub stream_priorities_supported: bool,
pub shared_mem_per_multiprocessor: usize,
pub regs_per_multiprocessor: u32,
pub managed_memory: bool,
pub is_multi_gpu_board: bool,
pub multi_gpu_board_group_id: u32,
pub host_native_atomic_supported: bool,
pub cooperative_launch: bool,
pub cooperative_multi_device_launch: bool,
pub max_blocks_per_multi_processor: u32,
pub shared_mem_per_block_optin: usize,
pub cluster_launch: bool,
}
impl CudaDeviceProp {
pub fn from_device(ordinal: c_int) -> CudaRtResult<Self> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let attr = |a: oxicuda_driver::ffi::CUdevice_attribute| -> CudaRtResult<u32> {
let mut v: c_int = 0;
let rc = unsafe { (api.cu_device_get_attribute)(&raw mut v, a, ordinal) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
Ok(v as u32)
};
use oxicuda_driver::ffi::CUdevice_attribute as A;
let mut name_buf = [0u8; 256];
unsafe {
(api.cu_device_get_name)(
name_buf.as_mut_ptr() as *mut std::ffi::c_char,
name_buf.len() as c_int,
ordinal,
);
}
let name = {
let nul = name_buf
.iter()
.position(|&b| b == 0)
.unwrap_or(name_buf.len());
String::from_utf8_lossy(&name_buf[..nul]).into_owned()
};
let mut total_global_mem: usize = 0;
unsafe {
(api.cu_device_total_mem_v2)(&raw mut total_global_mem, ordinal);
}
Ok(Self {
name,
total_global_mem,
shared_mem_per_block: attr(A::MaxSharedMemoryPerBlock)? as usize,
regs_per_block: attr(A::MaxRegistersPerBlock)?,
warp_size: attr(A::WarpSize)?,
mem_pitch: attr(A::MaxPitch)? as usize,
max_threads_per_block: attr(A::MaxThreadsPerBlock)?,
max_threads_dim: [
attr(A::MaxBlockDimX)?,
attr(A::MaxBlockDimY)?,
attr(A::MaxBlockDimZ)?,
],
max_grid_size: [
attr(A::MaxGridDimX)?,
attr(A::MaxGridDimY)?,
attr(A::MaxGridDimZ)?,
],
clock_rate: attr(A::ClockRate)?,
total_const_mem: attr(A::TotalConstantMemory)? as usize,
major: attr(A::ComputeCapabilityMajor)?,
minor: attr(A::ComputeCapabilityMinor)?,
texture_alignment: attr(A::TextureAlignment)? as usize,
texture_pitch_alignment: attr(A::TexturePitchAlignment)? as usize,
device_overlap: attr(A::GpuOverlap)? != 0,
multi_processor_count: attr(A::MultiprocessorCount)?,
ecc_enabled: attr(A::EccEnabled)? != 0,
integrated: attr(A::Integrated)? != 0,
can_map_host_memory: attr(A::CanMapHostMemory)? != 0,
unified_addressing: attr(A::UnifiedAddressing)? != 0,
memory_clock_rate: attr(A::MemoryClockRate)?,
memory_bus_width: attr(A::GlobalMemoryBusWidth)?,
l2_cache_size: attr(A::L2CacheSize)?,
max_threads_per_multi_processor: attr(A::MaxThreadsPerMultiprocessor)?,
stream_priorities_supported: attr(A::StreamPrioritiesSupported)? != 0,
shared_mem_per_multiprocessor: attr(A::MaxSharedMemoryPerMultiprocessor)? as usize,
regs_per_multiprocessor: attr(A::MaxRegistersPerMultiprocessor)?,
managed_memory: attr(A::ManagedMemory)? != 0,
is_multi_gpu_board: attr(A::IsMultiGpuBoard)? != 0,
multi_gpu_board_group_id: attr(A::MultiGpuBoardGroupId)?,
host_native_atomic_supported: attr(A::HostNativeAtomicSupported)? != 0,
cooperative_launch: attr(A::CooperativeLaunch)? != 0,
cooperative_multi_device_launch: attr(A::CooperativeMultiDeviceLaunch)? != 0,
max_blocks_per_multi_processor: attr(A::MaxBlocksPerMultiprocessor)?,
shared_mem_per_block_optin: attr(A::MaxSharedMemoryPerBlockOptin)? as usize,
cluster_launch: attr(A::ClusterLaunch)? != 0,
})
}
}
pub fn get_device_count() -> CudaRtResult<u32> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut count: c_int = 0;
let rc = unsafe { (api.cu_device_get_count)(&raw mut count) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::NoGpu));
}
if count == 0 {
return Err(CudaRtError::NoGpu);
}
Ok(count as u32)
}
pub fn set_device(device: u32) -> CudaRtResult<()> {
let count = get_device_count()?;
if device >= count {
return Err(CudaRtError::InvalidDevice);
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut ctx = oxicuda_driver::ffi::CUcontext::default();
let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut ctx, device as c_int) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
let rc = unsafe { (api.cu_ctx_set_current)(ctx) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
CURRENT_DEVICE.with(|cell| cell.set(Some(device as c_int)));
Ok(())
}
pub fn get_device() -> CudaRtResult<u32> {
CURRENT_DEVICE.with(|cell| {
cell.get()
.map(|d| d as u32)
.ok_or(CudaRtError::DeviceNotSet)
})
}
pub fn get_device_properties(device: u32) -> CudaRtResult<CudaDeviceProp> {
let count = get_device_count()?;
if device >= count {
return Err(CudaRtError::InvalidDevice);
}
CudaDeviceProp::from_device(device as c_int)
}
pub fn device_synchronize() -> CudaRtResult<()> {
let _device = get_device()?;
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
unsafe { (api.cu_ctx_synchronize)() };
Ok(())
}
pub fn device_reset() -> CudaRtResult<()> {
let _device = get_device()?;
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut _dev: c_int = 0;
CURRENT_DEVICE.with(|cell| {
if let Some(d) = cell.get() {
_dev = d;
}
});
unsafe { (api.cu_device_primary_ctx_reset_v2)(_dev) };
CURRENT_DEVICE.with(|cell| cell.set(None));
Ok(())
}
pub fn get_compute_capability(device: u32) -> CudaRtResult<(u32, u32)> {
let props = get_device_properties(device)?;
Ok((props.major, props.minor))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(not(feature = "gpu-tests"))]
fn get_device_without_set_errors() {
let result = get_device();
assert!(matches!(result, Err(CudaRtError::DeviceNotSet)));
}
#[test]
fn set_device_persists_in_thread() {
let count_result = get_device_count();
match count_result {
Err(CudaRtError::DriverNotAvailable) | Err(CudaRtError::NoGpu) => {
assert!(get_device().is_err());
}
Ok(n) => {
set_device(0).expect("set_device(0) failed");
assert_eq!(get_device().unwrap(), 0);
assert!(matches!(set_device(n), Err(CudaRtError::InvalidDevice)));
}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn from_code_round_trip() {
assert_eq!(CudaRtError::from_code(100), Some(CudaRtError::NoDevice));
assert_eq!(
CudaRtError::from_code(101),
Some(CudaRtError::InvalidDevice)
);
}
}