use core::ffi::{c_char, c_void};
use std::ffi::CString;
use std::sync::Arc;
use baracuda_cuda_sys::runtime::{cudaKernel_t, cudaLibrary_t, runtime};
use baracuda_types::{supports, CudaVersion, Feature};
use crate::error::{check, Error, Result};
#[derive(Clone)]
pub struct Library {
inner: Arc<LibraryInner>,
}
struct LibraryInner {
handle: cudaLibrary_t,
}
unsafe impl Send for LibraryInner {}
unsafe impl Sync for LibraryInner {}
impl core::fmt::Debug for LibraryInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Library")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for Library {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl Library {
pub fn load_raw(image: &[u8]) -> Result<Self> {
let installed = crate::init::driver_version()?;
if !supports(installed, Feature::LibraryManagement) {
return Err(Error::FeatureNotSupported {
api: "cudaLibraryLoadData",
since: Feature::LibraryManagement.required_version(),
});
}
let r = runtime()?;
let cu = r.cuda_library_load_data()?;
let mut lib: cudaLibrary_t = core::ptr::null_mut();
check(unsafe {
cu(
&mut lib,
image.as_ptr() as *const c_void,
core::ptr::null_mut(), core::ptr::null_mut(), 0, core::ptr::null_mut(), core::ptr::null_mut(), 0, )
})?;
Ok(Self {
inner: Arc::new(LibraryInner { handle: lib }),
})
}
pub fn load_ptx(ptx_source: &str) -> Result<Self> {
let c_src = CString::new(ptx_source).map_err(|_| {
Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-runtime",
symbol: "cudaLibraryLoadData(PTX input contained a NUL byte)",
})
})?;
Self::load_raw(c_src.as_bytes_with_nul())
}
pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
let r = runtime()?;
let cu = r.cuda_library_get_kernel()?;
let c_name = CString::new(name).map_err(|_| {
Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-runtime",
symbol: "cudaLibraryGetKernel(kernel name contained a NUL byte)",
})
})?;
let mut kernel: cudaKernel_t = core::ptr::null_mut();
check(unsafe {
cu(
&mut kernel,
self.inner.handle,
c_name.as_ptr() as *const c_char,
)
})?;
Ok(Kernel {
handle: kernel,
_library: self.clone(),
})
}
#[inline]
pub fn as_raw(&self) -> cudaLibrary_t {
self.inner.handle
}
}
impl Drop for LibraryInner {
fn drop(&mut self) {
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_library_unload() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[derive(Clone, Debug)]
pub struct Kernel {
handle: cudaKernel_t,
_library: Library,
}
unsafe impl Send for Kernel {}
unsafe impl Sync for Kernel {}
impl Kernel {
#[inline]
pub fn as_raw(&self) -> cudaKernel_t {
self.handle
}
#[inline]
pub fn as_launch_ptr(&self) -> *const c_void {
self.handle as *const c_void
}
pub fn max_active_blocks_per_multiprocessor(
&self,
block_size: i32,
dynamic_smem_bytes: usize,
) -> Result<i32> {
let r = runtime()?;
let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor()?;
let mut n: core::ffi::c_int = 0;
check(unsafe { cu(&mut n, self.as_launch_ptr(), block_size, dynamic_smem_bytes) })?;
Ok(n)
}
pub fn max_active_blocks_per_multiprocessor_with_flags(
&self,
block_size: i32,
dynamic_smem_bytes: usize,
flags: u32,
) -> Result<i32> {
let r = runtime()?;
let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor_with_flags()?;
let mut n: core::ffi::c_int = 0;
check(unsafe {
cu(
&mut n,
self.as_launch_ptr(),
block_size,
dynamic_smem_bytes,
flags,
)
})?;
Ok(n)
}
pub fn available_dynamic_smem_per_block(
&self,
num_blocks: i32,
block_size: i32,
) -> Result<usize> {
let r = runtime()?;
let cu = r.cuda_occupancy_available_dynamic_smem_per_block()?;
let mut n: usize = 0;
check(unsafe { cu(&mut n, self.as_launch_ptr(), num_blocks, block_size) })?;
Ok(n)
}
pub fn set_attribute(&self, attr: i32, value: i32) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_func_set_attribute()?;
check(unsafe { cu(self.as_launch_ptr(), attr, value) })
}
}
#[allow(dead_code)]
fn require_library_management(installed: CudaVersion) -> Result<()> {
if supports(installed, Feature::LibraryManagement) {
Ok(())
} else {
Err(Error::FeatureNotSupported {
api: "cudaLibraryLoadData",
since: Feature::LibraryManagement.required_version(),
})
}
}