use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use super::sys::{CUcontext, CUdevice, CudaDriver, CUDA_SUCCESS};
use crate::GpuError;
static CUDA_INITIALIZED: AtomicBool = AtomicBool::new(false);
pub fn get_driver() -> Result<&'static CudaDriver, GpuError> {
let driver = CudaDriver::load()
.ok_or_else(|| GpuError::CudaNotAvailable("CUDA driver not found".to_string()))?;
if !CUDA_INITIALIZED.swap(true, Ordering::SeqCst) {
let result = unsafe { (driver.cuInit)(0) };
if result != CUDA_SUCCESS {
CUDA_INITIALIZED.store(false, Ordering::SeqCst);
return Err(GpuError::DeviceInit(format!(
"cuInit failed with code {}",
result
)));
}
}
Ok(driver)
}
pub struct CudaContext {
device: CUdevice,
context: CUcontext,
}
unsafe impl Send for CudaContext {}
unsafe impl Sync for CudaContext {}
impl CudaContext {
pub fn new(device_ordinal: i32) -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut count: i32 = 0;
let result = unsafe { (driver.cuDeviceGetCount)(&mut count) };
CudaDriver::check(result)?;
if device_ordinal < 0 || device_ordinal >= count {
return Err(GpuError::DeviceNotFound(device_ordinal, count as usize));
}
let mut device: CUdevice = 0;
let result = unsafe { (driver.cuDeviceGet)(&mut device, device_ordinal) };
CudaDriver::check(result)?;
let mut context: CUcontext = ptr::null_mut();
let result = unsafe { (driver.cuDevicePrimaryCtxRetain)(&mut context, device) };
CudaDriver::check(result)?;
let result = unsafe { (driver.cuCtxSetCurrent)(context) };
if result != CUDA_SUCCESS {
unsafe { (driver.cuDevicePrimaryCtxRelease)(device) };
return Err(GpuError::DeviceInit(format!(
"cuCtxSetCurrent failed with code {}",
result
)));
}
Ok(Self { device, context })
}
#[must_use]
pub fn device(&self) -> i32 {
self.device
}
#[must_use]
pub fn raw(&self) -> CUcontext {
self.context
}
pub fn memory_info(&self) -> Result<(usize, usize), GpuError> {
let driver = get_driver()?;
let mut free: usize = 0;
let mut total: usize = 0;
let result = unsafe { (driver.cuMemGetInfo)(&mut free, &mut total) };
CudaDriver::check(result)?;
Ok((free, total))
}
pub fn make_current(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuCtxSetCurrent)(self.context) };
if result != CUDA_SUCCESS {
return Err(GpuError::DeviceInit(format!(
"cuCtxSetCurrent failed with code {}",
result
)));
}
Ok(())
}
pub fn synchronize(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuCtxSynchronize)() };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(e.to_string()))
}
pub fn device_name(&self) -> Result<String, GpuError> {
let driver = get_driver()?;
let mut name = [0i8; 256];
let result = unsafe { (driver.cuDeviceGetName)(name.as_mut_ptr(), 256, self.device) };
CudaDriver::check(result)?;
let name_str = unsafe {
std::ffi::CStr::from_ptr(name.as_ptr())
.to_string_lossy()
.into_owned()
};
Ok(name_str)
}
pub fn total_memory(&self) -> Result<usize, GpuError> {
let driver = get_driver()?;
let mut bytes: usize = 0;
let result = unsafe { (driver.cuDeviceTotalMem)(&mut bytes, self.device) };
CudaDriver::check(result)?;
Ok(bytes)
}
}
impl Drop for CudaContext {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuDevicePrimaryCtxRelease)(self.device);
}
}
}
}
pub fn device_count() -> Result<usize, GpuError> {
let driver = get_driver()?;
let mut count: i32 = 0;
let result = unsafe { (driver.cuDeviceGetCount)(&mut count) };
CudaDriver::check(result)?;
Ok(count as usize)
}
#[must_use]
pub fn cuda_available() -> bool {
device_count().map(|c| c > 0).unwrap_or(false)
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(not(feature = "cuda"))]
fn test_get_driver_without_feature() {
use super::get_driver;
let result = get_driver();
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_cuda_available_without_feature() {
use super::cuda_available;
assert!(!cuda_available());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_device_count_without_feature() {
use super::device_count;
let result = device_count();
assert!(result.is_err());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_context_new_without_feature() {
use super::CudaContext;
let result = CudaContext::new(0);
assert!(result.is_err());
}
}