use super::sys::{self};
use core::ffi::{c_uchar, c_uint, c_void, CStr};
use std::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct DriverError(pub sys::CUresult);
impl sys::CUresult {
#[inline]
pub fn result(self) -> Result<(), DriverError> {
match self {
sys::CUresult::CUDA_SUCCESS => Ok(()),
_ => Err(DriverError(self)),
}
}
}
impl DriverError {
pub fn error_name(&self) -> Result<&CStr, DriverError> {
let mut err_str = MaybeUninit::uninit();
unsafe {
sys::cuGetErrorName(self.0, err_str.as_mut_ptr()).result()?;
Ok(CStr::from_ptr(err_str.assume_init()))
}
}
pub fn error_string(&self) -> Result<&CStr, DriverError> {
let mut err_str = MaybeUninit::uninit();
unsafe {
sys::cuGetErrorString(self.0, err_str.as_mut_ptr()).result()?;
Ok(CStr::from_ptr(err_str.assume_init()))
}
}
}
impl std::fmt::Debug for DriverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.error_string() {
Ok(err_str) => f
.debug_tuple("DriverError")
.field(&self.0)
.field(&err_str)
.finish(),
Err(_) => f
.debug_tuple("DriverError")
.field(&self.0)
.field(&"<Failure when calling cuGetErrorString()>")
.finish(),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for DriverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for DriverError {}
pub fn init() -> Result<(), DriverError> {
unsafe { sys::cuInit(0).result() }
}
pub mod device {
use super::{
sys::{self},
DriverError,
};
use std::{
ffi::{c_int, CStr},
mem::MaybeUninit,
string::String,
};
pub fn get(ordinal: c_int) -> Result<sys::CUdevice, DriverError> {
let mut dev = MaybeUninit::uninit();
unsafe {
sys::cuDeviceGet(dev.as_mut_ptr(), ordinal).result()?;
Ok(dev.assume_init())
}
}
pub fn get_count() -> Result<c_int, DriverError> {
let mut count = MaybeUninit::uninit();
unsafe {
sys::cuDeviceGetCount(count.as_mut_ptr()).result()?;
Ok(count.assume_init())
}
}
pub unsafe fn total_mem(dev: sys::CUdevice) -> Result<usize, DriverError> {
let mut bytes = MaybeUninit::uninit();
sys::cuDeviceTotalMem_v2(bytes.as_mut_ptr(), dev).result()?;
Ok(bytes.assume_init())
}
pub unsafe fn get_attribute(
dev: sys::CUdevice,
attrib: sys::CUdevice_attribute,
) -> Result<i32, DriverError> {
let mut value = MaybeUninit::uninit();
sys::cuDeviceGetAttribute(value.as_mut_ptr(), attrib, dev).result()?;
Ok(value.assume_init())
}
pub fn get_name(dev: sys::CUdevice) -> Result<String, DriverError> {
const BUF_SIZE: usize = 128;
let mut buf = [0u8; BUF_SIZE];
unsafe {
sys::cuDeviceGetName(buf.as_mut_ptr() as _, BUF_SIZE as _, dev).result()?;
}
let name = CStr::from_bytes_until_nul(&buf).expect("No null byte was present");
Ok(String::from_utf8_lossy(name.to_bytes()).into())
}
pub fn get_uuid(dev: sys::CUdevice) -> Result<sys::CUuuid, DriverError> {
let id: sys::CUuuid;
unsafe {
let mut uuid = MaybeUninit::uninit();
#[cfg(not(any(
feature = "cuda-13000",
feature = "cuda-13010",
feature = "cuda-13020"
)))]
sys::cuDeviceGetUuid(uuid.as_mut_ptr(), dev).result()?;
#[cfg(any(feature = "cuda-13000", feature = "cuda-13010", feature = "cuda-13020"))]
sys::cuDeviceGetUuid_v2(uuid.as_mut_ptr(), dev).result()?;
id = uuid.assume_init();
}
Ok(id)
}
pub unsafe fn get_default_mem_pool(
dev: sys::CUdevice,
) -> Result<sys::CUmemoryPool, DriverError> {
let mut pool = MaybeUninit::uninit();
sys::cuDeviceGetDefaultMemPool(pool.as_mut_ptr(), dev).result()?;
Ok(pool.assume_init())
}
pub unsafe fn get_mem_pool(dev: sys::CUdevice) -> Result<sys::CUmemoryPool, DriverError> {
let mut pool = MaybeUninit::uninit();
sys::cuDeviceGetMemPool(pool.as_mut_ptr(), dev).result()?;
Ok(pool.assume_init())
}
pub unsafe fn set_mem_pool(
dev: sys::CUdevice,
pool: sys::CUmemoryPool,
) -> Result<(), DriverError> {
sys::cuDeviceSetMemPool(dev, pool).result()
}
}
pub mod function {
use super::sys::{self, CUfunc_cache_enum, CUfunction_attribute_enum};
use std::mem::MaybeUninit;
pub unsafe fn get_function_attribute(
f: sys::CUfunction,
attribute: CUfunction_attribute_enum,
) -> Result<i32, super::DriverError> {
let mut value = MaybeUninit::uninit();
unsafe {
sys::cuFuncGetAttribute(value.as_mut_ptr(), attribute, f).result()?;
Ok(value.assume_init())
}
}
pub unsafe fn set_function_attribute(
f: sys::CUfunction,
attribute: CUfunction_attribute_enum,
value: i32,
) -> Result<(), super::DriverError> {
unsafe {
sys::cuFuncSetAttribute(f, attribute, value).result()?;
}
Ok(())
}
pub unsafe fn set_function_cache_config(
f: sys::CUfunction,
attribute: CUfunc_cache_enum,
) -> Result<(), super::DriverError> {
unsafe {
sys::cuFuncSetCacheConfig(f, attribute).result()?;
}
Ok(())
}
}
pub mod occupancy {
use core::{
ffi::{c_int, c_uint},
mem::MaybeUninit,
};
use super::{
sys::{self},
DriverError,
};
pub unsafe fn available_dynamic_shared_mem_per_block(
f: sys::CUfunction,
num_blocks: c_int,
block_size: c_int,
) -> Result<usize, DriverError> {
let mut dynamic_smem_size = MaybeUninit::uninit();
unsafe {
sys::cuOccupancyAvailableDynamicSMemPerBlock(
dynamic_smem_size.as_mut_ptr(),
f,
num_blocks,
block_size,
)
.result()?;
}
Ok(dynamic_smem_size.assume_init())
}
pub unsafe fn max_active_block_per_multiprocessor(
f: sys::CUfunction,
block_size: c_int,
dynamic_smem_size: usize,
) -> Result<i32, DriverError> {
let mut num_blocks = MaybeUninit::uninit();
unsafe {
sys::cuOccupancyMaxActiveBlocksPerMultiprocessor(
num_blocks.as_mut_ptr(),
f,
block_size,
dynamic_smem_size,
)
.result()?;
}
Ok(num_blocks.assume_init())
}
pub unsafe fn max_active_block_per_multiprocessor_with_flags(
f: sys::CUfunction,
block_size: c_int,
dynamic_smem_size: usize,
flags: c_uint,
) -> Result<i32, DriverError> {
let mut num_blocks = MaybeUninit::uninit();
unsafe {
sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
num_blocks.as_mut_ptr(),
f,
block_size,
dynamic_smem_size,
flags,
)
.result()?;
}
Ok(num_blocks.assume_init())
}
pub unsafe fn max_potential_block_size(
f: sys::CUfunction,
block_size_to_dynamic_smem_size: sys::CUoccupancyB2DSize,
dynamic_smem_size: usize,
block_size_limit: c_int,
) -> Result<(i32, i32), DriverError> {
let mut min_grid_size = MaybeUninit::uninit();
let mut block_size = MaybeUninit::uninit();
unsafe {
sys::cuOccupancyMaxPotentialBlockSize(
min_grid_size.as_mut_ptr(),
block_size.as_mut_ptr(),
f,
block_size_to_dynamic_smem_size,
dynamic_smem_size,
block_size_limit,
)
.result()?;
}
Ok((min_grid_size.assume_init(), block_size.assume_init()))
}
pub unsafe fn max_potential_block_size_with_flags(
f: sys::CUfunction,
block_size_to_dynamic_smem_size: sys::CUoccupancyB2DSize,
dynamic_smem_size: usize,
block_size_limit: c_int,
flags: c_uint,
) -> Result<(i32, i32), DriverError> {
let mut min_grid_size = MaybeUninit::uninit();
let mut block_size = MaybeUninit::uninit();
unsafe {
sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
min_grid_size.as_mut_ptr(),
block_size.as_mut_ptr(),
f,
block_size_to_dynamic_smem_size,
dynamic_smem_size,
block_size_limit,
flags,
)
.result()?;
}
Ok((min_grid_size.assume_init(), block_size.assume_init()))
}
}
pub mod primary_ctx {
use super::{
sys::{self},
DriverError,
};
use std::mem::MaybeUninit;
pub unsafe fn retain(dev: sys::CUdevice) -> Result<sys::CUcontext, DriverError> {
let mut ctx = MaybeUninit::uninit();
sys::cuDevicePrimaryCtxRetain(ctx.as_mut_ptr(), dev).result()?;
Ok(ctx.assume_init())
}
pub unsafe fn release(dev: sys::CUdevice) -> Result<(), DriverError> {
sys::cuDevicePrimaryCtxRelease_v2(dev).result()
}
}
pub mod ctx {
use super::{
sys::{self},
DriverError,
};
use std::mem::MaybeUninit;
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090"
))]
pub unsafe fn create_v3(
flags: ::core::ffi::c_uint,
dev: sys::CUdevice,
) -> Result<sys::CUcontext, DriverError> {
let mut ctx = MaybeUninit::uninit();
sys::cuCtxCreate_v3(ctx.as_mut_ptr(), std::ptr::null_mut(), 0, flags, dev).result()?;
Ok(ctx.assume_init())
}
#[cfg(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010"
))]
pub unsafe fn create_v4(
ctx_create_params: *mut sys::CUctxCreateParams,
flags: ::core::ffi::c_uint,
dev: sys::CUdevice,
) -> Result<sys::CUcontext, DriverError> {
let mut ctx = MaybeUninit::uninit();
sys::cuCtxCreate_v4(ctx.as_mut_ptr(), ctx_create_params, flags, dev).result()?;
Ok(ctx.assume_init())
}
pub unsafe fn set_current(ctx: sys::CUcontext) -> Result<(), DriverError> {
sys::cuCtxSetCurrent(ctx).result()
}
pub fn get_current() -> Result<Option<sys::CUcontext>, DriverError> {
let mut ctx = MaybeUninit::uninit();
unsafe {
sys::cuCtxGetCurrent(ctx.as_mut_ptr()).result()?;
let ctx: sys::CUcontext = ctx.assume_init();
if ctx.is_null() {
Ok(None)
} else {
Ok(Some(ctx))
}
}
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000"
)))]
pub fn set_flags(flags: sys::CUctx_flags) -> Result<(), DriverError> {
unsafe { sys::cuCtxSetFlags(flags as u32).result() }
}
pub fn synchronize() -> Result<(), DriverError> {
unsafe { sys::cuCtxSynchronize() }.result()
}
pub fn get_limit(limit: sys::CUlimit) -> Result<usize, DriverError> {
let mut value = MaybeUninit::uninit();
unsafe {
sys::cuCtxGetLimit(value.as_mut_ptr(), limit).result()?;
Ok(value.assume_init())
}
}
pub fn set_limit(limit: sys::CUlimit, value: usize) -> Result<(), DriverError> {
unsafe { sys::cuCtxSetLimit(limit, value).result() }
}
pub fn get_cache_config() -> Result<sys::CUfunc_cache, DriverError> {
let mut config = MaybeUninit::uninit();
unsafe {
sys::cuCtxGetCacheConfig(config.as_mut_ptr()).result()?;
Ok(config.assume_init())
}
}
pub fn set_cache_config(config: sys::CUfunc_cache) -> Result<(), DriverError> {
unsafe { sys::cuCtxSetCacheConfig(config).result() }
}
}
pub mod stream {
use super::{
sys::{self},
DriverError,
};
use std::mem::MaybeUninit;
pub enum StreamKind {
Default,
NonBlocking,
}
impl StreamKind {
fn flags(self) -> sys::CUstream_flags {
match self {
Self::Default => sys::CUstream_flags::CU_STREAM_DEFAULT,
Self::NonBlocking => sys::CUstream_flags::CU_STREAM_NON_BLOCKING,
}
}
}
pub fn null() -> sys::CUstream {
std::ptr::null_mut()
}
pub fn create(kind: StreamKind) -> Result<sys::CUstream, DriverError> {
let mut stream = MaybeUninit::uninit();
unsafe {
sys::cuStreamCreate(stream.as_mut_ptr(), kind.flags() as u32).result()?;
Ok(stream.assume_init())
}
}
pub unsafe fn synchronize(stream: sys::CUstream) -> Result<(), DriverError> {
sys::cuStreamSynchronize(stream).result()
}
pub unsafe fn destroy(stream: sys::CUstream) -> Result<(), DriverError> {
sys::cuStreamDestroy_v2(stream).result()
}
pub unsafe fn wait_event(
stream: sys::CUstream,
event: sys::CUevent,
flags: sys::CUevent_wait_flags,
) -> Result<(), DriverError> {
sys::cuStreamWaitEvent(stream, event, flags as u32).result()
}
pub unsafe fn attach_mem_async(
stream: sys::CUstream,
dptr: sys::CUdeviceptr,
num_bytes: usize,
flags: sys::CUmemAttach_flags,
) -> Result<(), DriverError> {
sys::cuStreamAttachMemAsync(stream, dptr, num_bytes, flags as u32).result()
}
pub unsafe fn launch_host_function(
stream: sys::CUstream,
func: unsafe extern "C" fn(*mut ::core::ffi::c_void),
arg: *mut std::ffi::c_void,
) -> Result<(), DriverError> {
sys::cuLaunchHostFunc(stream, Some(func), arg).result()
}
pub unsafe fn begin_capture(
stream: sys::CUstream,
mode: sys::CUstreamCaptureMode,
) -> Result<(), DriverError> {
sys::cuStreamBeginCapture_v2(stream, mode).result()
}
pub unsafe fn end_capture(stream: sys::CUstream) -> Result<sys::CUgraph, DriverError> {
let mut graph = MaybeUninit::uninit();
sys::cuStreamEndCapture(stream, graph.as_mut_ptr()).result()?;
Ok(graph.assume_init())
}
pub unsafe fn is_capturing(
stream: sys::CUstream,
) -> Result<sys::CUstreamCaptureStatus, DriverError> {
let mut status = MaybeUninit::uninit();
sys::cuStreamIsCapturing(stream, status.as_mut_ptr()).result()?;
Ok(status.assume_init())
}
}
pub unsafe fn malloc_async(
stream: sys::CUstream,
num_bytes: usize,
) -> Result<sys::CUdeviceptr, DriverError> {
let mut dev_ptr = MaybeUninit::uninit();
sys::cuMemAllocAsync(dev_ptr.as_mut_ptr(), num_bytes, stream).result()?;
Ok(dev_ptr.assume_init())
}
pub unsafe fn malloc_sync(num_bytes: usize) -> Result<sys::CUdeviceptr, DriverError> {
let mut dev_ptr = MaybeUninit::uninit();
sys::cuMemAlloc_v2(dev_ptr.as_mut_ptr(), num_bytes).result()?;
Ok(dev_ptr.assume_init())
}
pub unsafe fn malloc_managed(
num_bytes: usize,
flags: sys::CUmemAttach_flags,
) -> Result<sys::CUdeviceptr, DriverError> {
let mut dev_ptr = MaybeUninit::uninit();
sys::cuMemAllocManaged(dev_ptr.as_mut_ptr(), num_bytes, flags as u32).result()?;
Ok(dev_ptr.assume_init())
}
pub unsafe fn malloc_host(num_bytes: usize, flags: c_uint) -> Result<*mut c_void, DriverError> {
let mut host_ptr = MaybeUninit::uninit();
sys::cuMemHostAlloc(host_ptr.as_mut_ptr(), num_bytes, flags).result()?;
Ok(host_ptr.assume_init())
}
pub unsafe fn free_host(host_ptr: *mut c_void) -> Result<(), DriverError> {
sys::cuMemFreeHost(host_ptr).result()
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010"
)))]
pub unsafe fn mem_advise(
dptr: sys::CUdeviceptr,
num_bytes: usize,
advice: sys::CUmem_advise,
location: sys::CUmemLocation,
) -> Result<(), DriverError> {
sys::cuMemAdvise_v2(dptr, num_bytes, advice, location).result()
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010"
)))]
pub unsafe fn mem_prefetch_async(
dptr: sys::CUdeviceptr,
num_bytes: usize,
location: sys::CUmemLocation,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemPrefetchAsync_v2(dptr, num_bytes, location, 0, stream).result()
}
pub unsafe fn free_async(dptr: sys::CUdeviceptr, stream: sys::CUstream) -> Result<(), DriverError> {
sys::cuMemFreeAsync(dptr, stream).result()
}
pub unsafe fn free_sync(dptr: sys::CUdeviceptr) -> Result<(), DriverError> {
sys::cuMemFree_v2(dptr).result()
}
pub unsafe fn memory_free(device_ptr: sys::CUdeviceptr) -> Result<(), DriverError> {
sys::cuMemFree_v2(device_ptr).result()
}
pub unsafe fn memset_d8_async(
dptr: sys::CUdeviceptr,
uc: c_uchar,
num_bytes: usize,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemsetD8Async(dptr, uc, num_bytes, stream).result()
}
pub unsafe fn memset_d8_sync(
dptr: sys::CUdeviceptr,
uc: c_uchar,
num_bytes: usize,
) -> Result<(), DriverError> {
sys::cuMemsetD8_v2(dptr, uc, num_bytes).result()
}
pub unsafe fn memcpy_htod_async<T>(
dst: sys::CUdeviceptr,
src: &[T],
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemcpyHtoDAsync_v2(
dst,
src.as_ptr() as *const _,
std::mem::size_of_val(src),
stream,
)
.result()
}
pub unsafe fn memcpy_htod_sync<T>(dst: sys::CUdeviceptr, src: &[T]) -> Result<(), DriverError> {
sys::cuMemcpyHtoD_v2(dst, src.as_ptr() as *const _, std::mem::size_of_val(src)).result()
}
pub unsafe fn memcpy_dtoh_async<T>(
dst: &mut [T],
src: sys::CUdeviceptr,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemcpyDtoHAsync_v2(
dst.as_mut_ptr() as *mut _,
src,
std::mem::size_of_val(dst),
stream,
)
.result()
}
pub unsafe fn memcpy_dtoh_sync<T>(dst: &mut [T], src: sys::CUdeviceptr) -> Result<(), DriverError> {
sys::cuMemcpyDtoH_v2(dst.as_mut_ptr() as *mut _, src, std::mem::size_of_val(dst)).result()
}
pub unsafe fn memcpy_dtod_async(
dst: sys::CUdeviceptr,
src: sys::CUdeviceptr,
num_bytes: usize,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemcpyDtoDAsync_v2(dst, src, num_bytes, stream).result()
}
pub unsafe fn memcpy_dtod_sync(
dst: sys::CUdeviceptr,
src: sys::CUdeviceptr,
num_bytes: usize,
) -> Result<(), DriverError> {
sys::cuMemcpyDtoD_v2(dst, src, num_bytes).result()
}
pub unsafe fn memcpy_peer_async(
dst_ctx: sys::CUcontext,
dst: sys::CUdeviceptr,
src_ctx: sys::CUcontext,
src: sys::CUdeviceptr,
num_bytes: usize,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuMemcpyPeerAsync(dst, dst_ctx, src, src_ctx, num_bytes, stream).result()
}
pub fn mem_get_info() -> Result<(usize, usize), DriverError> {
let mut free = 0;
let mut total = 0;
unsafe { sys::cuMemGetInfo_v2(&mut free as *mut _, &mut total as *mut _) }.result()?;
Ok((free, total))
}
pub mod module {
use super::{
sys::{self},
DriverError,
};
use core::ffi::c_void;
use std::ffi::CString;
use std::mem::MaybeUninit;
pub fn load(fname: CString) -> Result<sys::CUmodule, DriverError> {
let fname_ptr = fname.as_c_str().as_ptr();
let mut module = MaybeUninit::uninit();
unsafe {
sys::cuModuleLoad(module.as_mut_ptr(), fname_ptr).result()?;
Ok(module.assume_init())
}
}
pub unsafe fn load_data(image: *const c_void) -> Result<sys::CUmodule, DriverError> {
let mut module = MaybeUninit::uninit();
sys::cuModuleLoadData(module.as_mut_ptr(), image).result()?;
Ok(module.assume_init())
}
pub unsafe fn get_function(
module: sys::CUmodule,
name: CString,
) -> Result<sys::CUfunction, DriverError> {
let name_ptr = name.as_c_str().as_ptr();
let mut func = MaybeUninit::uninit();
sys::cuModuleGetFunction(func.as_mut_ptr(), module, name_ptr).result()?;
Ok(func.assume_init())
}
pub unsafe fn get_global(
module: sys::CUmodule,
name: CString,
) -> Result<(sys::CUdeviceptr, usize), DriverError> {
let name_ptr = name.as_c_str().as_ptr();
let mut dptr = MaybeUninit::uninit();
let mut bytes = MaybeUninit::uninit();
sys::cuModuleGetGlobal_v2(dptr.as_mut_ptr(), bytes.as_mut_ptr(), module, name_ptr)
.result()?;
Ok((dptr.assume_init(), bytes.assume_init()))
}
pub unsafe fn unload(module: sys::CUmodule) -> Result<(), DriverError> {
sys::cuModuleUnload(module).result()
}
}
pub mod event {
use super::{
sys::{self},
DriverError,
};
use std::mem::MaybeUninit;
pub fn create(flags: sys::CUevent_flags) -> Result<sys::CUevent, DriverError> {
let mut event = MaybeUninit::uninit();
unsafe {
sys::cuEventCreate(event.as_mut_ptr(), flags as u32).result()?;
Ok(event.assume_init())
}
}
pub unsafe fn record(event: sys::CUevent, stream: sys::CUstream) -> Result<(), DriverError> {
unsafe { sys::cuEventRecord(event, stream).result() }
}
pub unsafe fn elapsed(start: sys::CUevent, end: sys::CUevent) -> Result<f32, DriverError> {
let mut ms: f32 = 0.0;
unsafe {
#[cfg(not(any(
feature = "cuda-13000",
feature = "cuda-13010",
feature = "cuda-13020"
)))]
sys::cuEventElapsedTime((&mut ms) as *mut _, start, end).result()?;
#[cfg(any(feature = "cuda-13000", feature = "cuda-13010", feature = "cuda-13020"))]
sys::cuEventElapsedTime_v2((&mut ms) as *mut _, start, end).result()?;
}
Ok(ms)
}
pub unsafe fn query(event: sys::CUevent) -> Result<(), DriverError> {
unsafe { sys::cuEventQuery(event).result() }
}
pub unsafe fn synchronize(event: sys::CUevent) -> Result<(), DriverError> {
unsafe { sys::cuEventSynchronize(event).result() }
}
pub unsafe fn destroy(event: sys::CUevent) -> Result<(), DriverError> {
sys::cuEventDestroy_v2(event).result()
}
}
#[inline]
pub unsafe fn launch_kernel(
f: sys::CUfunction,
grid_dim: (c_uint, c_uint, c_uint),
block_dim: (c_uint, c_uint, c_uint),
shared_mem_bytes: c_uint,
stream: sys::CUstream,
kernel_params: &mut [*mut c_void],
) -> Result<(), DriverError> {
sys::cuLaunchKernel(
f,
grid_dim.0,
grid_dim.1,
grid_dim.2,
block_dim.0,
block_dim.1,
block_dim.2,
shared_mem_bytes,
stream,
kernel_params.as_mut_ptr(),
std::ptr::null_mut(),
)
.result()
}
#[inline]
pub unsafe fn launch_cooperative_kernel(
f: sys::CUfunction,
grid_dim: (c_uint, c_uint, c_uint),
block_dim: (c_uint, c_uint, c_uint),
shared_mem_bytes: c_uint,
stream: sys::CUstream,
kernel_params: &mut [*mut c_void],
) -> Result<(), DriverError> {
sys::cuLaunchCooperativeKernel(
f,
grid_dim.0,
grid_dim.1,
grid_dim.2,
block_dim.0,
block_dim.1,
block_dim.2,
shared_mem_bytes,
stream,
kernel_params.as_mut_ptr(),
)
.result()
}
pub mod external_memory {
use std::mem::MaybeUninit;
use super::{
sys::{self},
DriverError,
};
#[cfg(unix)]
pub unsafe fn import_external_memory_opaque_fd(
fd: std::os::fd::RawFd,
size: u64,
) -> Result<sys::CUexternalMemory, DriverError> {
let mut external_memory = MaybeUninit::uninit();
let handle_description = sys::CUDA_EXTERNAL_MEMORY_HANDLE_DESC {
type_: sys::CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD,
handle: sys::CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st__bindgen_ty_1 { fd },
size,
flags: 0,
reserved: [0; 16],
};
sys::cuImportExternalMemory(external_memory.as_mut_ptr(), &handle_description).result()?;
Ok(external_memory.assume_init())
}
#[cfg(windows)]
pub unsafe fn import_external_memory_opaque_win32(
handle: std::os::windows::io::RawHandle,
size: u64,
) -> Result<sys::CUexternalMemory, DriverError> {
let mut external_memory = MaybeUninit::uninit();
let handle_description = sys::CUDA_EXTERNAL_MEMORY_HANDLE_DESC {
type_: sys::CUexternalMemoryHandleType::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32,
handle: sys::CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st__bindgen_ty_1 {
win32: sys::CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st__bindgen_ty_1__bindgen_ty_1 {
handle,
name: std::ptr::null(),
},
},
size,
flags: 0,
reserved: [0; 16],
};
sys::cuImportExternalMemory(external_memory.as_mut_ptr(), &handle_description).result()?;
Ok(external_memory.assume_init())
}
pub unsafe fn destroy_external_memory(
external_memory: sys::CUexternalMemory,
) -> Result<(), DriverError> {
sys::cuDestroyExternalMemory(external_memory).result()
}
pub unsafe fn get_mapped_buffer(
external_memory: sys::CUexternalMemory,
offset: u64,
size: u64,
) -> Result<sys::CUdeviceptr, DriverError> {
let mut device_ptr = MaybeUninit::uninit();
let buffer_description = sys::CUDA_EXTERNAL_MEMORY_BUFFER_DESC {
offset,
size,
flags: 0,
reserved: [0; 16],
};
sys::cuExternalMemoryGetMappedBuffer(
device_ptr.as_mut_ptr(),
external_memory,
&buffer_description,
)
.result()?;
Ok(device_ptr.assume_init())
}
}
pub mod graph {
use super::*;
pub unsafe fn instantiate(
graph: sys::CUgraph,
flags: sys::CUgraphInstantiate_flags,
) -> Result<sys::CUgraphExec, DriverError> {
let mut graph_exec = MaybeUninit::uninit();
sys::cuGraphInstantiateWithFlags(graph_exec.as_mut_ptr(), graph, flags as u32 as u64)
.result()?;
Ok(graph_exec.assume_init())
}
pub unsafe fn exec_destroy(graph_exec: sys::CUgraphExec) -> Result<(), DriverError> {
sys::cuGraphExecDestroy(graph_exec).result()
}
pub unsafe fn destroy(graph: sys::CUgraph) -> Result<(), DriverError> {
sys::cuGraphDestroy(graph).result()
}
pub unsafe fn launch(
graph_exec: sys::CUgraphExec,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuGraphLaunch(graph_exec, stream).result()
}
pub unsafe fn upload(
graph_exec: sys::CUgraphExec,
stream: sys::CUstream,
) -> Result<(), DriverError> {
sys::cuGraphUpload(graph_exec, stream).result()
}
}
pub mod mem_pool {
use super::*;
pub unsafe fn create(
pool_props: *const sys::CUmemPoolProps,
) -> Result<sys::CUmemoryPool, DriverError> {
let mut pool = MaybeUninit::uninit();
sys::cuMemPoolCreate(pool.as_mut_ptr(), pool_props).result()?;
Ok(pool.assume_init())
}
pub unsafe fn destroy(pool: sys::CUmemoryPool) -> Result<(), DriverError> {
sys::cuMemPoolDestroy(pool).result()
}
pub unsafe fn trim_to(
pool: sys::CUmemoryPool,
min_bytes_to_keep: usize,
) -> Result<(), DriverError> {
sys::cuMemPoolTrimTo(pool, min_bytes_to_keep).result()
}
pub unsafe fn get_attribute(
pool: sys::CUmemoryPool,
attr: sys::CUmemPool_attribute,
value: *mut core::ffi::c_void,
) -> Result<(), DriverError> {
sys::cuMemPoolGetAttribute(pool, attr, value).result()
}
pub unsafe fn set_attribute(
pool: sys::CUmemoryPool,
attr: sys::CUmemPool_attribute,
value: *mut core::ffi::c_void,
) -> Result<(), DriverError> {
sys::cuMemPoolSetAttribute(pool, attr, value).result()
}
pub unsafe fn alloc_async(
pool: sys::CUmemoryPool,
num_bytes: usize,
stream: sys::CUstream,
) -> Result<sys::CUdeviceptr, DriverError> {
let mut dptr = MaybeUninit::uninit();
sys::cuMemAllocFromPoolAsync(dptr.as_mut_ptr(), num_bytes, pool, stream).result()?;
Ok(dptr.assume_init())
}
}
#[cfg(test)]
mod tests {
use super::super::safe::{CudaContext, CudaSlice};
use super::*;
use std::println;
#[test]
fn peer_transfer_contexts() -> Result<(), DriverError> {
let ctx1 = CudaContext::new(0)?;
if device::get_count()? < 2 {
println!("Skip test because not enough cuda devices");
return Ok(());
}
let stream1 = ctx1.default_stream();
let a: CudaSlice<f64> = stream1.alloc_zeros::<f64>(10)?;
let ctx2 = CudaContext::new(1)?;
let stream2 = ctx2.default_stream();
let b = stream2.clone_dtod(&a)?;
let _ = stream1.clone_dtoh(&a)?;
let _ = stream2.clone_dtoh(&b)?;
Ok(())
}
#[test]
fn peer_transfer_self() -> Result<(), DriverError> {
let ctx1 = CudaContext::new(0)?;
let stream1 = ctx1.default_stream();
let a: CudaSlice<f64> = stream1.alloc_zeros::<f64>(10)?;
let ctx2 = CudaContext::new(0)?;
let stream2 = ctx2.default_stream();
let b = stream2.clone_dtod(&a)?;
let _ = stream1.clone_dtoh(&a)?;
let _ = stream2.clone_dtoh(&b)?;
Ok(())
}
#[test]
fn re_associate_context_for_memory_op() -> Result<(), DriverError> {
let ctx1 = CudaContext::new(0)?;
if device::get_count()? < 2 {
println!("Skip test because not enough cuda devices");
return Ok(());
}
let stream1 = ctx1.default_stream();
let a: CudaSlice<f64> = stream1.alloc_zeros::<f64>(10)?;
let _ctx2 = CudaContext::new(1)?;
stream1.clone_dtoh(&a).map(|_| ())
}
}