use core::ffi::{c_char, c_void};
use std::ffi::CString;
use std::sync::Arc;
use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUmodule};
use crate::context::Context;
use crate::error::{check, Result};
#[derive(Clone)]
pub struct Module {
inner: Arc<ModuleInner>,
}
struct ModuleInner {
handle: CUmodule,
context: Context,
}
unsafe impl Send for ModuleInner {}
unsafe impl Sync for ModuleInner {}
impl core::fmt::Debug for ModuleInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Module")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for Module {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl Module {
pub fn load_raw(context: &Context, image: &[u8]) -> Result<Self> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_module_load_data()?;
let mut module: CUmodule = core::ptr::null_mut();
check(unsafe { cu(&mut module, image.as_ptr() as *const c_void) })?;
Ok(Self {
inner: Arc::new(ModuleInner {
handle: module,
context: context.clone(),
}),
})
}
pub fn load_ptx(context: &Context, ptx_source: &str) -> Result<Self> {
let c_src = CString::new(ptx_source).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleLoadData(PTX input contained a NUL byte)",
})
})?;
Self::load_raw(context, c_src.as_bytes_with_nul())
}
pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
let d = driver()?;
let cu = d.cu_module_get_global()?;
let c_name = CString::new(name).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleGetGlobal(name contained a NUL byte)",
})
})?;
let mut dptr = CUdeviceptr(0);
let mut bytes: usize = 0;
check(unsafe {
cu(
&mut dptr,
&mut bytes,
self.inner.handle,
c_name.as_ptr() as *const c_char,
)
})?;
Ok((dptr, bytes))
}
pub fn get_function(&self, name: &str) -> Result<Function> {
let d = driver()?;
let cu = d.cu_module_get_function()?;
let c_name = CString::new(name).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleGetFunction(kernel name contained a NUL byte)",
})
})?;
let mut func: CUfunction = core::ptr::null_mut();
check(unsafe {
cu(
&mut func,
self.inner.handle,
c_name.as_ptr() as *const c_char,
)
})?;
Ok(Function {
handle: func,
_owner: FunctionOwner::Module(self.clone()),
})
}
#[inline]
pub fn as_raw(&self) -> CUmodule {
self.inner.handle
}
pub fn loading_mode() -> Result<i32> {
let d = driver()?;
let cu = d.cu_module_get_loading_mode()?;
let mut mode: core::ffi::c_int = 0;
check(unsafe { cu(&mut mode) })?;
Ok(mode)
}
#[inline]
pub fn context(&self) -> &Context {
&self.inner.context
}
}
impl Drop for ModuleInner {
fn drop(&mut self) {
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_module_unload() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[derive(Clone, Debug)]
pub struct Function {
handle: CUfunction,
_owner: FunctionOwner,
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
enum FunctionOwner {
Module(Module),
Library(crate::library::Library),
}
impl Function {
pub(crate) fn from_raw_with_library(
handle: CUfunction,
library: crate::library::Library,
) -> Self {
Self {
handle,
_owner: FunctionOwner::Library(library),
}
}
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
impl Function {
#[inline]
pub fn as_raw(&self) -> CUfunction {
self.handle
}
#[inline]
pub fn module(&self) -> Option<&Module> {
match &self._owner {
FunctionOwner::Module(m) => Some(m),
FunctionOwner::Library(_) => None,
}
}
pub fn get_attribute(&self, attribute: i32) -> Result<i32> {
let d = driver()?;
let cu = d.cu_func_get_attribute()?;
let mut v: core::ffi::c_int = 0;
check(unsafe { cu(&mut v, attribute, self.handle) })?;
Ok(v)
}
pub fn name(&self) -> Result<String> {
let d = driver()?;
let cu = d.cu_func_get_name()?;
let mut p: *const core::ffi::c_char = core::ptr::null();
check(unsafe { cu(&mut p, self.handle) })?;
if p.is_null() {
return Ok(String::new());
}
let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
Ok(cstr.to_string_lossy().into_owned())
}
pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
let d = driver()?;
let cu = d.cu_func_get_param_info()?;
let mut off: usize = 0;
let mut sz: usize = 0;
check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
Ok((off, sz))
}
pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
let d = driver()?;
let cu = d.cu_func_get_module()?;
let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
check(unsafe { cu(&mut m, self.handle) })?;
Ok(m)
}
pub fn set_attribute(&self, attribute: i32, value: i32) -> Result<()> {
let d = driver()?;
let cu = d.cu_func_set_attribute()?;
check(unsafe { cu(self.handle, attribute, value) })
}
pub fn max_threads_per_block(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::MAX_THREADS_PER_BLOCK)
}
pub fn shared_size_bytes(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::SHARED_SIZE_BYTES)
}
pub fn num_regs(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::NUM_REGS)
}
pub fn local_size_bytes(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::LOCAL_SIZE_BYTES)
}
pub fn ptx_version(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::PTX_VERSION)
}
pub fn binary_version(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::BINARY_VERSION)
}
}