use std::{ffi::CString, ptr, sync::Arc};
use singe_cuda_sys::driver;
use crate::{
context::Context,
error::{Error, Result},
graph::{ExecutableGraph, Graph, GraphNode},
kernel::{self, LibraryKernelHandle},
module::{KernelFunction, KernelParameters, LaunchConfig, Module},
try_cuda,
types::{DeviceFunction, FunctionAttribute, FunctionCache},
};
#[derive(Debug)]
pub struct Library {
handle: driver::CUlibrary,
ctx: Arc<Context>,
}
#[derive(Debug, Clone, Copy)]
pub struct LibraryGlobal<'a> {
ptr: *mut (),
size: usize,
_library: &'a Library,
}
#[derive(Debug, Clone, Copy)]
pub struct LibraryKernel<'a> {
handle: driver::CUkernel,
library: &'a Library,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KernelParamInfo {
pub offset: usize,
pub size: usize,
}
impl Library {
pub const unsafe fn from_raw(handle: driver::CUlibrary, ctx: Arc<Context>) -> Self {
Self { handle, ctx }
}
pub fn kernel(&self, name: &str) -> Result<LibraryKernel<'_>> {
let c_name = CString::new(name)?;
let mut handle = ptr::null_mut();
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetKernel(
&raw mut handle,
self.handle,
c_name.as_ptr(),
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(LibraryKernel {
handle,
library: self,
})
}
pub fn kernel_count(&self) -> Result<usize> {
let mut count = 0;
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetKernelCount(&raw mut count, self.handle))?;
}
Ok(count as usize)
}
pub fn module(&self) -> Result<Module> {
let mut handle = ptr::null_mut();
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetModule(&raw mut handle, self.handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(unsafe { Module::from_borrowed_raw(handle, Arc::clone(&self.ctx)) })
}
pub fn global(&self, name: &str) -> Result<LibraryGlobal<'_>> {
let c_name = CString::new(name)?;
let mut ptr = 0;
let mut size = 0;
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetGlobal(
&raw mut ptr,
&raw mut size,
self.handle,
c_name.as_ptr(),
))?;
}
Ok(LibraryGlobal {
ptr: ptr as *mut (),
size: size as usize,
_library: self,
})
}
pub fn managed(&self, name: &str) -> Result<LibraryGlobal<'_>> {
let c_name = CString::new(name)?;
let mut ptr = 0;
let mut size = 0;
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetManaged(
&raw mut ptr,
&raw mut size,
self.handle,
c_name.as_ptr(),
))?;
}
Ok(LibraryGlobal {
ptr: ptr as *mut (),
size: size as usize,
_library: self,
})
}
pub fn unified_function(&self, symbol: &str) -> Result<*mut ()> {
let c_symbol = CString::new(symbol)?;
let mut ptr = ptr::null_mut();
self.ctx.bind()?;
unsafe {
try_cuda!(driver::cuLibraryGetUnifiedFunction(
&raw mut ptr,
self.handle,
c_symbol.as_ptr(),
))?;
}
if ptr.is_null() {
return Err(Error::NullHandle);
}
Ok(ptr.cast())
}
pub const unsafe fn as_raw(&self) -> driver::CUlibrary {
self.handle
}
}
impl Drop for Library {
fn drop(&mut self) {
if let Err(err) = self.ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind context before unloading library: {err}");
return;
}
unsafe {
if let Err(err) = try_cuda!(driver::cuLibraryUnload(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to unload cuda library: {err}");
}
}
}
}
impl LibraryGlobal<'_> {
pub const fn as_ptr(&self) -> *mut () {
self.ptr
}
pub const fn size(&self) -> usize {
self.size
}
}
impl LibraryKernel<'_> {
pub fn name(&self) -> Result<String> {
kernel::name::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle)
}
pub fn function(&self) -> Result<DeviceFunction> {
self.library.ctx.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(driver::cuKernelGetFunction(&raw mut handle, self.handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(handle.into())
}
pub fn add_to_graph(
&self,
graph: &mut Graph,
dependencies: &[GraphNode],
config: &LaunchConfig,
params: &mut KernelParameters,
) -> Result<GraphNode> {
let function = self.function()?;
let module = self.library.module()?;
let function = unsafe { KernelFunction::from_raw(function, &module) };
function.add_to_graph(graph, dependencies, config, params)
}
pub fn set_graph_node_params(
&self,
executable: &mut ExecutableGraph,
node: GraphNode,
config: &LaunchConfig,
params: &mut KernelParameters,
) -> Result<()> {
let function = self.function()?;
let module = self.library.module()?;
let function = unsafe { KernelFunction::from_raw(function, &module) };
function.set_graph_node_params(executable, node, config, params)
}
pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
kernel::attribute::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle, attribute)
}
pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
kernel::set_attribute::<LibraryKernelHandle>(
self.library.ctx.as_ref(),
self.handle,
attribute,
value,
)
}
pub fn set_cache_config(&self, config: FunctionCache) -> Result<()> {
self.library.ctx.bind()?;
unsafe {
try_cuda!(driver::cuKernelSetCacheConfig(
self.handle,
config.into(),
self.library.ctx.device().id() as _,
))?;
}
Ok(())
}
pub fn param_info(&self, index: usize) -> Result<KernelParamInfo> {
self.library.ctx.bind()?;
let mut offset = 0;
let mut size = 0;
unsafe {
try_cuda!(driver::cuKernelGetParamInfo(
self.handle,
index as _,
&raw mut offset,
&raw mut size,
))?;
}
Ok(KernelParamInfo {
offset: offset as usize,
size: size as usize,
})
}
pub const unsafe fn as_raw(&self) -> driver::CUkernel {
self.handle
}
}