use cudarc::driver::sys as driver_sys;
use cudarc::runtime::sys as runtime_sys;
use crate::error::GpuError;
const LIB_DRIVER: &str = "driver";
const LIB_RUNTIME: &str = "runtime";
fn driver_check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB_DRIVER,
msg: format!("{op}: {s:?}"),
})
}
}
fn runtime_check(s: runtime_sys::cudaError_t, op: &str) -> Result<(), GpuError> {
if s == runtime_sys::cudaError::cudaSuccess {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB_RUNTIME,
msg: format!("{op}: {s:?}"),
})
}
}
fn guarded<F, R>(op: &'static str, f: F) -> Result<R, GpuError>
where
F: FnOnce() -> Result<R, GpuError>,
{
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(r) => r,
Err(_) => Err(GpuError::Unrecoverable(format!(
"{op}: CUDA driver not loadable"
))),
}
}
pub fn mem_prefetch_async_v2(
dev_ptr: driver_sys::CUdeviceptr,
count: usize,
location: driver_sys::CUmemLocation,
flags: u32,
stream: driver_sys::CUstream,
) -> Result<(), GpuError> {
guarded("cuMemPrefetchAsync_v2", || {
let s =
unsafe { driver_sys::cuMemPrefetchAsync_v2(dev_ptr, count, location, flags, stream) };
driver_check(s, "cuMemPrefetchAsync_v2")
})
}
pub fn mem_advise_v2(
dev_ptr: driver_sys::CUdeviceptr,
count: usize,
advice: driver_sys::CUmem_advise,
location: driver_sys::CUmemLocation,
) -> Result<(), GpuError> {
guarded("cuMemAdvise_v2", || {
let s = unsafe { driver_sys::cuMemAdvise_v2(dev_ptr, count, advice, location) };
driver_check(s, "cuMemAdvise_v2")
})
}
#[cfg(feature = "cuda-ipc")]
pub fn ipc_get_mem_handle(
dev_ptr: driver_sys::CUdeviceptr,
) -> Result<driver_sys::CUipcMemHandle, GpuError> {
guarded("cuIpcGetMemHandle", || {
let mut handle = driver_sys::CUipcMemHandle_st {
reserved: [0; 64usize],
};
let s = unsafe { driver_sys::cuIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
driver_check(s, "cuIpcGetMemHandle")?;
Ok(handle)
})
}
#[cfg(feature = "cuda-ipc")]
pub fn ipc_open_mem_handle_v2(
handle: driver_sys::CUipcMemHandle,
flags: u32,
) -> Result<driver_sys::CUdeviceptr, GpuError> {
guarded("cuIpcOpenMemHandle_v2", || {
let mut dptr: driver_sys::CUdeviceptr = 0;
let s = unsafe { driver_sys::cuIpcOpenMemHandle_v2(&mut dptr as *mut _, handle, flags) };
driver_check(s, "cuIpcOpenMemHandle_v2")?;
Ok(dptr)
})
}
#[cfg(feature = "cuda-ipc")]
pub fn ipc_close_mem_handle(dev_ptr: driver_sys::CUdeviceptr) -> Result<(), GpuError> {
guarded("cuIpcCloseMemHandle", || {
let s = unsafe { driver_sys::cuIpcCloseMemHandle(dev_ptr) };
driver_check(s, "cuIpcCloseMemHandle")
})
}
#[cfg(feature = "cuda-ipc")]
pub fn ipc_get_event_handle(
event: driver_sys::CUevent,
) -> Result<driver_sys::CUipcEventHandle, GpuError> {
guarded("cuIpcGetEventHandle", || {
let mut handle = driver_sys::CUipcEventHandle_st {
reserved: [0; 64usize],
};
let s = unsafe { driver_sys::cuIpcGetEventHandle(&mut handle as *mut _, event) };
driver_check(s, "cuIpcGetEventHandle")?;
Ok(handle)
})
}
#[cfg(feature = "cuda-ipc")]
pub fn ipc_open_event_handle(
handle: driver_sys::CUipcEventHandle,
) -> Result<driver_sys::CUevent, GpuError> {
guarded("cuIpcOpenEventHandle", || {
let mut event: driver_sys::CUevent = std::ptr::null_mut();
let s = unsafe { driver_sys::cuIpcOpenEventHandle(&mut event as *mut _, handle) };
driver_check(s, "cuIpcOpenEventHandle")?;
Ok(event)
})
}
pub fn module_load_data(image: *const std::ffi::c_void) -> Result<driver_sys::CUmodule, GpuError> {
guarded("cuModuleLoadData", || {
let mut m: driver_sys::CUmodule = std::ptr::null_mut();
let s = unsafe { driver_sys::cuModuleLoadData(&mut m as *mut _, image) };
driver_check(s, "cuModuleLoadData")?;
Ok(m)
})
}
pub fn module_unload(m: driver_sys::CUmodule) -> Result<(), GpuError> {
guarded("cuModuleUnload", || {
let s = unsafe { driver_sys::cuModuleUnload(m) };
driver_check(s, "cuModuleUnload")
})
}
pub fn module_get_function(
m: driver_sys::CUmodule,
name: &std::ffi::CStr,
) -> Result<driver_sys::CUfunction, GpuError> {
guarded("cuModuleGetFunction", || {
let mut f: driver_sys::CUfunction = std::ptr::null_mut();
let s = unsafe { driver_sys::cuModuleGetFunction(&mut f as *mut _, m, name.as_ptr()) };
driver_check(s, "cuModuleGetFunction")?;
Ok(f)
})
}
#[allow(clippy::too_many_arguments)]
pub fn launch_kernel(
f: driver_sys::CUfunction,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_bytes: u32,
stream: driver_sys::CUstream,
kernel_params: *mut *mut std::ffi::c_void,
) -> Result<(), GpuError> {
guarded("cuLaunchKernel", || {
let s = unsafe {
driver_sys::cuLaunchKernel(
f,
grid.0,
grid.1,
grid.2,
block.0,
block.1,
block.2,
shared_bytes,
stream,
kernel_params,
std::ptr::null_mut(),
)
};
driver_check(s, "cuLaunchKernel")
})
}
#[allow(clippy::too_many_arguments)]
pub fn launch_cooperative_kernel(
f: driver_sys::CUfunction,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_bytes: u32,
stream: driver_sys::CUstream,
kernel_params: *mut *mut std::ffi::c_void,
) -> Result<(), GpuError> {
guarded("cuLaunchCooperativeKernel", || {
let s = unsafe {
driver_sys::cuLaunchCooperativeKernel(
f,
grid.0,
grid.1,
grid.2,
block.0,
block.1,
block.2,
shared_bytes,
stream,
kernel_params,
)
};
driver_check(s, "cuLaunchCooperativeKernel")
})
}
#[cfg(feature = "cuda-ipc")]
pub fn runtime_ipc_get_mem_handle(
dev_ptr: *mut std::ffi::c_void,
) -> Result<runtime_sys::cudaIpcMemHandle_t, GpuError> {
guarded("cudaIpcGetMemHandle", || {
let mut handle = runtime_sys::cudaIpcMemHandle_st {
reserved: [0; 64usize],
};
let s = unsafe { runtime_sys::cudaIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
runtime_check(s, "cudaIpcGetMemHandle")?;
Ok(handle)
})
}