use std::ffi::c_void;
use std::os::raw::{c_char, c_int, c_uint};
use crate::GpuError;
pub type CUresult = c_int;
pub type CUdevice = c_int;
pub type CUcontext = *mut c_void;
pub type CUmodule = *mut c_void;
pub type CUfunction = *mut c_void;
pub type CUstream = *mut c_void;
pub type CUdeviceptr = u64;
pub type CUgraph = *mut c_void;
pub type CUgraphExec = *mut c_void;
pub const CUDA_SUCCESS: CUresult = 0;
pub const CUDA_ERROR_INVALID_VALUE: CUresult = 1;
pub const CUDA_ERROR_OUT_OF_MEMORY: CUresult = 2;
pub const CUDA_ERROR_NOT_INITIALIZED: CUresult = 3;
pub const CUDA_ERROR_DEINITIALIZED: CUresult = 4;
pub const CUDA_ERROR_NO_DEVICE: CUresult = 100;
pub const CUDA_ERROR_INVALID_DEVICE: CUresult = 101;
pub const CUDA_ERROR_INVALID_IMAGE: CUresult = 200;
pub const CUDA_ERROR_INVALID_CONTEXT: CUresult = 201;
pub const CUDA_ERROR_NO_BINARY_FOR_GPU: CUresult = 209;
pub const CUDA_ERROR_INVALID_PTX: CUresult = 218;
pub const CUDA_ERROR_NOT_FOUND: CUresult = 500;
pub const CUDA_ERROR_NOT_READY: CUresult = 600;
pub const CUDA_ERROR_ILLEGAL_ADDRESS: CUresult = 700;
pub const CUDA_ERROR_ILLEGAL_INSTRUCTION: CUresult = 715;
pub const CUDA_ERROR_LAUNCH_FAILED: CUresult = 719;
pub const CU_JIT_TARGET: c_uint = 9;
pub const CU_JIT_ERROR_LOG_BUFFER: c_uint = 5;
pub const CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES: c_uint = 6;
pub const CU_TARGET_COMPUTE_70: c_uint = 70;
pub const CU_TARGET_COMPUTE_75: c_uint = 75;
pub const CU_TARGET_COMPUTE_80: c_uint = 80;
pub const CU_TARGET_COMPUTE_86: c_uint = 86;
pub const CU_TARGET_COMPUTE_87: c_uint = 87;
pub const CU_TARGET_COMPUTE_89: c_uint = 89;
pub const CU_TARGET_COMPUTE_90: c_uint = 90;
pub const CU_STREAM_DEFAULT: c_uint = 0;
pub const CU_STREAM_NON_BLOCKING: c_uint = 1;
#[allow(non_snake_case)]
pub struct CudaDriver {
pub cuInit: unsafe extern "C" fn(flags: c_uint) -> CUresult,
pub cuDeviceGetCount: unsafe extern "C" fn(count: *mut c_int) -> CUresult,
pub cuDeviceGet: unsafe extern "C" fn(device: *mut CUdevice, ordinal: c_int) -> CUresult,
pub cuDeviceGetName:
unsafe extern "C" fn(name: *mut c_char, len: c_int, device: CUdevice) -> CUresult,
pub cuDeviceTotalMem: unsafe extern "C" fn(bytes: *mut usize, device: CUdevice) -> CUresult,
pub cuDeviceGetAttribute:
unsafe extern "C" fn(pi: *mut c_int, attrib: c_int, device: CUdevice) -> CUresult,
pub cuDevicePrimaryCtxRetain:
unsafe extern "C" fn(ctx: *mut CUcontext, device: CUdevice) -> CUresult,
pub cuDevicePrimaryCtxRelease: unsafe extern "C" fn(device: CUdevice) -> CUresult,
pub cuCtxSetCurrent: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
pub cuCtxSynchronize: unsafe extern "C" fn() -> CUresult,
pub cuModuleLoadData:
unsafe extern "C" fn(module: *mut CUmodule, image: *const c_void) -> CUresult,
#[allow(clippy::type_complexity)]
pub cuModuleLoadDataEx: unsafe extern "C" fn(
module: *mut CUmodule,
image: *const c_void,
num_options: c_uint,
options: *mut c_uint,
option_values: *mut *mut c_void,
) -> CUresult,
pub cuModuleUnload: unsafe extern "C" fn(module: CUmodule) -> CUresult,
pub cuModuleGetFunction: unsafe extern "C" fn(
func: *mut CUfunction,
module: CUmodule,
name: *const c_char,
) -> CUresult,
pub cuMemAlloc: unsafe extern "C" fn(ptr: *mut CUdeviceptr, size: usize) -> CUresult,
pub cuMemFree: unsafe extern "C" fn(ptr: CUdeviceptr) -> CUresult,
pub cuMemcpyHtoD:
unsafe extern "C" fn(dst: CUdeviceptr, src: *const c_void, size: usize) -> CUresult,
pub cuMemcpyDtoH:
unsafe extern "C" fn(dst: *mut c_void, src: CUdeviceptr, size: usize) -> CUresult,
pub cuMemcpyHtoDAsync: unsafe extern "C" fn(
dst: CUdeviceptr,
src: *const c_void,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemcpyDtoHAsync: unsafe extern "C" fn(
dst: *mut c_void,
src: CUdeviceptr,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemcpyDtoD:
unsafe extern "C" fn(dst: CUdeviceptr, src: CUdeviceptr, size: usize) -> CUresult,
pub cuMemcpyDtoDAsync: unsafe extern "C" fn(
dst: CUdeviceptr,
src: CUdeviceptr,
size: usize,
stream: CUstream,
) -> CUresult,
pub cuMemGetInfo: unsafe extern "C" fn(free: *mut usize, total: *mut usize) -> CUresult,
pub cuStreamCreate: unsafe extern "C" fn(stream: *mut CUstream, flags: c_uint) -> CUresult,
pub cuStreamDestroy: unsafe extern "C" fn(stream: CUstream) -> CUresult,
pub cuStreamSynchronize: unsafe extern "C" fn(stream: CUstream) -> CUresult,
#[allow(clippy::type_complexity)]
pub cuLaunchKernel: unsafe extern "C" fn(
func: CUfunction,
grid_dim_x: c_uint,
grid_dim_y: c_uint,
grid_dim_z: c_uint,
block_dim_x: c_uint,
block_dim_y: c_uint,
block_dim_z: c_uint,
shared_mem_bytes: c_uint,
stream: CUstream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> CUresult,
pub cuGraphCreate: unsafe extern "C" fn(graph: *mut CUgraph, flags: c_uint) -> CUresult,
pub cuGraphDestroy: unsafe extern "C" fn(graph: CUgraph) -> CUresult,
pub cuGraphInstantiateWithFlags:
unsafe extern "C" fn(exec: *mut CUgraphExec, graph: CUgraph, flags: u64) -> CUresult,
pub cuGraphExecDestroy: unsafe extern "C" fn(exec: CUgraphExec) -> CUresult,
pub cuGraphLaunch: unsafe extern "C" fn(exec: CUgraphExec, stream: CUstream) -> CUresult,
pub cuStreamBeginCapture: unsafe extern "C" fn(stream: CUstream, mode: c_uint) -> CUresult,
pub cuStreamEndCapture: unsafe extern "C" fn(stream: CUstream, graph: *mut CUgraph) -> CUresult,
}
#[cfg(feature = "cuda")]
mod loading {
use super::*;
use libloading::{Library, Symbol};
use std::sync::OnceLock;
static DRIVER: OnceLock<Option<CudaDriver>> = OnceLock::new();
static LIBRARY: OnceLock<Option<Library>> = OnceLock::new();
impl CudaDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
let _ = LIBRARY.get_or_init(|| {
#[cfg(target_os = "linux")]
let lib_names = ["libcuda.so.1", "libcuda.so"];
#[cfg(target_os = "windows")]
let lib_names = ["nvcuda.dll"];
#[cfg(target_os = "macos")]
let lib_names: [&str; 0] = [];
for name in lib_names {
if let Ok(lib) = unsafe { Library::new(name) } {
return Some(lib);
}
}
None
});
DRIVER
.get_or_init(|| {
let lib = LIBRARY.get()?.as_ref()?;
Self::load_from_library(lib)
})
.as_ref()
}
fn load_from_library(lib: &Library) -> Option<Self> {
unsafe {
macro_rules! load_sym {
($name:ident, $ty:ty) => {{
let sym: Symbol<'_, $ty> = lib.get(stringify!($name).as_bytes()).ok()?;
*sym
}};
}
type FnInit = unsafe extern "C" fn(c_uint) -> CUresult;
type FnDeviceGetCount = unsafe extern "C" fn(*mut c_int) -> CUresult;
type FnDeviceGet = unsafe extern "C" fn(*mut CUdevice, c_int) -> CUresult;
type FnDeviceGetName =
unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUresult;
type FnDeviceTotalMem = unsafe extern "C" fn(*mut usize, CUdevice) -> CUresult;
type FnDeviceGetAttribute =
unsafe extern "C" fn(*mut c_int, c_int, CUdevice) -> CUresult;
type FnPrimaryCtxRetain =
unsafe extern "C" fn(*mut CUcontext, CUdevice) -> CUresult;
type FnPrimaryCtxRelease = unsafe extern "C" fn(CUdevice) -> CUresult;
type FnCtxSetCurrent = unsafe extern "C" fn(CUcontext) -> CUresult;
type FnCtxSync = unsafe extern "C" fn() -> CUresult;
type FnModuleLoadData =
unsafe extern "C" fn(*mut CUmodule, *const c_void) -> CUresult;
type FnModuleLoadDataEx = unsafe extern "C" fn(
*mut CUmodule,
*const c_void,
c_uint,
*mut c_uint,
*mut *mut c_void,
) -> CUresult;
type FnModuleUnload = unsafe extern "C" fn(CUmodule) -> CUresult;
type FnModuleGetFunction =
unsafe extern "C" fn(*mut CUfunction, CUmodule, *const c_char) -> CUresult;
type FnMemAlloc = unsafe extern "C" fn(*mut CUdeviceptr, usize) -> CUresult;
type FnMemFree = unsafe extern "C" fn(CUdeviceptr) -> CUresult;
type FnMemcpyHtoD =
unsafe extern "C" fn(CUdeviceptr, *const c_void, usize) -> CUresult;
type FnMemcpyDtoH =
unsafe extern "C" fn(*mut c_void, CUdeviceptr, usize) -> CUresult;
type FnMemcpyHtoDAsync =
unsafe extern "C" fn(CUdeviceptr, *const c_void, usize, CUstream) -> CUresult;
type FnMemcpyDtoHAsync =
unsafe extern "C" fn(*mut c_void, CUdeviceptr, usize, CUstream) -> CUresult;
type FnMemcpyDtoD =
unsafe extern "C" fn(CUdeviceptr, CUdeviceptr, usize) -> CUresult;
type FnMemcpyDtoDAsync =
unsafe extern "C" fn(CUdeviceptr, CUdeviceptr, usize, CUstream) -> CUresult;
type FnMemGetInfo = unsafe extern "C" fn(*mut usize, *mut usize) -> CUresult;
type FnStreamCreate = unsafe extern "C" fn(*mut CUstream, c_uint) -> CUresult;
type FnStreamDestroy = unsafe extern "C" fn(CUstream) -> CUresult;
type FnStreamSync = unsafe extern "C" fn(CUstream) -> CUresult;
type FnLaunchKernel = unsafe extern "C" fn(
CUfunction,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
CUstream,
*mut *mut c_void,
*mut *mut c_void,
) -> CUresult;
type FnGraphCreate = unsafe extern "C" fn(*mut CUgraph, c_uint) -> CUresult;
type FnGraphDestroy = unsafe extern "C" fn(CUgraph) -> CUresult;
type FnGraphInstantiate =
unsafe extern "C" fn(*mut CUgraphExec, CUgraph, u64) -> CUresult;
type FnGraphExecDestroy = unsafe extern "C" fn(CUgraphExec) -> CUresult;
type FnGraphLaunch = unsafe extern "C" fn(CUgraphExec, CUstream) -> CUresult;
type FnStreamBeginCapture = unsafe extern "C" fn(CUstream, c_uint) -> CUresult;
type FnStreamEndCapture = unsafe extern "C" fn(CUstream, *mut CUgraph) -> CUresult;
Some(CudaDriver {
cuInit: load_sym!(cuInit, FnInit),
cuDeviceGetCount: load_sym!(cuDeviceGetCount, FnDeviceGetCount),
cuDeviceGet: load_sym!(cuDeviceGet, FnDeviceGet),
cuDeviceGetName: load_sym!(cuDeviceGetName, FnDeviceGetName),
cuDeviceTotalMem: load_sym!(cuDeviceTotalMem_v2, FnDeviceTotalMem),
cuDeviceGetAttribute: load_sym!(cuDeviceGetAttribute, FnDeviceGetAttribute),
cuDevicePrimaryCtxRetain: load_sym!(
cuDevicePrimaryCtxRetain,
FnPrimaryCtxRetain
),
cuDevicePrimaryCtxRelease: load_sym!(
cuDevicePrimaryCtxRelease_v2,
FnPrimaryCtxRelease
),
cuCtxSetCurrent: load_sym!(cuCtxSetCurrent, FnCtxSetCurrent),
cuCtxSynchronize: load_sym!(cuCtxSynchronize, FnCtxSync),
cuModuleLoadData: load_sym!(cuModuleLoadData, FnModuleLoadData),
cuModuleLoadDataEx: load_sym!(cuModuleLoadDataEx, FnModuleLoadDataEx),
cuModuleUnload: load_sym!(cuModuleUnload, FnModuleUnload),
cuModuleGetFunction: load_sym!(cuModuleGetFunction, FnModuleGetFunction),
cuMemAlloc: load_sym!(cuMemAlloc_v2, FnMemAlloc),
cuMemFree: load_sym!(cuMemFree_v2, FnMemFree),
cuMemcpyHtoD: load_sym!(cuMemcpyHtoD_v2, FnMemcpyHtoD),
cuMemcpyDtoH: load_sym!(cuMemcpyDtoH_v2, FnMemcpyDtoH),
cuMemcpyHtoDAsync: load_sym!(cuMemcpyHtoDAsync_v2, FnMemcpyHtoDAsync),
cuMemcpyDtoHAsync: load_sym!(cuMemcpyDtoHAsync_v2, FnMemcpyDtoHAsync),
cuMemcpyDtoD: load_sym!(cuMemcpyDtoD_v2, FnMemcpyDtoD),
cuMemcpyDtoDAsync: load_sym!(cuMemcpyDtoDAsync_v2, FnMemcpyDtoDAsync),
cuMemGetInfo: load_sym!(cuMemGetInfo_v2, FnMemGetInfo),
cuStreamCreate: load_sym!(cuStreamCreate, FnStreamCreate),
cuStreamDestroy: load_sym!(cuStreamDestroy_v2, FnStreamDestroy),
cuStreamSynchronize: load_sym!(cuStreamSynchronize, FnStreamSync),
cuLaunchKernel: load_sym!(cuLaunchKernel, FnLaunchKernel),
cuGraphCreate: load_sym!(cuGraphCreate, FnGraphCreate),
cuGraphDestroy: load_sym!(cuGraphDestroy, FnGraphDestroy),
cuGraphInstantiateWithFlags: load_sym!(
cuGraphInstantiateWithFlags,
FnGraphInstantiate
),
cuGraphExecDestroy: load_sym!(cuGraphExecDestroy, FnGraphExecDestroy),
cuGraphLaunch: load_sym!(cuGraphLaunch, FnGraphLaunch),
cuStreamBeginCapture: load_sym!(cuStreamBeginCapture, FnStreamBeginCapture),
cuStreamEndCapture: load_sym!(cuStreamEndCapture, FnStreamEndCapture),
})
}
}
pub fn check(result: CUresult) -> Result<(), GpuError> {
if result == CUDA_SUCCESS {
Ok(())
} else {
Err(GpuError::CudaDriver(cuda_error_string(result).to_string(), result))
}
}
}
}
#[cfg(not(feature = "cuda"))]
mod loading {
use super::*;
impl CudaDriver {
#[must_use]
pub fn load() -> Option<&'static Self> {
None
}
pub fn check(_result: CUresult) -> Result<(), GpuError> {
Err(GpuError::CudaNotAvailable("cuda feature not enabled".to_string()))
}
}
}
#[must_use]
pub fn cuda_error_string(code: CUresult) -> &'static str {
match code {
CUDA_SUCCESS => "CUDA_SUCCESS",
CUDA_ERROR_INVALID_VALUE => "CUDA_ERROR_INVALID_VALUE",
CUDA_ERROR_OUT_OF_MEMORY => "CUDA_ERROR_OUT_OF_MEMORY",
CUDA_ERROR_NOT_INITIALIZED => "CUDA_ERROR_NOT_INITIALIZED",
CUDA_ERROR_DEINITIALIZED => "CUDA_ERROR_DEINITIALIZED",
CUDA_ERROR_NO_DEVICE => "CUDA_ERROR_NO_DEVICE",
CUDA_ERROR_INVALID_DEVICE => "CUDA_ERROR_INVALID_DEVICE",
CUDA_ERROR_INVALID_IMAGE => "CUDA_ERROR_INVALID_IMAGE",
CUDA_ERROR_INVALID_CONTEXT => "CUDA_ERROR_INVALID_CONTEXT",
CUDA_ERROR_NO_BINARY_FOR_GPU => "CUDA_ERROR_NO_BINARY_FOR_GPU",
CUDA_ERROR_INVALID_PTX => "CUDA_ERROR_INVALID_PTX",
CUDA_ERROR_NOT_FOUND => "CUDA_ERROR_NOT_FOUND",
CUDA_ERROR_NOT_READY => "CUDA_ERROR_NOT_READY",
CUDA_ERROR_ILLEGAL_ADDRESS => "CUDA_ERROR_ILLEGAL_ADDRESS",
CUDA_ERROR_ILLEGAL_INSTRUCTION => "CUDA_ERROR_ILLEGAL_INSTRUCTION",
CUDA_ERROR_LAUNCH_FAILED => "CUDA_ERROR_LAUNCH_FAILED",
_ => "CUDA_ERROR_UNKNOWN",
}
}
#[cfg(test)]
mod tests;
#[cfg(test)]
mod proptests;