use std::ffi::{c_int, c_void, CString};
use std::sync::Arc;
use crate::cudarc_shim::{ctx, device, module, pool, primary_ctx, stream};
use crate::error::*;
use crate::init;
#[derive(Clone, Copy, Debug)]
pub struct LaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_mem_bytes: u32,
}
#[derive(Debug)]
pub struct Device {
pub(crate) cu_device: cuda_bindings::CUdevice,
pub(crate) cu_ctx: cuda_bindings::CUcontext,
pub(crate) ordinal: usize,
owned: bool,
}
unsafe impl Send for Device {}
unsafe impl Sync for Device {}
impl Drop for Device {
fn drop(&mut self) {
if !self.owned {
return;
}
let _ = self.bind_to_thread();
let ctx = std::mem::replace(&mut self.cu_ctx, std::ptr::null_mut());
if !ctx.is_null() {
let _ = unsafe { primary_ctx::release(self.cu_device) };
}
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.cu_device == other.cu_device
&& self.cu_ctx == other.cu_ctx
&& self.ordinal == other.ordinal
}
}
impl Eq for Device {}
impl Device {
pub fn new(ordinal: usize) -> Result<Arc<Self>, DriverError> {
unsafe { init(0)? };
let cu_device = device::get(ordinal as c_int)?;
let cu_ctx = unsafe { primary_ctx::retain(cu_device) }?;
let device = Arc::new(Device {
cu_device,
cu_ctx,
ordinal,
owned: true,
});
device.bind_to_thread()?;
Ok(device)
}
pub unsafe fn borrow_raw(cu_ctx: *mut c_void, cu_device: c_int, ordinal: usize) -> Arc<Self> {
Arc::new(Device {
cu_device: cu_device as cuda_bindings::CUdevice,
cu_ctx: cu_ctx as cuda_bindings::CUcontext,
ordinal,
owned: false,
})
}
pub fn device_count() -> Result<i32, DriverError> {
unsafe { init(0)? };
device::get_count()
}
pub fn raw_device(ordinal: usize) -> Result<cuda_bindings::CUdevice, DriverError> {
unsafe { init(0)? };
device::get(ordinal as c_int)
}
pub fn ordinal(&self) -> usize {
self.ordinal
}
pub fn name(&self) -> Result<String, DriverError> {
device::get_name(self.cu_device)
}
pub fn cu_device(&self) -> cuda_bindings::CUdevice {
self.cu_device
}
pub fn cu_ctx(&self) -> cuda_bindings::CUcontext {
self.cu_ctx
}
pub fn bind_to_thread(&self) -> Result<(), DriverError> {
if match ctx::get_current()? {
Some(curr_ctx) => curr_ctx != self.cu_ctx,
None => true,
} {
unsafe { ctx::set_current(self.cu_ctx) }?;
}
Ok(())
}
pub unsafe fn synchronize(&self) -> Result<(), DriverError> {
ctx::synchronize()
}
pub fn new_stream(self: &Arc<Self>) -> Result<Arc<Stream>, DriverError> {
self.bind_to_thread()?;
let cu_stream = stream::create(stream::StreamKind::NonBlocking)?;
Ok(Arc::new(Stream {
cu_stream,
device: self.clone(),
owned: true,
}))
}
pub fn load_module_from_ptx_src(
self: &Arc<Self>,
ptx_src: &str,
) -> Result<Arc<Module>, DriverError> {
self.bind_to_thread()?;
let cu_module = {
let c_src = CString::new(ptx_src).unwrap();
unsafe { module::load_data(c_src.as_ptr() as *const _) }
}?;
Ok(Arc::new(Module {
cu_module,
device: self.clone(),
owned: true,
}))
}
pub fn load_module_from_file(
self: &Arc<Self>,
filename: &str,
) -> Result<Arc<Module>, DriverError> {
self.bind_to_thread()?;
let cu_module = { module::load(filename) }?;
Ok(Arc::new(Module {
cu_module,
device: self.clone(),
owned: true,
}))
}
pub fn new_mem_pool(self: &Arc<Self>) -> Result<Arc<MemPool>, DriverError> {
self.bind_to_thread()?;
let mut props: cuda_bindings::CUmemPoolProps = unsafe { std::mem::zeroed() };
props.allocType = cuda_bindings::CUmemAllocationType_enum_CU_MEM_ALLOCATION_TYPE_PINNED;
props.handleTypes = cuda_bindings::CUmemAllocationHandleType_enum_CU_MEM_HANDLE_TYPE_NONE;
props.location.type_ = cuda_bindings::CUmemLocationType_enum_CU_MEM_LOCATION_TYPE_DEVICE;
cuda_bindings::set_mem_location_id(&mut props.location, self.ordinal as c_int);
let cu_pool = unsafe { pool::create(&props) }?;
Ok(Arc::new(MemPool {
cu_pool,
device: self.clone(),
owned: true,
}))
}
pub fn default_mem_pool(self: &Arc<Self>) -> Result<Arc<MemPool>, DriverError> {
self.bind_to_thread()?;
let cu_pool = unsafe { pool::get_default(self.cu_device) }?;
Ok(Arc::new(MemPool {
cu_pool,
device: self.clone(),
owned: false,
}))
}
}
#[derive(Debug)]
pub struct MemPool {
pub(crate) cu_pool: cuda_bindings::CUmemoryPool,
pub(crate) device: Arc<Device>,
owned: bool,
}
unsafe impl Send for MemPool {}
unsafe impl Sync for MemPool {}
impl Drop for MemPool {
fn drop(&mut self) {
if !self.owned {
return;
}
let _ = self.device.bind_to_thread();
let _ = unsafe { pool::destroy(self.cu_pool) };
}
}
impl MemPool {
pub fn cu_pool(&self) -> cuda_bindings::CUmemoryPool {
self.cu_pool
}
pub fn device(&self) -> &Arc<Device> {
&self.device
}
pub fn set_release_threshold(&self, threshold: u64) -> Result<(), DriverError> {
self.device.bind_to_thread()?;
unsafe { pool::set_release_threshold(self.cu_pool, threshold) }
}
pub fn mem_stats(&self) -> Result<PoolMemStats, DriverError> {
self.device.bind_to_thread()?;
unsafe {
Ok(PoolMemStats {
used_current: pool::get_attribute_u64(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_USED_MEM_CURRENT,
)?,
used_high: pool::get_attribute_u64(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_USED_MEM_HIGH,
)?,
reserved_current: pool::get_attribute_u64(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT,
)?,
reserved_high: pool::get_attribute_u64(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH,
)?,
})
}
}
pub fn reset_used_high(&self) -> Result<(), DriverError> {
self.device.bind_to_thread()?;
unsafe {
pool::reset_high_watermark(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_USED_MEM_HIGH,
)
}
}
pub fn reset_reserved_high(&self) -> Result<(), DriverError> {
self.device.bind_to_thread()?;
unsafe {
pool::reset_high_watermark(
self.cu_pool,
cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH,
)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PoolMemStats {
pub used_current: u64,
pub used_high: u64,
pub reserved_current: u64,
pub reserved_high: u64,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Stream {
pub(crate) cu_stream: cuda_bindings::CUstream,
pub(crate) device: Arc<Device>,
owned: bool,
}
unsafe impl Send for Stream {}
unsafe impl Sync for Stream {}
impl Drop for Stream {
fn drop(&mut self) {
if !self.owned {
return;
}
let _ = self.device.bind_to_thread();
if !self.cu_stream.is_null() {
let _ = unsafe { stream::destroy(self.cu_stream) };
}
}
}
impl Stream {
pub unsafe fn borrow_raw(cu_stream: *mut c_void, device: &Arc<Device>) -> Arc<Self> {
Arc::new(Stream {
cu_stream: cu_stream as cuda_bindings::CUstream,
device: device.clone(),
owned: false,
})
}
pub fn cu_stream(&self) -> cuda_bindings::CUstream {
self.cu_stream
}
pub fn device(&self) -> &Arc<Device> {
&self.device
}
pub unsafe fn synchronize(&self) -> Result<(), DriverError> {
stream::synchronize(self.cu_stream)
}
pub unsafe fn launch_host_function<F: FnOnce() + Send>(
&self,
host_func: F,
) -> Result<(), DriverError> {
let boxed_host_func = Box::new(host_func);
stream::launch_host_function(
self.cu_stream,
Self::callback_wrapper::<F>,
Box::into_raw(boxed_host_func) as *mut c_void,
)
}
unsafe extern "C" fn callback_wrapper<F: FnOnce() + Send>(callback: *mut c_void) {
let _ = std::panic::catch_unwind(|| {
let callback: Box<F> = Box::from_raw(callback as *mut F);
callback();
});
}
pub unsafe fn begin_capture(
&self,
mode: cuda_bindings::CUstreamCaptureMode,
) -> Result<(), DriverError> {
stream::begin_capture(self.cu_stream, mode)
}
pub unsafe fn end_capture(&self) -> Result<cuda_bindings::CUgraph, DriverError> {
stream::end_capture(self.cu_stream)
}
}
#[derive(Debug)]
pub struct Module {
pub(crate) cu_module: cuda_bindings::CUmodule,
pub(crate) device: Arc<Device>,
owned: bool,
}
unsafe impl Send for Module {}
unsafe impl Sync for Module {}
impl Drop for Module {
fn drop(&mut self) {
if !self.owned {
return;
}
let _ = self.device.bind_to_thread();
let _ = unsafe { module::unload(self.cu_module) };
}
}
impl Module {
pub unsafe fn borrow_raw(cu_module: *mut c_void, device: &Arc<Device>) -> Arc<Self> {
Arc::new(Module {
cu_module: cu_module as cuda_bindings::CUmodule,
device: device.clone(),
owned: false,
})
}
pub fn cu_module(&self) -> cuda_bindings::CUmodule {
self.cu_module
}
pub fn load_function(self: &Arc<Self>, fn_name: &str) -> Result<Function, DriverError> {
let cu_function = unsafe { module::get_function(self.cu_module, fn_name) }?;
Ok(Function {
cu_function,
module: self.clone(),
})
}
}
#[derive(Debug, Clone)]
pub struct Function {
pub(crate) cu_function: cuda_bindings::CUfunction,
#[allow(unused)]
pub(crate) module: Arc<Module>,
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
impl Function {
pub unsafe fn borrow_raw(cu_function: *mut c_void, module: &Arc<Module>) -> Function {
Function {
cu_function: cu_function as cuda_bindings::CUfunction,
module: module.clone(),
}
}
pub unsafe fn cu_function(&self) -> cuda_bindings::CUfunction {
self.cu_function
}
}