use std::ffi::{c_char, c_int, c_void};
use std::sync::OnceLock;
use libloading::Library;
use crate::error::{CudaError, CudaResult, DriverLoadError};
use crate::ffi::*;
static DRIVER: OnceLock<Result<DriverApi, DriverLoadError>> = OnceLock::new();
#[cfg(not(target_os = "macos"))]
macro_rules! load_sym {
($lib:expr, $name:literal) => {{
let sym = unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) }.map_err(|e| {
DriverLoadError::SymbolNotFound {
symbol: $name,
reason: e.to_string(),
}
})?;
#[allow(clippy::missing_transmute_annotations)]
let result = unsafe { std::mem::transmute(*sym) };
result
}};
}
#[cfg(not(target_os = "macos"))]
macro_rules! load_sym_optional {
($lib:expr, $name:literal) => {{
match unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) } {
Ok(sym) => {
#[allow(clippy::missing_transmute_annotations)]
let fp = unsafe { std::mem::transmute(*sym) };
Some(fp)
}
Err(_) => {
tracing::debug!(concat!("optional symbol not found: ", $name));
None
}
}
}};
}
pub struct DriverApi {
_lib: Library,
pub cu_init: unsafe extern "C" fn(flags: u32) -> CUresult,
pub cu_driver_get_version: unsafe extern "C" fn(version: *mut c_int) -> CUresult,
pub cu_device_get: unsafe extern "C" fn(device: *mut CUdevice, ordinal: c_int) -> CUresult,
pub cu_device_get_count: unsafe extern "C" fn(count: *mut c_int) -> CUresult,
pub cu_device_get_name:
unsafe extern "C" fn(name: *mut c_char, len: c_int, dev: CUdevice) -> CUresult,
pub cu_device_get_attribute:
unsafe extern "C" fn(pi: *mut c_int, attrib: CUdevice_attribute, dev: CUdevice) -> CUresult,
pub cu_device_total_mem_v2: unsafe extern "C" fn(bytes: *mut usize, dev: CUdevice) -> CUresult,
pub cu_device_can_access_peer:
unsafe extern "C" fn(can_access: *mut c_int, dev: CUdevice, peer_dev: CUdevice) -> CUresult,
pub cu_device_primary_ctx_retain:
unsafe extern "C" fn(pctx: *mut CUcontext, dev: CUdevice) -> CUresult,
pub cu_device_primary_ctx_release_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
pub cu_device_primary_ctx_set_flags_v2:
unsafe extern "C" fn(dev: CUdevice, flags: u32) -> CUresult,
pub cu_device_primary_ctx_get_state:
unsafe extern "C" fn(dev: CUdevice, flags: *mut u32, active: *mut c_int) -> CUresult,
pub cu_device_primary_ctx_reset_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
pub cu_ctx_create_v2:
unsafe extern "C" fn(pctx: *mut CUcontext, flags: u32, dev: CUdevice) -> CUresult,
pub cu_ctx_destroy_v2: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
pub cu_ctx_set_current: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
pub cu_ctx_get_current: unsafe extern "C" fn(pctx: *mut CUcontext) -> CUresult,
pub cu_ctx_synchronize: unsafe extern "C" fn() -> CUresult,
pub cu_module_load_data:
unsafe extern "C" fn(module: *mut CUmodule, image: *const c_void) -> CUresult,
pub cu_module_load_data_ex: unsafe extern "C" fn(
module: *mut CUmodule,
image: *const c_void,
num_options: u32,
options: *mut CUjit_option,
option_values: *mut *mut c_void,
) -> CUresult,
pub cu_module_get_function: unsafe extern "C" fn(
hfunc: *mut CUfunction,
hmod: CUmodule,
name: *const c_char,
) -> CUresult,
pub cu_module_unload: unsafe extern "C" fn(hmod: CUmodule) -> CUresult,
pub cu_mem_alloc_v2: unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult,
pub cu_mem_free_v2: unsafe extern "C" fn(dptr: CUdeviceptr) -> CUresult,
pub cu_memcpy_htod_v2:
unsafe extern "C" fn(dst: CUdeviceptr, src: *const c_void, bytesize: usize) -> CUresult,
pub cu_memcpy_dtoh_v2:
unsafe extern "C" fn(dst: *mut c_void, src: CUdeviceptr, bytesize: usize) -> CUresult,
pub cu_memcpy_dtod_v2:
unsafe extern "C" fn(dst: CUdeviceptr, src: CUdeviceptr, bytesize: usize) -> CUresult,
pub cu_memcpy_htod_async_v2: unsafe extern "C" fn(
dst: CUdeviceptr,
src: *const c_void,
bytesize: usize,
stream: CUstream,
) -> CUresult,
pub cu_memcpy_dtoh_async_v2: unsafe extern "C" fn(
dst: *mut c_void,
src: CUdeviceptr,
bytesize: usize,
stream: CUstream,
) -> CUresult,
pub cu_mem_alloc_host_v2:
unsafe extern "C" fn(pp: *mut *mut c_void, bytesize: usize) -> CUresult,
pub cu_mem_free_host: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
pub cu_mem_alloc_managed:
unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize, flags: u32) -> CUresult,
pub cu_memset_d8_v2:
unsafe extern "C" fn(dst: CUdeviceptr, value: u8, count: usize) -> CUresult,
pub cu_memset_d32_v2:
unsafe extern "C" fn(dst: CUdeviceptr, value: u32, count: usize) -> CUresult,
pub cu_mem_get_info_v2: unsafe extern "C" fn(free: *mut usize, total: *mut usize) -> CUresult,
pub cu_mem_host_register_v2:
unsafe extern "C" fn(p: *mut c_void, bytesize: usize, flags: u32) -> CUresult,
pub cu_mem_host_unregister: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
pub cu_mem_host_get_device_pointer_v2:
unsafe extern "C" fn(pdptr: *mut CUdeviceptr, p: *mut c_void, flags: u32) -> CUresult,
pub cu_pointer_get_attribute:
unsafe extern "C" fn(data: *mut c_void, attribute: u32, ptr: CUdeviceptr) -> CUresult,
pub cu_mem_advise: unsafe extern "C" fn(
dev_ptr: CUdeviceptr,
count: usize,
advice: u32,
device: CUdevice,
) -> CUresult,
pub cu_mem_prefetch_async: unsafe extern "C" fn(
dev_ptr: CUdeviceptr,
count: usize,
dst_device: CUdevice,
hstream: CUstream,
) -> CUresult,
pub cu_stream_create: unsafe extern "C" fn(phstream: *mut CUstream, flags: u32) -> CUresult,
pub cu_stream_create_with_priority:
unsafe extern "C" fn(phstream: *mut CUstream, flags: u32, priority: c_int) -> CUresult,
pub cu_stream_destroy_v2: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
pub cu_stream_synchronize: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
pub cu_stream_wait_event:
unsafe extern "C" fn(hstream: CUstream, hevent: CUevent, flags: u32) -> CUresult,
pub cu_stream_query: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
pub cu_stream_get_priority:
unsafe extern "C" fn(hstream: CUstream, priority: *mut std::ffi::c_int) -> CUresult,
pub cu_stream_get_flags: unsafe extern "C" fn(hstream: CUstream, flags: *mut u32) -> CUresult,
pub cu_event_create: unsafe extern "C" fn(phevent: *mut CUevent, flags: u32) -> CUresult,
pub cu_event_destroy_v2: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
pub cu_event_record: unsafe extern "C" fn(hevent: CUevent, hstream: CUstream) -> CUresult,
pub cu_event_query: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
pub cu_event_synchronize: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
pub cu_event_elapsed_time:
unsafe extern "C" fn(pmilliseconds: *mut f32, hstart: CUevent, hend: CUevent) -> CUresult,
pub cu_memcpy_peer: unsafe extern "C" fn(
dst_device: u64,
dst_ctx: CUcontext,
src_device: u64,
src_ctx: CUcontext,
count: usize,
) -> CUresult,
pub cu_memcpy_peer_async: unsafe extern "C" fn(
dst_device: u64,
dst_ctx: CUcontext,
src_device: u64,
src_ctx: CUcontext,
count: usize,
stream: CUstream,
) -> CUresult,
pub cu_ctx_enable_peer_access:
unsafe extern "C" fn(peer_context: CUcontext, flags: u32) -> CUresult,
pub cu_ctx_disable_peer_access: unsafe extern "C" fn(peer_context: CUcontext) -> CUresult,
#[allow(clippy::type_complexity)]
pub cu_launch_kernel: unsafe extern "C" fn(
f: CUfunction,
grid_dim_x: u32,
grid_dim_y: u32,
grid_dim_z: u32,
block_dim_x: u32,
block_dim_y: u32,
block_dim_z: u32,
shared_mem_bytes: u32,
hstream: CUstream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> CUresult,
#[allow(clippy::type_complexity)]
pub cu_launch_cooperative_kernel: unsafe extern "C" fn(
f: CUfunction,
grid_dim_x: u32,
grid_dim_y: u32,
grid_dim_z: u32,
block_dim_x: u32,
block_dim_y: u32,
block_dim_z: u32,
shared_mem_bytes: u32,
hstream: CUstream,
kernel_params: *mut *mut c_void,
) -> CUresult,
pub cu_launch_cooperative_kernel_multi_device: unsafe extern "C" fn(
launch_params_list: *mut c_void,
num_devices: u32,
flags: u32,
) -> CUresult,
pub cu_occupancy_max_active_blocks_per_multiprocessor: unsafe extern "C" fn(
num_blocks: *mut c_int,
func: CUfunction,
block_size: c_int,
dynamic_smem_size: usize,
) -> CUresult,
#[allow(clippy::type_complexity)]
pub cu_occupancy_max_potential_block_size: unsafe extern "C" fn(
min_grid_size: *mut c_int,
block_size: *mut c_int,
func: CUfunction,
block_size_to_dynamic_smem_size: Option<unsafe extern "C" fn(c_int) -> usize>,
dynamic_smem_size: usize,
block_size_limit: c_int,
) -> CUresult,
pub cu_occupancy_max_active_blocks_per_multiprocessor_with_flags:
unsafe extern "C" fn(
num_blocks: *mut c_int,
func: CUfunction,
block_size: c_int,
dynamic_smem_size: usize,
flags: u32,
) -> CUresult,
pub cu_memcpy_dtod_async_v2: Option<
unsafe extern "C" fn(
dst: CUdeviceptr,
src: CUdeviceptr,
bytesize: usize,
stream: CUstream,
) -> CUresult,
>,
pub cu_memset_d16_v2:
Option<unsafe extern "C" fn(dst: CUdeviceptr, value: u16, count: usize) -> CUresult>,
pub cu_memset_d32_async: Option<
unsafe extern "C" fn(
dst: CUdeviceptr,
value: u32,
count: usize,
stream: CUstream,
) -> CUresult,
>,
pub cu_ctx_get_limit: Option<unsafe extern "C" fn(value: *mut usize, limit: u32) -> CUresult>,
pub cu_ctx_set_limit: Option<unsafe extern "C" fn(limit: u32, value: usize) -> CUresult>,
pub cu_ctx_get_cache_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
pub cu_ctx_set_cache_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
pub cu_ctx_get_shared_mem_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
pub cu_ctx_set_shared_mem_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
pub cu_event_record_with_flags:
Option<unsafe extern "C" fn(hevent: CUevent, hstream: CUstream, flags: u32) -> CUresult>,
pub cu_func_get_attribute: Option<
unsafe extern "C" fn(value: *mut c_int, attrib: c_int, func: CUfunction) -> CUresult,
>,
pub cu_func_set_cache_config:
Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
pub cu_func_set_shared_mem_config:
Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
pub cu_func_set_attribute:
Option<unsafe extern "C" fn(func: CUfunction, attrib: c_int, value: c_int) -> CUresult>,
pub cu_profiler_start: Option<unsafe extern "C" fn() -> CUresult>,
pub cu_profiler_stop: Option<unsafe extern "C" fn() -> CUresult>,
#[allow(clippy::type_complexity)]
pub cu_launch_kernel_ex: Option<
unsafe extern "C" fn(
config: *const CuLaunchConfig,
f: CUfunction,
kernel_params: *mut *mut std::ffi::c_void,
extra: *mut *mut std::ffi::c_void,
) -> CUresult,
>,
#[allow(clippy::type_complexity)]
pub cu_tensor_map_encode_tiled: Option<
unsafe extern "C" fn(
tensor_map: *mut std::ffi::c_void,
tensor_data_type: u32,
tensor_rank: u32,
global_address: *mut std::ffi::c_void,
global_dim: *const u64,
global_strides: *const u64,
box_dim: *const u32,
element_strides: *const u32,
interleave: u32,
swizzle: u32,
l2_promotion: u32,
oob_fill: u32,
) -> CUresult,
>,
#[allow(clippy::type_complexity)]
pub cu_tensor_map_encode_tiled_memref: Option<
unsafe extern "C" fn(
tensor_map: *mut c_void,
tensor_data_type: u32,
tensor_rank: u32,
global_address: *mut c_void,
global_dim: *const u64,
global_strides: *const u64,
box_dim: *const u32,
element_strides: *const u32,
interleave: u32,
swizzle: u32,
l2_promotion: u32,
oob_fill: u32,
flags: u64,
) -> CUresult,
>,
pub cu_kernel_get_library:
Option<unsafe extern "C" fn(p_lib: *mut CUlibrary, kernel: CUkernel) -> CUresult>,
pub cu_multicast_get_granularity: Option<
unsafe extern "C" fn(granularity: *mut usize, desc: *const c_void, option: u32) -> CUresult,
>,
pub cu_multicast_create: Option<
unsafe extern "C" fn(mc_handle: *mut CUmulticastObject, desc: *const c_void) -> CUresult,
>,
pub cu_multicast_add_device:
Option<unsafe extern "C" fn(mc_handle: CUmulticastObject, dev: CUdevice) -> CUresult>,
#[allow(clippy::type_complexity)]
pub cu_memcpy_batch_async: Option<
unsafe extern "C" fn(
dsts: *const *mut c_void,
srcs: *const *const c_void,
sizes: *const usize,
count: u64,
flags: u64,
stream: CUstream,
) -> CUresult,
>,
pub cu_array_create_v2: Option<
unsafe extern "C" fn(
p_handle: *mut CUarray,
p_allocate_array: *const CUDA_ARRAY_DESCRIPTOR,
) -> CUresult,
>,
pub cu_array_destroy: Option<unsafe extern "C" fn(h_array: CUarray) -> CUresult>,
pub cu_array_get_descriptor_v2: Option<
unsafe extern "C" fn(
p_array_descriptor: *mut CUDA_ARRAY_DESCRIPTOR,
h_array: CUarray,
) -> CUresult,
>,
pub cu_array3d_create_v2: Option<
unsafe extern "C" fn(
p_handle: *mut CUarray,
p_allocate_array: *const CUDA_ARRAY3D_DESCRIPTOR,
) -> CUresult,
>,
pub cu_array3d_get_descriptor_v2: Option<
unsafe extern "C" fn(
p_array_descriptor: *mut CUDA_ARRAY3D_DESCRIPTOR,
h_array: CUarray,
) -> CUresult,
>,
pub cu_memcpy_htoa_v2: Option<
unsafe extern "C" fn(
dst_array: CUarray,
dst_offset: usize,
src_host: *const c_void,
byte_count: usize,
) -> CUresult,
>,
pub cu_memcpy_atoh_v2: Option<
unsafe extern "C" fn(
dst_host: *mut c_void,
src_array: CUarray,
src_offset: usize,
byte_count: usize,
) -> CUresult,
>,
pub cu_memcpy_htoa_async_v2: Option<
unsafe extern "C" fn(
dst_array: CUarray,
dst_offset: usize,
src_host: *const c_void,
byte_count: usize,
stream: CUstream,
) -> CUresult,
>,
pub cu_memcpy_atoh_async_v2: Option<
unsafe extern "C" fn(
dst_host: *mut c_void,
src_array: CUarray,
src_offset: usize,
byte_count: usize,
stream: CUstream,
) -> CUresult,
>,
pub cu_tex_object_create: Option<
unsafe extern "C" fn(
p_tex_object: *mut CUtexObject,
p_res_desc: *const CUDA_RESOURCE_DESC,
p_tex_desc: *const CUDA_TEXTURE_DESC,
p_res_view_desc: *const CUDA_RESOURCE_VIEW_DESC,
) -> CUresult,
>,
pub cu_tex_object_destroy: Option<unsafe extern "C" fn(tex_object: CUtexObject) -> CUresult>,
pub cu_tex_object_get_resource_desc: Option<
unsafe extern "C" fn(
p_res_desc: *mut CUDA_RESOURCE_DESC,
tex_object: CUtexObject,
) -> CUresult,
>,
pub cu_surf_object_create: Option<
unsafe extern "C" fn(
p_surf_object: *mut CUsurfObject,
p_res_desc: *const CUDA_RESOURCE_DESC,
) -> CUresult,
>,
pub cu_surf_object_destroy: Option<unsafe extern "C" fn(surf_object: CUsurfObject) -> CUresult>,
}
unsafe impl Send for DriverApi {}
unsafe impl Sync for DriverApi {}
impl DriverApi {
pub fn load() -> Result<Self, DriverLoadError> {
#[cfg(target_os = "macos")]
{
Err(DriverLoadError::UnsupportedPlatform)
}
#[cfg(target_os = "linux")]
let lib_names: &[&str] = &["libcuda.so.1", "libcuda.so"];
#[cfg(target_os = "windows")]
let lib_names: &[&str] = &["nvcuda.dll"];
#[cfg(not(target_os = "macos"))]
{
let lib = Self::load_library(lib_names)?;
let api = Self::load_symbols(lib)?;
let rc = unsafe { (api.cu_init)(0) };
if rc != 0 {
return Err(DriverLoadError::InitializationFailed { code: rc as u32 });
}
Ok(api)
}
}
#[cfg(not(target_os = "macos"))]
fn load_library(names: &[&str]) -> Result<Library, DriverLoadError> {
let mut last_error = String::new();
for name in names {
match unsafe { Library::new(*name) } {
Ok(lib) => {
tracing::debug!("loaded CUDA driver library: {name}");
return Ok(lib);
}
Err(e) => {
tracing::debug!("failed to load {name}: {e}");
last_error = e.to_string();
}
}
}
Err(DriverLoadError::LibraryNotFound {
candidates: names.iter().map(|s| (*s).to_string()).collect(),
last_error,
})
}
#[cfg(not(target_os = "macos"))]
fn load_symbols(lib: Library) -> Result<Self, DriverLoadError> {
Ok(Self {
cu_init: load_sym!(lib, "cuInit"),
cu_driver_get_version: load_sym!(lib, "cuDriverGetVersion"),
cu_device_get: load_sym!(lib, "cuDeviceGet"),
cu_device_get_count: load_sym!(lib, "cuDeviceGetCount"),
cu_device_get_name: load_sym!(lib, "cuDeviceGetName"),
cu_device_get_attribute: load_sym!(lib, "cuDeviceGetAttribute"),
cu_device_total_mem_v2: load_sym!(lib, "cuDeviceTotalMem_v2"),
cu_device_can_access_peer: load_sym!(lib, "cuDeviceCanAccessPeer"),
cu_device_primary_ctx_retain: load_sym!(lib, "cuDevicePrimaryCtxRetain"),
cu_device_primary_ctx_release_v2: load_sym!(lib, "cuDevicePrimaryCtxRelease_v2"),
cu_device_primary_ctx_set_flags_v2: load_sym!(lib, "cuDevicePrimaryCtxSetFlags_v2"),
cu_device_primary_ctx_get_state: load_sym!(lib, "cuDevicePrimaryCtxGetState"),
cu_device_primary_ctx_reset_v2: load_sym!(lib, "cuDevicePrimaryCtxReset_v2"),
cu_ctx_create_v2: load_sym!(lib, "cuCtxCreate_v2"),
cu_ctx_destroy_v2: load_sym!(lib, "cuCtxDestroy_v2"),
cu_ctx_set_current: load_sym!(lib, "cuCtxSetCurrent"),
cu_ctx_get_current: load_sym!(lib, "cuCtxGetCurrent"),
cu_ctx_synchronize: load_sym!(lib, "cuCtxSynchronize"),
cu_module_load_data: load_sym!(lib, "cuModuleLoadData"),
cu_module_load_data_ex: load_sym!(lib, "cuModuleLoadDataEx"),
cu_module_get_function: load_sym!(lib, "cuModuleGetFunction"),
cu_module_unload: load_sym!(lib, "cuModuleUnload"),
cu_mem_alloc_v2: load_sym!(lib, "cuMemAlloc_v2"),
cu_mem_free_v2: load_sym!(lib, "cuMemFree_v2"),
cu_memcpy_htod_v2: load_sym!(lib, "cuMemcpyHtoD_v2"),
cu_memcpy_dtoh_v2: load_sym!(lib, "cuMemcpyDtoH_v2"),
cu_memcpy_dtod_v2: load_sym!(lib, "cuMemcpyDtoD_v2"),
cu_memcpy_htod_async_v2: load_sym!(lib, "cuMemcpyHtoDAsync_v2"),
cu_memcpy_dtoh_async_v2: load_sym!(lib, "cuMemcpyDtoHAsync_v2"),
cu_mem_alloc_host_v2: load_sym!(lib, "cuMemAllocHost_v2"),
cu_mem_free_host: load_sym!(lib, "cuMemFreeHost"),
cu_mem_alloc_managed: load_sym!(lib, "cuMemAllocManaged"),
cu_memset_d8_v2: load_sym!(lib, "cuMemsetD8_v2"),
cu_memset_d32_v2: load_sym!(lib, "cuMemsetD32_v2"),
cu_mem_get_info_v2: load_sym!(lib, "cuMemGetInfo_v2"),
cu_mem_host_register_v2: load_sym!(lib, "cuMemHostRegister_v2"),
cu_mem_host_unregister: load_sym!(lib, "cuMemHostUnregister"),
cu_mem_host_get_device_pointer_v2: load_sym!(lib, "cuMemHostGetDevicePointer_v2"),
cu_pointer_get_attribute: load_sym!(lib, "cuPointerGetAttribute"),
cu_mem_advise: load_sym!(lib, "cuMemAdvise"),
cu_mem_prefetch_async: load_sym!(lib, "cuMemPrefetchAsync"),
cu_stream_create: load_sym!(lib, "cuStreamCreate"),
cu_stream_create_with_priority: load_sym!(lib, "cuStreamCreateWithPriority"),
cu_stream_destroy_v2: load_sym!(lib, "cuStreamDestroy_v2"),
cu_stream_synchronize: load_sym!(lib, "cuStreamSynchronize"),
cu_stream_wait_event: load_sym!(lib, "cuStreamWaitEvent"),
cu_stream_query: load_sym!(lib, "cuStreamQuery"),
cu_stream_get_priority: load_sym!(lib, "cuStreamGetPriority"),
cu_stream_get_flags: load_sym!(lib, "cuStreamGetFlags"),
cu_event_create: load_sym!(lib, "cuEventCreate"),
cu_event_destroy_v2: load_sym!(lib, "cuEventDestroy_v2"),
cu_event_record: load_sym!(lib, "cuEventRecord"),
cu_event_query: load_sym!(lib, "cuEventQuery"),
cu_event_synchronize: load_sym!(lib, "cuEventSynchronize"),
cu_event_elapsed_time: load_sym!(lib, "cuEventElapsedTime"),
cu_event_record_with_flags: load_sym_optional!(lib, "cuEventRecordWithFlags"),
cu_memcpy_peer: load_sym!(lib, "cuMemcpyPeer"),
cu_memcpy_peer_async: load_sym!(lib, "cuMemcpyPeerAsync"),
cu_ctx_enable_peer_access: load_sym!(lib, "cuCtxEnablePeerAccess"),
cu_ctx_disable_peer_access: load_sym!(lib, "cuCtxDisablePeerAccess"),
cu_launch_kernel: load_sym!(lib, "cuLaunchKernel"),
cu_launch_cooperative_kernel: load_sym!(lib, "cuLaunchCooperativeKernel"),
cu_launch_cooperative_kernel_multi_device: load_sym!(
lib,
"cuLaunchCooperativeKernelMultiDevice"
),
cu_occupancy_max_active_blocks_per_multiprocessor: load_sym!(
lib,
"cuOccupancyMaxActiveBlocksPerMultiprocessor"
),
cu_occupancy_max_potential_block_size: load_sym!(
lib,
"cuOccupancyMaxPotentialBlockSize"
),
cu_occupancy_max_active_blocks_per_multiprocessor_with_flags: load_sym!(
lib,
"cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"
),
cu_memcpy_dtod_async_v2: load_sym_optional!(lib, "cuMemcpyDtoDAsync_v2"),
cu_memset_d16_v2: load_sym_optional!(lib, "cuMemsetD16_v2"),
cu_memset_d32_async: load_sym_optional!(lib, "cuMemsetD32Async"),
cu_ctx_get_limit: load_sym_optional!(lib, "cuCtxGetLimit"),
cu_ctx_set_limit: load_sym_optional!(lib, "cuCtxSetLimit"),
cu_ctx_get_cache_config: load_sym_optional!(lib, "cuCtxGetCacheConfig"),
cu_ctx_set_cache_config: load_sym_optional!(lib, "cuCtxSetCacheConfig"),
cu_ctx_get_shared_mem_config: load_sym_optional!(lib, "cuCtxGetSharedMemConfig"),
cu_ctx_set_shared_mem_config: load_sym_optional!(lib, "cuCtxSetSharedMemConfig"),
cu_func_get_attribute: load_sym_optional!(lib, "cuFuncGetAttribute"),
cu_func_set_cache_config: load_sym_optional!(lib, "cuFuncSetCacheConfig"),
cu_func_set_shared_mem_config: load_sym_optional!(lib, "cuFuncSetSharedMemConfig"),
cu_func_set_attribute: load_sym_optional!(lib, "cuFuncSetAttribute"),
cu_profiler_start: load_sym_optional!(lib, "cuProfilerStart"),
cu_profiler_stop: load_sym_optional!(lib, "cuProfilerStop"),
cu_launch_kernel_ex: load_sym_optional!(lib, "cuLaunchKernelEx"),
cu_tensor_map_encode_tiled: load_sym_optional!(lib, "cuTensorMapEncodeTiled"),
cu_tensor_map_encode_tiled_memref: load_sym_optional!(
lib,
"cuTensorMapEncodeTiledMemref"
),
cu_kernel_get_library: load_sym_optional!(lib, "cuKernelGetLibrary"),
cu_multicast_get_granularity: load_sym_optional!(lib, "cuMulticastGetGranularity"),
cu_multicast_create: load_sym_optional!(lib, "cuMulticastCreate"),
cu_multicast_add_device: load_sym_optional!(lib, "cuMulticastAddDevice"),
cu_memcpy_batch_async: load_sym_optional!(lib, "cuMemcpyBatchAsync"),
cu_array_create_v2: load_sym_optional!(lib, "cuArrayCreate_v2"),
cu_array_destroy: load_sym_optional!(lib, "cuArrayDestroy"),
cu_array_get_descriptor_v2: load_sym_optional!(lib, "cuArrayGetDescriptor_v2"),
cu_array3d_create_v2: load_sym_optional!(lib, "cuArray3DCreate_v2"),
cu_array3d_get_descriptor_v2: load_sym_optional!(lib, "cuArray3DGetDescriptor_v2"),
cu_memcpy_htoa_v2: load_sym_optional!(lib, "cuMemcpyHtoA_v2"),
cu_memcpy_atoh_v2: load_sym_optional!(lib, "cuMemcpyAtoH_v2"),
cu_memcpy_htoa_async_v2: load_sym_optional!(lib, "cuMemcpyHtoAAsync_v2"),
cu_memcpy_atoh_async_v2: load_sym_optional!(lib, "cuMemcpyAtoHAsync_v2"),
cu_tex_object_create: load_sym_optional!(lib, "cuTexObjectCreate"),
cu_tex_object_destroy: load_sym_optional!(lib, "cuTexObjectDestroy"),
cu_tex_object_get_resource_desc: load_sym_optional!(lib, "cuTexObjectGetResourceDesc"),
cu_surf_object_create: load_sym_optional!(lib, "cuSurfObjectCreate"),
cu_surf_object_destroy: load_sym_optional!(lib, "cuSurfObjectDestroy"),
_lib: lib,
})
}
}
pub fn try_driver() -> CudaResult<&'static DriverApi> {
let result = DRIVER.get_or_init(DriverApi::load);
match result {
Ok(api) => Ok(api),
Err(_) => Err(CudaError::NotInitialized),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(target_os = "macos")]
#[test]
fn load_returns_unsupported_on_macos() {
let result = DriverApi::load();
assert!(result.is_err(), "expected Err on macOS");
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Err on macOS"),
};
assert!(
matches!(err, DriverLoadError::UnsupportedPlatform),
"expected UnsupportedPlatform, got {err:?}"
);
}
#[cfg(target_os = "macos")]
#[test]
fn try_driver_returns_not_initialized_on_macos() {
let result = try_driver();
assert!(result.is_err(), "expected Err on macOS");
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Err on macOS"),
};
assert!(
matches!(err, CudaError::NotInitialized),
"expected NotInitialized, got {err:?}"
);
}
#[test]
fn driver_v12_8_api_fields_present() {
type TensorMapEncodeTiledFn = unsafe extern "C" fn(
tensor_map: *mut std::ffi::c_void,
tensor_data_type: u32,
tensor_rank: u32,
global_address: *mut std::ffi::c_void,
global_dim: *const u64,
global_strides: *const u64,
box_dim: *const u32,
element_strides: *const u32,
interleave: u32,
swizzle: u32,
l2_promotion: u32,
oob_fill: u32,
flags: u64,
) -> CUresult;
let _none: Option<TensorMapEncodeTiledFn> = None;
let _field_exists = |api: &DriverApi| api.cu_tensor_map_encode_tiled_memref.is_none();
let _ = _none;
let _ = _field_exists;
}
#[test]
fn driver_v12_8_multicast_fields_present() {
let _probe_create = |api: &DriverApi| api.cu_multicast_create.is_none();
let _probe_add = |api: &DriverApi| api.cu_multicast_add_device.is_none();
let _probe_gran = |api: &DriverApi| api.cu_multicast_get_granularity.is_none();
let _ = (_probe_create, _probe_add, _probe_gran);
}
#[test]
fn driver_v12_8_batch_memcpy_field_present() {
let _probe = |api: &DriverApi| api.cu_memcpy_batch_async.is_none();
let _ = _probe;
}
#[test]
fn driver_v12_8_kernel_get_library_field_present() {
let _probe = |api: &DriverApi| api.cu_kernel_get_library.is_none();
let _ = _probe;
}
}