use crate::context::{CacheConfig, SharedMemoryConfig};
use crate::error::{CudaResult, ToResult};
use crate::module::Module;
use crate::sys::{self as cuda, CUfunction};
use std::marker::PhantomData;
use std::mem::{transmute, MaybeUninit};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GridSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl GridSize {
#[inline]
pub fn x(x: u32) -> GridSize {
GridSize { x, y: 1, z: 1 }
}
#[inline]
pub fn xy(x: u32, y: u32) -> GridSize {
GridSize { x, y, z: 1 }
}
#[inline]
pub fn xyz(x: u32, y: u32, z: u32) -> GridSize {
GridSize { x, y, z }
}
}
impl From<u32> for GridSize {
fn from(x: u32) -> GridSize {
GridSize::x(x)
}
}
impl From<(u32, u32)> for GridSize {
fn from((x, y): (u32, u32)) -> GridSize {
GridSize::xy(x, y)
}
}
impl From<(u32, u32, u32)> for GridSize {
fn from((x, y, z): (u32, u32, u32)) -> GridSize {
GridSize::xyz(x, y, z)
}
}
impl<'a> From<&'a GridSize> for GridSize {
fn from(other: &GridSize) -> GridSize {
*other
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec2<u32>> for GridSize {
fn from(vec: vek::Vec2<u32>) -> Self {
GridSize::xy(vec.x, vec.y)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec3<u32>> for GridSize {
fn from(vec: vek::Vec3<u32>) -> Self {
GridSize::xyz(vec.x, vec.y, vec.z)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec2<usize>> for GridSize {
fn from(vec: vek::Vec2<usize>) -> Self {
GridSize::xy(vec.x as u32, vec.y as u32)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec3<usize>> for GridSize {
fn from(vec: vek::Vec3<usize>) -> Self {
GridSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BlockSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl BlockSize {
#[inline]
pub fn x(x: u32) -> BlockSize {
BlockSize { x, y: 1, z: 1 }
}
#[inline]
pub fn xy(x: u32, y: u32) -> BlockSize {
BlockSize { x, y, z: 1 }
}
#[inline]
pub fn xyz(x: u32, y: u32, z: u32) -> BlockSize {
BlockSize { x, y, z }
}
}
impl From<u32> for BlockSize {
fn from(x: u32) -> BlockSize {
BlockSize::x(x)
}
}
impl From<(u32, u32)> for BlockSize {
fn from((x, y): (u32, u32)) -> BlockSize {
BlockSize::xy(x, y)
}
}
impl From<(u32, u32, u32)> for BlockSize {
fn from((x, y, z): (u32, u32, u32)) -> BlockSize {
BlockSize::xyz(x, y, z)
}
}
impl<'a> From<&'a BlockSize> for BlockSize {
fn from(other: &BlockSize) -> BlockSize {
*other
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec2<u32>> for BlockSize {
fn from(vec: vek::Vec2<u32>) -> Self {
BlockSize::xy(vec.x, vec.y)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec3<u32>> for BlockSize {
fn from(vec: vek::Vec3<u32>) -> Self {
BlockSize::xyz(vec.x, vec.y, vec.z)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec2<usize>> for BlockSize {
fn from(vec: vek::Vec2<usize>) -> Self {
BlockSize::xy(vec.x as u32, vec.y as u32)
}
}
#[cfg(feature = "vek")]
impl From<vek::Vec3<usize>> for BlockSize {
fn from(vec: vek::Vec3<usize>) -> Self {
BlockSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
}
}
#[repr(u32)]
#[non_exhaustive]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum FunctionAttribute {
MaxThreadsPerBlock = 0,
SharedMemorySizeBytes = 1,
ConstSizeBytes = 2,
LocalSizeBytes = 3,
NumRegisters = 4,
PtxVersion = 5,
BinaryVersion = 6,
CacheModeCa = 7,
}
#[derive(Debug)]
pub struct Function<'a> {
inner: CUfunction,
module: PhantomData<&'a Module>,
}
unsafe impl Send for Function<'_> {}
unsafe impl Sync for Function<'_> {}
impl<'a> Function<'a> {
pub(crate) fn new(inner: CUfunction, _module: &Module) -> Function {
Function {
inner,
module: PhantomData,
}
}
pub fn get_attribute(&self, attr: FunctionAttribute) -> CudaResult<i32> {
unsafe {
let mut val = 0i32;
cuda::cuFuncGetAttribute(
&mut val as *mut i32,
::std::mem::transmute(attr),
self.inner,
)
.to_result()?;
Ok(val)
}
}
pub fn set_cache_config(&mut self, config: CacheConfig) -> CudaResult<()> {
unsafe { cuda::cuFuncSetCacheConfig(self.inner, transmute(config)).to_result() }
}
pub fn set_shared_memory_config(&mut self, cfg: SharedMemoryConfig) -> CudaResult<()> {
unsafe { cuda::cuFuncSetSharedMemConfig(self.inner, transmute(cfg)).to_result() }
}
pub fn to_raw(&self) -> CUfunction {
self.inner
}
pub fn available_dynamic_shared_memory_per_block(
&self,
blocks: GridSize,
block_size: BlockSize,
) -> CudaResult<usize> {
let num_blocks = blocks.x * blocks.y * blocks.z;
let total_block_size = block_size.x * block_size.y * block_size.z;
let mut result = MaybeUninit::uninit();
unsafe {
cuda::cuOccupancyAvailableDynamicSMemPerBlock(
result.as_mut_ptr(),
self.to_raw(),
num_blocks as i32,
total_block_size as i32,
)
.to_result()?;
Ok(result.assume_init())
}
}
pub fn max_active_blocks_per_multiprocessor(
&self,
block_size: BlockSize,
dynamic_smem_size: usize,
) -> CudaResult<u32> {
let total_block_size = block_size.x * block_size.y * block_size.z;
let mut num_blocks = MaybeUninit::uninit();
unsafe {
cuda::cuOccupancyMaxActiveBlocksPerMultiprocessor(
num_blocks.as_mut_ptr(),
self.to_raw(),
total_block_size as i32,
dynamic_smem_size,
)
.to_result()?;
Ok(num_blocks.assume_init() as u32)
}
}
pub fn suggested_launch_configuration(
&self,
dynamic_smem_size: usize,
block_size_limit: BlockSize,
) -> CudaResult<(u32, u32)> {
let mut min_grid_size = MaybeUninit::uninit();
let mut block_size = MaybeUninit::uninit();
let total_block_size_limit = block_size_limit.x * block_size_limit.y * block_size_limit.z;
unsafe {
cuda::cuOccupancyMaxPotentialBlockSize(
min_grid_size.as_mut_ptr(),
block_size.as_mut_ptr(),
self.to_raw(),
None,
dynamic_smem_size,
total_block_size_limit as i32,
)
.to_result()?;
Ok((
min_grid_size.assume_init() as u32,
block_size.assume_init() as u32,
))
}
}
}
#[macro_export]
macro_rules! launch {
($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* $(,)?)) => {
{
let function = $module.get_function(stringify!($function));
match function {
Ok(f) => launch!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
Err(e) => Err(e),
}
}
};
($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* $(,)?)) => {
{
fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
if false {
$(
assert_impl_devicecopy($arg);
)*
};
$stream.launch(&$function, $grid, $block, $shared,
&[
$(
&$arg as *const _ as *mut ::std::ffi::c_void,
)*
]
)
}
};
}