use std::{ffi::CStr, ptr};
use singe_cuda_sys::driver;
use crate::{
context::Context,
error::{Error, Result},
try_ffi,
types::FunctionAttribute,
};
pub trait KernelHandle {
type RawHandle: Copy;
unsafe fn raw_name(
raw: Self::RawHandle,
name: *mut *const i8,
device_id: i32,
) -> driver::CUresult;
unsafe fn raw_attribute(
value: *mut i32,
attribute: driver::CUfunction_attribute,
raw: Self::RawHandle,
device_id: i32,
) -> driver::CUresult;
unsafe fn set_attribute(
raw: Self::RawHandle,
attribute: driver::CUfunction_attribute,
value: i32,
device_id: i32,
) -> driver::CUresult;
}
pub struct ModuleKernelHandle;
impl KernelHandle for ModuleKernelHandle {
type RawHandle = driver::CUfunction;
unsafe fn raw_name(
raw: Self::RawHandle,
name: *mut *const i8,
_device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuFuncGetName(name, raw) }
}
unsafe fn raw_attribute(
value: *mut i32,
attribute: driver::CUfunction_attribute,
raw: Self::RawHandle,
_device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuFuncGetAttribute(value, attribute, raw) }
}
unsafe fn set_attribute(
raw: Self::RawHandle,
attribute: driver::CUfunction_attribute,
value: i32,
_device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuFuncSetAttribute(raw, attribute, value) }
}
}
pub struct LibraryKernelHandle;
impl KernelHandle for LibraryKernelHandle {
type RawHandle = driver::CUkernel;
unsafe fn raw_name(
raw: Self::RawHandle,
name: *mut *const i8,
_device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuKernelGetName(name, raw) }
}
unsafe fn raw_attribute(
value: *mut i32,
attribute: driver::CUfunction_attribute,
raw: Self::RawHandle,
device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuKernelGetAttribute(value, attribute, raw, device_id) }
}
unsafe fn set_attribute(
raw: Self::RawHandle,
attribute: driver::CUfunction_attribute,
value: i32,
device_id: i32,
) -> driver::CUresult {
unsafe { driver::cuKernelSetAttribute(attribute, value, raw, device_id) }
}
}
pub fn name<H: KernelHandle>(ctx: &Context, raw: H::RawHandle) -> Result<String> {
ctx.bind()?;
let mut name = ptr::null();
unsafe {
try_ffi!(H::raw_name(raw, &raw mut name, ctx.device().id()))?;
if name.is_null() {
return Err(Error::NullHandle);
}
Ok(CStr::from_ptr(name).to_string_lossy().into_owned())
}
}
pub fn attribute<H: KernelHandle>(
ctx: &Context,
raw: H::RawHandle,
attribute: FunctionAttribute,
) -> Result<i32> {
ctx.bind()?;
let mut value = 0;
unsafe {
try_ffi!(H::raw_attribute(
&raw mut value,
attribute.into(),
raw,
ctx.device().id(),
))?;
}
Ok(value)
}
pub fn set_attribute<H: KernelHandle>(
ctx: &Context,
raw: H::RawHandle,
attribute: FunctionAttribute,
value: i32,
) -> Result<()> {
ctx.bind()?;
unsafe {
try_ffi!(H::set_attribute(
raw,
attribute.into(),
value,
ctx.device().id(),
))?;
}
Ok(())
}