use std::ffi::c_void;
use std::os::raw::{c_char, c_int, c_uint};
use crate::GpuError;
pub type CUresult = c_int;
pub type CUdevice = c_int;
pub type CUcontext = *mut c_void;
pub type CUmodule = *mut c_void;
pub type CUfunction = *mut c_void;
pub type CUstream = *mut c_void;
pub type CUdeviceptr = u64;
pub type CUgraph = *mut c_void;
pub type CUgraphExec = *mut c_void;
pub const CUDA_SUCCESS: CUresult = 0;
pub const CUDA_ERROR_INVALID_VALUE: CUresult = 1;
pub const CUDA_ERROR_OUT_OF_MEMORY: CUresult = 2;
pub const CUDA_ERROR_NOT_INITIALIZED: CUresult = 3;
pub const CUDA_ERROR_DEINITIALIZED: CUresult = 4;
pub const CUDA_ERROR_NO_DEVICE: CUresult = 100;
pub const CUDA_ERROR_INVALID_DEVICE: CUresult = 101;
pub const CUDA_ERROR_INVALID_PTX: CUresult = 218;
pub const CUDA_ERROR_NOT_FOUND: CUresult = 500;
pub const CU_STREAM_DEFAULT: c_uint = 0;
pub const CU_STREAM_NON_BLOCKING: c_uint = 1;
#[allow(non_snake_case)]
pub struct CudaDriver {
pub cuInit: unsafe extern "C" fn(flags: c_uint) -> CUresult,
pub cuDeviceGetCount: unsafe extern "C" fn(count: *mut c_int) -> CUresult,
pub cuDeviceGet: unsafe extern "C" fn(device: *mut CUdevice, ordinal: c_int) -> CUresult,
pub cuDeviceGetName:
unsafe extern "C" fn(name: *mut c_char, len: c_int, device: CUdevice) -> CUresult,
pub cuDeviceTotalMem: unsafe extern "C" fn(bytes: *mut usize, device: CUdevice) -> CUresult,
pub cuDevicePrimaryCtxRetain:
unsafe extern "C" fn(ctx: *mut CUcontext, device: CUdevice) -> CUresult,
pub cuDevicePrimaryCtxRelease: unsafe extern "C" fn(device: CUdevice) -> CUresult,
pub cuCtxSetCurrent: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
pub cuCtxSynchronize: unsafe extern "C" fn() -> CUresult,
pub cuModuleLoadData:
unsafe extern "C" fn(module: *mut CUmodule, image: *const c_void) -> CUresult,
pub cuModuleUnload: unsafe extern "C" fn(module: CUmodule) -> CUresult,
pub cuModuleGetFunction: unsafe extern "C" fn(
func: *mut CUfunction,
module: CUmodule,
name: *const c_char,
) -> CUresult,
pub cuMemAlloc: unsafe extern "C" fn(ptr: *mut CUdeviceptr, size: usize) -> CUresult,
pub cuMemFree: unsafe extern "C" fn(ptr: CUdeviceptr) -> CUresult,
pub cuMemcpyHtoD:
unsafe extern "C" fn(dst: CUdeviceptr, src: *const c_void, size: usize) -> CUresult,
pub cuMemcpyDtoH:
unsafe extern "C" fn(dst: *mut c_void, src: CUdeviceptr, size: usize) -> CUresult,
pub cuMemcpyHtoDAsync: unsafe extern "C" fn(
dst: CUdeviceptr,
src: *const c_void,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemcpyDtoHAsync: unsafe extern "C" fn(
dst: *mut c_void,
src: CUdeviceptr,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemcpyDtoD:
unsafe extern "C" fn(dst: CUdeviceptr, src: CUdeviceptr, size: usize) -> CUresult,
pub cuMemcpyDtoDAsync: unsafe extern "C" fn(
dst: CUdeviceptr,
src: CUdeviceptr,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemGetInfo: unsafe extern "C" fn(free: *mut usize, total: *mut usize) -> CUresult,
pub cuStreamCreate: unsafe extern "C" fn(stream: *mut CUstream, flags: c_uint) -> CUresult,
pub cuStreamDestroy: unsafe extern "C" fn(stream: CUstream) -> CUresult,
pub cuStreamSynchronize: unsafe extern "C" fn(stream: CUstream) -> CUresult,
#[allow(clippy::type_complexity)]
pub cuLaunchKernel: unsafe extern "C" fn(
func: CUfunction,
grid_dim_x: c_uint,
grid_dim_y: c_uint,
grid_dim_z: c_uint,
block_dim_x: c_uint,
block_dim_y: c_uint,
block_dim_z: c_uint,
shared_mem_bytes: c_uint,
stream: CUstream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> CUresult,
pub cuGraphCreate: unsafe extern "C" fn(graph: *mut CUgraph, flags: c_uint) -> CUresult,
pub cuGraphDestroy: unsafe extern "C" fn(graph: CUgraph) -> CUresult,
pub cuGraphInstantiateWithFlags:
unsafe extern "C" fn(exec: *mut CUgraphExec, graph: CUgraph, flags: u64) -> CUresult,
pub cuGraphExecDestroy: unsafe extern "C" fn(exec: CUgraphExec) -> CUresult,
pub cuGraphLaunch: unsafe extern "C" fn(exec: CUgraphExec, stream: CUstream) -> CUresult,
pub cuStreamBeginCapture: unsafe extern "C" fn(stream: CUstream, mode: c_uint) -> CUresult,
pub cuStreamEndCapture: unsafe extern "C" fn(stream: CUstream, graph: *mut CUgraph) -> CUresult,
}
#[cfg(feature = "cuda")]
mod loading {
use super::*;
use libloading::{Library, Symbol};
use std::sync::OnceLock;
static DRIVER: OnceLock<Option<CudaDriver>> = OnceLock::new();
static LIBRARY: OnceLock<Option<Library>> = OnceLock::new();
impl CudaDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
let _ = LIBRARY.get_or_init(|| {
#[cfg(target_os = "linux")]
let lib_names = ["libcuda.so.1", "libcuda.so"];
#[cfg(target_os = "windows")]
let lib_names = ["nvcuda.dll"];
#[cfg(target_os = "macos")]
let lib_names: [&str; 0] = [];
for name in lib_names {
if let Ok(lib) = unsafe { Library::new(name) } {
return Some(lib);
}
}
None
});
DRIVER
.get_or_init(|| {
let lib = LIBRARY.get()?.as_ref()?;
Self::load_from_library(lib)
})
.as_ref()
}
fn load_from_library(lib: &Library) -> Option<Self> {
unsafe {
macro_rules! load_sym {
($name:ident, $ty:ty) => {{
let sym: Symbol<'_, $ty> = lib.get(stringify!($name).as_bytes()).ok()?;
*sym
}};
}
type FnInit = unsafe extern "C" fn(c_uint) -> CUresult;
type FnDeviceGetCount = unsafe extern "C" fn(*mut c_int) -> CUresult;
type FnDeviceGet = unsafe extern "C" fn(*mut CUdevice, c_int) -> CUresult;
type FnDeviceGetName =
unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUresult;
type FnDeviceTotalMem = unsafe extern "C" fn(*mut usize, CUdevice) -> CUresult;
type FnPrimaryCtxRetain =
unsafe extern "C" fn(*mut CUcontext, CUdevice) -> CUresult;
type FnPrimaryCtxRelease = unsafe extern "C" fn(CUdevice) -> CUresult;
type FnCtxSetCurrent = unsafe extern "C" fn(CUcontext) -> CUresult;
type FnCtxSync = unsafe extern "C" fn() -> CUresult;
type FnModuleLoadData =
unsafe extern "C" fn(*mut CUmodule, *const c_void) -> CUresult;
type FnModuleUnload = unsafe extern "C" fn(CUmodule) -> CUresult;
type FnModuleGetFunction =
unsafe extern "C" fn(*mut CUfunction, CUmodule, *const c_char) -> CUresult;
type FnMemAlloc = unsafe extern "C" fn(*mut CUdeviceptr, usize) -> CUresult;
type FnMemFree = unsafe extern "C" fn(CUdeviceptr) -> CUresult;
type FnMemcpyHtoD =
unsafe extern "C" fn(CUdeviceptr, *const c_void, usize) -> CUresult;
type FnMemcpyDtoH =
unsafe extern "C" fn(*mut c_void, CUdeviceptr, usize) -> CUresult;
type FnMemcpyHtoDAsync =
unsafe extern "C" fn(CUdeviceptr, *const c_void, usize, CUstream) -> CUresult;
type FnMemcpyDtoHAsync =
unsafe extern "C" fn(*mut c_void, CUdeviceptr, usize, CUstream) -> CUresult;
type FnMemcpyDtoD =
unsafe extern "C" fn(CUdeviceptr, CUdeviceptr, usize) -> CUresult;
type FnMemcpyDtoDAsync =
unsafe extern "C" fn(CUdeviceptr, CUdeviceptr, usize, CUstream) -> CUresult;
type FnMemGetInfo = unsafe extern "C" fn(*mut usize, *mut usize) -> CUresult;
type FnStreamCreate = unsafe extern "C" fn(*mut CUstream, c_uint) -> CUresult;
type FnStreamDestroy = unsafe extern "C" fn(CUstream) -> CUresult;
type FnStreamSync = unsafe extern "C" fn(CUstream) -> CUresult;
type FnLaunchKernel = unsafe extern "C" fn(
CUfunction,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
CUstream,
*mut *mut c_void,
*mut *mut c_void,
) -> CUresult;
type FnGraphCreate = unsafe extern "C" fn(*mut CUgraph, c_uint) -> CUresult;
type FnGraphDestroy = unsafe extern "C" fn(CUgraph) -> CUresult;
type FnGraphInstantiate =
unsafe extern "C" fn(*mut CUgraphExec, CUgraph, u64) -> CUresult;
type FnGraphExecDestroy = unsafe extern "C" fn(CUgraphExec) -> CUresult;
type FnGraphLaunch = unsafe extern "C" fn(CUgraphExec, CUstream) -> CUresult;
type FnStreamBeginCapture = unsafe extern "C" fn(CUstream, c_uint) -> CUresult;
type FnStreamEndCapture = unsafe extern "C" fn(CUstream, *mut CUgraph) -> CUresult;
Some(CudaDriver {
cuInit: load_sym!(cuInit, FnInit),
cuDeviceGetCount: load_sym!(cuDeviceGetCount, FnDeviceGetCount),
cuDeviceGet: load_sym!(cuDeviceGet, FnDeviceGet),
cuDeviceGetName: load_sym!(cuDeviceGetName, FnDeviceGetName),
cuDeviceTotalMem: load_sym!(cuDeviceTotalMem_v2, FnDeviceTotalMem),
cuDevicePrimaryCtxRetain: load_sym!(
cuDevicePrimaryCtxRetain,
FnPrimaryCtxRetain
),
cuDevicePrimaryCtxRelease: load_sym!(
cuDevicePrimaryCtxRelease_v2,
FnPrimaryCtxRelease
),
cuCtxSetCurrent: load_sym!(cuCtxSetCurrent, FnCtxSetCurrent),
cuCtxSynchronize: load_sym!(cuCtxSynchronize, FnCtxSync),
cuModuleLoadData: load_sym!(cuModuleLoadData, FnModuleLoadData),
cuModuleUnload: load_sym!(cuModuleUnload, FnModuleUnload),
cuModuleGetFunction: load_sym!(cuModuleGetFunction, FnModuleGetFunction),
cuMemAlloc: load_sym!(cuMemAlloc_v2, FnMemAlloc),
cuMemFree: load_sym!(cuMemFree_v2, FnMemFree),
cuMemcpyHtoD: load_sym!(cuMemcpyHtoD_v2, FnMemcpyHtoD),
cuMemcpyDtoH: load_sym!(cuMemcpyDtoH_v2, FnMemcpyDtoH),
cuMemcpyHtoDAsync: load_sym!(cuMemcpyHtoDAsync_v2, FnMemcpyHtoDAsync),
cuMemcpyDtoHAsync: load_sym!(cuMemcpyDtoHAsync_v2, FnMemcpyDtoHAsync),
cuMemcpyDtoD: load_sym!(cuMemcpyDtoD_v2, FnMemcpyDtoD),
cuMemcpyDtoDAsync: load_sym!(cuMemcpyDtoDAsync_v2, FnMemcpyDtoDAsync),
cuMemGetInfo: load_sym!(cuMemGetInfo_v2, FnMemGetInfo),
cuStreamCreate: load_sym!(cuStreamCreate, FnStreamCreate),
cuStreamDestroy: load_sym!(cuStreamDestroy_v2, FnStreamDestroy),
cuStreamSynchronize: load_sym!(cuStreamSynchronize, FnStreamSync),
cuLaunchKernel: load_sym!(cuLaunchKernel, FnLaunchKernel),
cuGraphCreate: load_sym!(cuGraphCreate, FnGraphCreate),
cuGraphDestroy: load_sym!(cuGraphDestroy, FnGraphDestroy),
cuGraphInstantiateWithFlags: load_sym!(
cuGraphInstantiateWithFlags,
FnGraphInstantiate
),
cuGraphExecDestroy: load_sym!(cuGraphExecDestroy, FnGraphExecDestroy),
cuGraphLaunch: load_sym!(cuGraphLaunch, FnGraphLaunch),
cuStreamBeginCapture: load_sym!(cuStreamBeginCapture, FnStreamBeginCapture),
cuStreamEndCapture: load_sym!(cuStreamEndCapture, FnStreamEndCapture),
})
}
}
pub fn check(result: CUresult) -> Result<(), GpuError> {
if result == CUDA_SUCCESS {
Ok(())
} else {
Err(GpuError::CudaDriver(
cuda_error_string(result).to_string(),
result,
))
}
}
}
}
#[cfg(not(feature = "cuda"))]
mod loading {
use super::*;
impl CudaDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
None
}
pub fn check(_result: CUresult) -> Result<(), GpuError> {
Err(GpuError::CudaNotAvailable(
"cuda feature not enabled".to_string(),
))
}
}
}
#[must_use]
pub fn cuda_error_string(code: CUresult) -> &'static str {
match code {
CUDA_SUCCESS => "CUDA_SUCCESS",
CUDA_ERROR_INVALID_VALUE => "CUDA_ERROR_INVALID_VALUE",
CUDA_ERROR_OUT_OF_MEMORY => "CUDA_ERROR_OUT_OF_MEMORY",
CUDA_ERROR_NOT_INITIALIZED => "CUDA_ERROR_NOT_INITIALIZED",
CUDA_ERROR_DEINITIALIZED => "CUDA_ERROR_DEINITIALIZED",
CUDA_ERROR_NO_DEVICE => "CUDA_ERROR_NO_DEVICE",
CUDA_ERROR_INVALID_DEVICE => "CUDA_ERROR_INVALID_DEVICE",
CUDA_ERROR_INVALID_PTX => "CUDA_ERROR_INVALID_PTX",
CUDA_ERROR_NOT_FOUND => "CUDA_ERROR_NOT_FOUND",
_ => "CUDA_ERROR_UNKNOWN",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_error_string_success() {
assert_eq!(cuda_error_string(CUDA_SUCCESS), "CUDA_SUCCESS");
}
#[test]
fn test_cuda_error_string_oom() {
assert_eq!(
cuda_error_string(CUDA_ERROR_OUT_OF_MEMORY),
"CUDA_ERROR_OUT_OF_MEMORY"
);
}
#[test]
fn test_cuda_error_string_unknown() {
assert_eq!(cuda_error_string(99999), "CUDA_ERROR_UNKNOWN");
}
#[test]
fn test_cuda_constants() {
assert_eq!(CUDA_SUCCESS, 0);
assert_eq!(CUDA_ERROR_NO_DEVICE, 100);
assert_eq!(CUDA_ERROR_INVALID_PTX, 218);
}
#[test]
fn test_custream_flags() {
assert_eq!(CU_STREAM_DEFAULT, 0);
assert_eq!(CU_STREAM_NON_BLOCKING, 1);
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_driver_load_without_feature() {
assert!(CudaDriver::load().is_none());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_check_without_feature() {
let result = CudaDriver::check(CUDA_SUCCESS);
assert!(result.is_err());
}
#[test]
fn test_all_error_strings() {
assert_eq!(
cuda_error_string(CUDA_ERROR_INVALID_VALUE),
"CUDA_ERROR_INVALID_VALUE"
);
assert_eq!(
cuda_error_string(CUDA_ERROR_NOT_INITIALIZED),
"CUDA_ERROR_NOT_INITIALIZED"
);
assert_eq!(
cuda_error_string(CUDA_ERROR_DEINITIALIZED),
"CUDA_ERROR_DEINITIALIZED"
);
assert_eq!(
cuda_error_string(CUDA_ERROR_INVALID_DEVICE),
"CUDA_ERROR_INVALID_DEVICE"
);
assert_eq!(
cuda_error_string(CUDA_ERROR_NOT_FOUND),
"CUDA_ERROR_NOT_FOUND"
);
}
#[test]
fn test_error_codes_are_distinct() {
let codes = [
CUDA_SUCCESS,
CUDA_ERROR_INVALID_VALUE,
CUDA_ERROR_OUT_OF_MEMORY,
CUDA_ERROR_NOT_INITIALIZED,
CUDA_ERROR_DEINITIALIZED,
CUDA_ERROR_NO_DEVICE,
CUDA_ERROR_INVALID_DEVICE,
CUDA_ERROR_INVALID_PTX,
CUDA_ERROR_NOT_FOUND,
];
for i in 0..codes.len() {
for j in (i + 1)..codes.len() {
assert_ne!(
codes[i], codes[j],
"Error codes at {} and {} are equal",
i, j
);
}
}
}
#[test]
fn test_type_sizes() {
assert_eq!(std::mem::size_of::<CUresult>(), std::mem::size_of::<i32>());
assert_eq!(std::mem::size_of::<CUdevice>(), std::mem::size_of::<i32>());
assert_eq!(
std::mem::size_of::<CUdeviceptr>(),
std::mem::size_of::<u64>()
);
assert_eq!(
std::mem::size_of::<CUcontext>(),
std::mem::size_of::<*mut ()>()
);
assert_eq!(
std::mem::size_of::<CUmodule>(),
std::mem::size_of::<*mut ()>()
);
assert_eq!(
std::mem::size_of::<CUfunction>(),
std::mem::size_of::<*mut ()>()
);
assert_eq!(
std::mem::size_of::<CUstream>(),
std::mem::size_of::<*mut ()>()
);
}
#[test]
fn test_null_pointers() {
use std::ptr;
let ctx: CUcontext = ptr::null_mut();
let module: CUmodule = ptr::null_mut();
let func: CUfunction = ptr::null_mut();
let stream: CUstream = ptr::null_mut();
assert!(ctx.is_null());
assert!(module.is_null());
assert!(func.is_null());
assert!(stream.is_null());
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_error_string_never_panics(code in any::<i32>()) {
let _ = cuda_error_string(code);
}
#[test]
fn prop_error_string_valid(code in any::<i32>()) {
let result = cuda_error_string(code);
prop_assert!(!result.is_empty());
prop_assert!(result.starts_with("CUDA_"));
}
#[test]
fn prop_known_errors_have_specific_string(
code in prop_oneof![
Just(CUDA_SUCCESS),
Just(CUDA_ERROR_INVALID_VALUE),
Just(CUDA_ERROR_OUT_OF_MEMORY),
Just(CUDA_ERROR_NO_DEVICE),
Just(CUDA_ERROR_INVALID_PTX),
]
) {
let result = cuda_error_string(code);
prop_assert_ne!(result, "CUDA_ERROR_UNKNOWN");
}
}
}