use std::ffi::c_void;
use oxicuda_driver::ffi::{CUfunction, CUmodule};
use oxicuda_driver::loader::try_driver;
use crate::error::{CudaRtError, CudaRtResult};
use crate::stream::CudaStream;
pub type CudaFunction = CUfunction;
pub type CudaModule = CUmodule;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Dim3 {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl Dim3 {
#[must_use]
pub const fn one_d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
#[must_use]
pub const fn two_d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
#[must_use]
pub const fn three_d(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
#[must_use]
pub fn volume(self) -> u64 {
self.x as u64 * self.y as u64 * self.z as u64
}
}
impl From<u32> for Dim3 {
fn from(x: u32) -> Self {
Self::one_d(x)
}
}
impl From<(u32, u32)> for Dim3 {
fn from((x, y): (u32, u32)) -> Self {
Self::two_d(x, y)
}
}
impl From<(u32, u32, u32)> for Dim3 {
fn from((x, y, z): (u32, u32, u32)) -> Self {
Self::three_d(x, y, z)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FuncAttributes {
pub shared_size_bytes: usize,
pub const_size_bytes: usize,
pub local_size_bytes: usize,
pub max_threads_per_block: u32,
pub num_regs: u32,
pub ptx_version: u32,
pub binary_version: u32,
pub cache_mode_ca: bool,
pub max_dynamic_shared_size_bytes: usize,
pub preferred_shared_memory_carveout: i32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FuncAttribute {
MaxDynamicSharedMemorySize = 8,
PreferredSharedMemoryCarveout = 9,
}
pub fn module_load_ptx(ptx: &[u8]) -> CudaRtResult<CudaModule> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut module = CUmodule::default();
let mut ptx_owned;
let ptx_ptr = if ptx.last().copied() == Some(0) {
ptx.as_ptr()
} else {
ptx_owned = ptx.to_vec();
ptx_owned.push(0);
ptx_owned.as_ptr()
};
let rc = unsafe {
(api.cu_module_load_data_ex)(
&raw mut module,
ptx_ptr as *const c_void,
0,
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidPtx));
}
Ok(module)
}
pub fn module_get_function(module: CudaModule, name: &str) -> CudaRtResult<CudaFunction> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut func = CUfunction::default();
let name_cstr = std::ffi::CString::new(name).map_err(|_| CudaRtError::InvalidSymbol)?;
let rc = unsafe { (api.cu_module_get_function)(&raw mut func, module, name_cstr.as_ptr()) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::SymbolNotFound));
}
Ok(func)
}
pub fn module_unload(module: CudaModule) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_module_unload)(module) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(())
}
pub unsafe fn launch_kernel(
func: CudaFunction,
grid: Dim3,
block: Dim3,
args: &mut [*mut c_void],
shared_mem: u32,
stream: CudaStream,
) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe {
(api.cu_launch_kernel)(
func,
grid.x,
grid.y,
grid.z,
block.x,
block.y,
block.z,
shared_mem,
stream.raw(),
args.as_mut_ptr(),
std::ptr::null_mut(), )
};
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::LaunchFailure));
}
Ok(())
}
pub fn func_get_attributes(func: CudaFunction) -> CudaRtResult<FuncAttributes> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let get_attr_fn = api.cu_func_get_attribute.ok_or(CudaRtError::NotSupported)?;
let attr = |a: oxicuda_driver::ffi::CUfunction_attribute| -> CudaRtResult<i32> {
let mut v: std::ffi::c_int = 0;
let rc = unsafe { get_attr_fn(&raw mut v, a as std::ffi::c_int, func) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
}
Ok(v)
};
use oxicuda_driver::ffi::CUfunction_attribute as FA;
Ok(FuncAttributes {
shared_size_bytes: attr(FA::SharedSizeBytes)? as usize,
const_size_bytes: attr(FA::ConstSizeBytes)? as usize,
local_size_bytes: attr(FA::LocalSizeBytes)? as usize,
max_threads_per_block: attr(FA::MaxThreadsPerBlock)? as u32,
num_regs: attr(FA::NumRegs)? as u32,
ptx_version: attr(FA::PtxVersion)? as u32,
binary_version: attr(FA::BinaryVersion)? as u32,
cache_mode_ca: attr(FA::CacheModeCa)? != 0,
max_dynamic_shared_size_bytes: attr(FA::MaxDynamicSharedSizeBytes)? as usize,
preferred_shared_memory_carveout: attr(FA::PreferredSharedMemoryCarveout)?,
})
}
pub fn func_set_attribute(func: CudaFunction, attr: FuncAttribute, value: i32) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let set_attr_fn = api.cu_func_set_attribute.ok_or(CudaRtError::NotSupported)?;
let rc = unsafe { set_attr_fn(func, attr as std::ffi::c_int, value) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dim3_one_d() {
let d = Dim3::one_d(128);
assert_eq!(d.x, 128);
assert_eq!(d.y, 1);
assert_eq!(d.z, 1);
assert_eq!(d.volume(), 128);
}
#[test]
fn dim3_from_u32() {
let d: Dim3 = 256u32.into();
assert_eq!(d.x, 256);
}
#[test]
fn dim3_from_tuple() {
let d: Dim3 = (32u32, 8u32).into();
assert_eq!(d.volume(), 256);
let d3: Dim3 = (4u32, 4u32, 4u32).into();
assert_eq!(d3.volume(), 64);
}
#[test]
fn dim3_volume() {
assert_eq!(Dim3::three_d(2, 3, 4).volume(), 24);
}
#[test]
fn module_load_ptx_without_gpu_errors() {
let ptx = b"// empty\n\0";
let _ = module_load_ptx(ptx); }
}