use crate::context::{CacheConfig, SharedMemoryConfig};
use crate::error::{CudaResult, ToResult};
use crate::module::Module;
use cuda_driver_sys::CUfunction;
use std::marker::PhantomData;
use std::mem::transmute;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GridSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl GridSize {
#[inline]
pub fn x(x: u32) -> GridSize {
GridSize { x, y: 1, z: 1 }
}
#[inline]
pub fn xy(x: u32, y: u32) -> GridSize {
GridSize { x, y, z: 1 }
}
#[inline]
pub fn xyz(x: u32, y: u32, z: u32) -> GridSize {
GridSize { x, y, z }
}
}
impl From<u32> for GridSize {
fn from(x: u32) -> GridSize {
GridSize::x(x)
}
}
impl From<(u32, u32)> for GridSize {
fn from((x, y): (u32, u32)) -> GridSize {
GridSize::xy(x, y)
}
}
impl From<(u32, u32, u32)> for GridSize {
fn from((x, y, z): (u32, u32, u32)) -> GridSize {
GridSize::xyz(x, y, z)
}
}
impl<'a> From<&'a GridSize> for GridSize {
fn from(other: &GridSize) -> GridSize {
other.clone()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BlockSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl BlockSize {
#[inline]
pub fn x(x: u32) -> BlockSize {
BlockSize { x, y: 1, z: 1 }
}
#[inline]
pub fn xy(x: u32, y: u32) -> BlockSize {
BlockSize { x, y, z: 1 }
}
#[inline]
pub fn xyz(x: u32, y: u32, z: u32) -> BlockSize {
BlockSize { x, y, z }
}
}
impl From<u32> for BlockSize {
fn from(x: u32) -> BlockSize {
BlockSize::x(x)
}
}
impl From<(u32, u32)> for BlockSize {
fn from((x, y): (u32, u32)) -> BlockSize {
BlockSize::xy(x, y)
}
}
impl From<(u32, u32, u32)> for BlockSize {
fn from((x, y, z): (u32, u32, u32)) -> BlockSize {
BlockSize::xyz(x, y, z)
}
}
impl<'a> From<&'a BlockSize> for BlockSize {
fn from(other: &BlockSize) -> BlockSize {
other.clone()
}
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum FunctionAttribute {
MaxThreadsPerBlock = 0,
SharedMemorySizeBytes = 1,
ConstSizeBytes = 2,
LocalSizeBytes = 3,
NumRegisters = 4,
PtxVersion = 5,
BinaryVersion = 6,
CacheModeCa = 7,
#[doc(hidden)]
__Nonexhaustive = 8,
}
#[derive(Debug)]
pub struct Function<'a> {
inner: CUfunction,
module: PhantomData<&'a Module>,
}
impl<'a> Function<'a> {
pub(crate) fn new(inner: CUfunction, _module: &Module) -> Function {
Function {
inner,
module: PhantomData,
}
}
pub fn get_attribute(&self, attr: FunctionAttribute) -> CudaResult<i32> {
unsafe {
let mut val = 0i32;
cuda_driver_sys::cuFuncGetAttribute(
&mut val as *mut i32,
::std::mem::transmute(attr),
self.inner,
)
.to_result()?;
Ok(val)
}
}
pub fn set_cache_config(&mut self, config: CacheConfig) -> CudaResult<()> {
unsafe { cuda_driver_sys::cuFuncSetCacheConfig(self.inner, transmute(config)).to_result() }
}
pub fn set_shared_memory_config(&mut self, cfg: SharedMemoryConfig) -> CudaResult<()> {
unsafe { cuda_driver_sys::cuFuncSetSharedMemConfig(self.inner, transmute(cfg)).to_result() }
}
pub(crate) fn to_inner(&self) -> CUfunction {
self.inner
}
}
#[macro_export]
macro_rules! launch {
($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
{
let name = std::ffi::CString::new(stringify!($function)).unwrap();
let function = $module.get_function(&name);
match function {
Ok(f) => launch!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
Err(e) => Err(e),
}
}
};
($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
{
fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
if false {
$(
assert_impl_devicecopy($arg);
)*
};
$stream.launch(&$function, $grid, $block, $shared,
&[
$(
&$arg as *const _ as *mut ::std::ffi::c_void,
)*
]
)
}
};
}
#[cfg(test)]
mod test {
use super::*;
use crate::memory::CopyDestination;
use crate::memory::DeviceBuffer;
use crate::quick_init;
use crate::stream::{Stream, StreamFlags};
use std::error::Error;
use std::ffi::CString;
#[test]
fn test_launch() -> Result<(), Box<dyn Error>> {
let _context = quick_init();
let ptx_text = CString::new(include_str!("../resources/add.ptx"))?;
let module = Module::load_from_string(&ptx_text)?;
unsafe {
let mut in_x = DeviceBuffer::from_slice(&[2.0f32; 128])?;
let mut in_y = DeviceBuffer::from_slice(&[1.0f32; 128])?;
let mut out: DeviceBuffer<f32> = DeviceBuffer::uninitialized(128)?;
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
launch!(module.sum<<<1, 128, 0, stream>>>(in_x.as_device_ptr(), in_y.as_device_ptr(), out.as_device_ptr(), out.len()))?;
stream.synchronize()?;
let mut out_host = [0f32; 128];
out.copy_to(&mut out_host[..])?;
for x in out_host.iter() {
assert_eq!(3, *x as u32);
}
}
Ok(())
}
}