use std::os::raw::c_char;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use super::sys::{CUcontext, CUdevice, CudaDriver, CUDA_SUCCESS};
use crate::GpuError;
static CUDA_INITIALIZED: AtomicBool = AtomicBool::new(false);
pub(crate) fn get_driver() -> Result<&'static CudaDriver, GpuError> {
let driver = CudaDriver::load()
.ok_or_else(|| GpuError::CudaNotAvailable("CUDA driver not found".to_string()))?;
if !CUDA_INITIALIZED.swap(true, Ordering::SeqCst) {
let result = unsafe { (driver.cuInit)(0) };
if result != CUDA_SUCCESS {
CUDA_INITIALIZED.store(false, Ordering::SeqCst);
return Err(GpuError::DeviceInit(format!("cuInit failed with code {}", result)));
}
}
Ok(driver)
}
pub struct CudaContext {
device: CUdevice,
context: CUcontext,
}
unsafe impl Send for CudaContext {}
unsafe impl Sync for CudaContext {}
impl CudaContext {
pub fn new(device_ordinal: i32) -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut count: i32 = 0;
let result = unsafe { (driver.cuDeviceGetCount)(&mut count) };
CudaDriver::check(result)?;
if device_ordinal < 0 || device_ordinal >= count {
return Err(GpuError::DeviceNotFound(device_ordinal, count as usize));
}
let mut device: CUdevice = 0;
let result = unsafe { (driver.cuDeviceGet)(&mut device, device_ordinal) };
CudaDriver::check(result)?;
let mut context: CUcontext = ptr::null_mut();
let result = unsafe { (driver.cuDevicePrimaryCtxRetain)(&mut context, device) };
CudaDriver::check(result)?;
let result = unsafe { (driver.cuCtxSetCurrent)(context) };
if result != CUDA_SUCCESS {
unsafe { (driver.cuDevicePrimaryCtxRelease)(device) };
return Err(GpuError::DeviceInit(format!(
"cuCtxSetCurrent failed with code {}",
result
)));
}
Ok(Self { device, context })
}
#[must_use]
pub fn device(&self) -> i32 {
self.device
}
#[must_use]
pub fn raw(&self) -> CUcontext {
self.context
}
pub fn memory_info(&self) -> Result<(usize, usize), GpuError> {
let driver = get_driver()?;
let mut free: usize = 0;
let mut total: usize = 0;
let result = unsafe { (driver.cuMemGetInfo)(&mut free, &mut total) };
CudaDriver::check(result)?;
Ok((free, total))
}
pub fn make_current(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuCtxSetCurrent)(self.context) };
if result != CUDA_SUCCESS {
return Err(GpuError::DeviceInit(format!(
"cuCtxSetCurrent failed with code {}",
result
)));
}
Ok(())
}
pub fn synchronize(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuCtxSynchronize)() };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(e.to_string()))
}
pub fn device_name(&self) -> Result<String, GpuError> {
let driver = get_driver()?;
let mut name = [0 as c_char; 256];
let result = unsafe { (driver.cuDeviceGetName)(name.as_mut_ptr(), 256, self.device) };
CudaDriver::check(result)?;
let name_str =
unsafe { std::ffi::CStr::from_ptr(name.as_ptr()).to_string_lossy().into_owned() };
Ok(name_str)
}
pub fn compute_capability(&self) -> Result<(i32, i32), GpuError> {
let driver = get_driver()?;
let mut major: i32 = 0;
let mut minor: i32 = 0;
let result = unsafe { (driver.cuDeviceGetAttribute)(&mut major, 75, self.device) };
CudaDriver::check(result)?;
let result = unsafe { (driver.cuDeviceGetAttribute)(&mut minor, 76, self.device) };
CudaDriver::check(result)?;
Ok((major, minor))
}
pub fn sm_target(&self) -> Result<String, GpuError> {
let (major, minor) = self.compute_capability()?;
Ok(format!("sm_{major}{minor}"))
}
pub fn multiprocessor_count(&self) -> Result<i32, GpuError> {
let driver = get_driver()?;
let mut count: i32 = 0;
let result = unsafe { (driver.cuDeviceGetAttribute)(&mut count, 16, self.device) };
CudaDriver::check(result)?;
Ok(count)
}
pub fn total_memory(&self) -> Result<usize, GpuError> {
let driver = get_driver()?;
let mut bytes: usize = 0;
let result = unsafe { (driver.cuDeviceTotalMem)(&mut bytes, self.device) };
CudaDriver::check(result)?;
Ok(bytes)
}
}
impl Drop for CudaContext {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuDevicePrimaryCtxRelease)(self.device);
}
}
}
}
pub fn device_count() -> Result<usize, GpuError> {
let driver = get_driver()?;
let mut count: i32 = 0;
let result = unsafe { (driver.cuDeviceGetCount)(&mut count) };
CudaDriver::check(result)?;
Ok(count as usize)
}
#[must_use]
pub fn cuda_available() -> bool {
device_count().map(|c| c > 0).unwrap_or(false)
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(not(feature = "cuda"))]
fn test_get_driver_without_feature() {
use super::get_driver;
let result = get_driver();
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_cuda_available_without_feature() {
use super::cuda_available;
assert!(!cuda_available());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_device_count_without_feature() {
use super::device_count;
let result = device_count();
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_context_new_without_feature() {
use super::CudaContext;
let result = CudaContext::new(0);
assert!(result.is_err());
}
}