use super::{result, sys};
use crate::nvrtc::Ptx;
use alloc::ffi::{CString, NulError};
use spin::RwLock;
use std::{collections::BTreeMap, marker::Unpin, pin::Pin, sync::Arc, vec::Vec};
pub use result::DriverError;
#[derive(Debug)]
pub struct CudaSlice<T> {
pub(crate) cu_device_ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) device: Arc<CudaDevice>,
pub(crate) host_buf: Option<Pin<Vec<T>>>,
}
unsafe impl<T: Send> Send for CudaSlice<T> {}
unsafe impl<T: Sync> Sync for CudaSlice<T> {}
impl<T> CudaSlice<T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn num_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
pub fn clone_async(&self) -> Result<Self, DriverError> {
let dst = unsafe { self.device.alloc(self.len) }?;
unsafe {
result::memcpy_dtod_async(
dst.cu_device_ptr,
self.cu_device_ptr,
self.num_bytes(),
self.device.cu_stream,
)
}?;
Ok(dst)
}
}
impl<T> Clone for CudaSlice<T> {
fn clone(&self) -> Self {
self.clone_async().unwrap()
}
}
impl<T> Drop for CudaSlice<T> {
fn drop(&mut self) {
unsafe { result::free_async(self.cu_device_ptr, self.device.cu_stream) }.unwrap();
}
}
impl<T: Clone + Default + Unpin> TryFrom<CudaSlice<T>> for Vec<T> {
type Error = DriverError;
fn try_from(value: CudaSlice<T>) -> Result<Self, Self::Error> {
value.device.clone().sync_release(value)
}
}
#[allow(unused)]
pub struct CudaView<'a, T> {
slice: &'a CudaSlice<T>,
ptr: sys::CUdeviceptr,
}
#[allow(unused)]
pub struct CudaViewMut<'a, T> {
slice: &'a mut CudaSlice<T>,
ptr: sys::CUdeviceptr,
}
impl<T> CudaSlice<T> {
pub fn try_slice(&self, range: std::ops::RangeFrom<usize>) -> Option<CudaView<'_, T>> {
if range.start < self.len {
Some(CudaView {
ptr: self.cu_device_ptr + (range.start * std::mem::size_of::<T>()) as u64,
slice: self,
})
} else {
None
}
}
pub fn try_slice_mut(
&mut self,
range: std::ops::RangeFrom<usize>,
) -> Option<CudaViewMut<'_, T>> {
if range.start < self.len {
Some(CudaViewMut {
ptr: self.cu_device_ptr + (range.start * std::mem::size_of::<T>()) as u64,
slice: self,
})
} else {
None
}
}
}
pub trait DevicePtr<T> {
fn device_ptr(&self) -> &sys::CUdeviceptr;
}
impl<T> DevicePtr<T> for CudaSlice<T> {
fn device_ptr(&self) -> &sys::CUdeviceptr {
&self.cu_device_ptr
}
}
impl<'a, T> DevicePtr<T> for CudaView<'a, T> {
fn device_ptr(&self) -> &sys::CUdeviceptr {
&self.ptr
}
}
pub trait DevicePtrMut<T> {
fn device_ptr_mut(&mut self) -> &mut sys::CUdeviceptr;
}
impl<T> DevicePtrMut<T> for CudaSlice<T> {
fn device_ptr_mut(&mut self) -> &mut sys::CUdeviceptr {
&mut self.cu_device_ptr
}
}
impl<'a, T> DevicePtrMut<T> for CudaViewMut<'a, T> {
fn device_ptr_mut(&mut self) -> &mut sys::CUdeviceptr {
&mut self.ptr
}
}
#[derive(Debug)]
pub struct CudaDevice {
pub(crate) cu_device: sys::CUdevice,
pub(crate) cu_primary_ctx: sys::CUcontext,
pub(crate) cu_stream: sys::CUstream,
pub(crate) modules: RwLock<BTreeMap<&'static str, CudaModule>>,
}
unsafe impl Send for CudaDevice {}
unsafe impl Sync for CudaDevice {}
impl Drop for CudaDevice {
fn drop(&mut self) {
let modules = RwLock::get_mut(&mut self.modules);
for (_, module) in modules.iter() {
unsafe { result::module::unload(module.cu_module) }.unwrap();
}
modules.clear();
let stream = std::mem::replace(&mut self.cu_stream, std::ptr::null_mut());
if !stream.is_null() {
unsafe { result::stream::destroy(stream) }.unwrap();
}
let ctx = std::mem::replace(&mut self.cu_primary_ctx, std::ptr::null_mut());
if !ctx.is_null() {
unsafe { result::primary_ctx::release(self.cu_device) }.unwrap();
}
}
}
impl CudaDevice {
unsafe fn alloc<T>(self: &Arc<Self>, len: usize) -> Result<CudaSlice<T>, DriverError> {
let cu_device_ptr = result::malloc_async(self.cu_stream, len * std::mem::size_of::<T>())?;
Ok(CudaSlice {
cu_device_ptr,
len,
device: self.clone(),
host_buf: None,
})
}
pub fn alloc_zeros_async<T: ValidAsZeroBits>(
self: &Arc<Self>,
len: usize,
) -> Result<CudaSlice<T>, DriverError> {
let dst = unsafe { self.alloc(len) }?;
unsafe { result::memset_d8_async(dst.cu_device_ptr, 0, dst.num_bytes(), self.cu_stream) }?;
Ok(dst)
}
pub fn take_async<T: Unpin>(
self: &Arc<Self>,
src: Vec<T>,
) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.copy_into_async(src, &mut dst)?;
Ok(dst)
}
pub fn sync_copy<T>(self: &Arc<Self>, src: &[T]) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.sync_copy_into(src, &mut dst)?;
Ok(dst)
}
pub fn sync_copy_into<T>(
self: &Arc<Self>,
src: &[T],
dst: &mut CudaSlice<T>,
) -> Result<(), DriverError> {
assert_eq!(src.len(), dst.len());
unsafe { result::memcpy_htod_async(dst.cu_device_ptr, src, self.cu_stream) }?;
self.synchronize()
}
pub fn copy_into_async<T: Unpin>(
self: &Arc<Self>,
src: Vec<T>,
dst: &mut CudaSlice<T>,
) -> Result<(), DriverError> {
assert_eq!(src.len(), dst.len());
dst.host_buf = Some(Pin::new(src));
unsafe {
result::memcpy_htod_async(
dst.cu_device_ptr,
dst.host_buf.as_ref().unwrap(),
self.cu_stream,
)
}?;
Ok(())
}
pub fn sync_copy_from<T>(
self: &Arc<Self>,
src: &CudaSlice<T>,
dst: &mut [T],
) -> Result<(), DriverError> {
assert_eq!(src.len(), dst.len());
unsafe { result::memcpy_dtoh_async(dst, src.cu_device_ptr, self.cu_stream) }?;
self.synchronize()
}
pub fn sync_release<T: Clone + Default + Unpin>(
self: &Arc<Self>,
mut src: CudaSlice<T>,
) -> Result<Vec<T>, DriverError> {
let buf = src.host_buf.take();
let mut buf = buf.unwrap_or_else(|| {
let mut b = Vec::with_capacity(src.len);
b.resize(src.len, Default::default());
Pin::new(b)
});
self.sync_copy_from(&src, &mut buf)?;
Ok(Pin::into_inner(buf))
}
pub fn synchronize(self: &Arc<Self>) -> Result<(), DriverError> {
unsafe { result::stream::synchronize(self.cu_stream) }
}
pub fn has_func(self: &Arc<Self>, module_name: &str, func_name: &str) -> bool {
let modules = self.modules.read();
modules
.get(module_name)
.map_or(false, |module| module.has_func(func_name))
}
pub fn get_func(self: &Arc<Self>, module_name: &str, func_name: &str) -> Option<CudaFunction> {
let modules = self.modules.read();
modules
.get(module_name)
.and_then(|m| m.get_func(func_name))
.map(|cu_function| CudaFunction {
cu_function,
device: self.clone(),
})
}
pub fn load_ptx_from_file(
self: &Arc<Self>,
ptx_path: &'static str,
module_name: &'static str,
func_names: &[&'static str],
) -> Result<(), BuildError> {
let m = CudaDeviceBuilder::build_module_from_ptx_file(ptx_path, module_name, func_names)?;
{
let mut modules = self.modules.write();
modules.insert(module_name, m);
}
Ok(())
}
pub fn load_ptx(
self: &Arc<Self>,
ptx: Ptx,
module_name: &'static str,
func_names: &[&'static str],
) -> Result<(), BuildError> {
let m = CudaDeviceBuilder::build_module_from_ptx(ptx, module_name, func_names)?;
{
let mut modules = self.modules.write();
modules.insert(module_name, m);
}
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct CudaModule {
pub(crate) cu_module: sys::CUmodule,
pub(crate) functions: BTreeMap<&'static str, sys::CUfunction>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
impl CudaModule {
pub(crate) fn get_func(&self, name: &str) -> Option<sys::CUfunction> {
self.functions.get(name).cloned()
}
pub(crate) fn has_func(&self, name: &str) -> bool {
self.functions.contains_key(name)
}
}
#[derive(Debug, Clone)]
pub struct CudaFunction {
pub(crate) cu_function: sys::CUfunction,
pub(crate) device: Arc<CudaDevice>,
}
unsafe impl Send for CudaFunction {}
unsafe impl Sync for CudaFunction {}
#[derive(Clone, Copy)]
pub struct LaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_mem_bytes: u32,
}
impl LaunchConfig {
pub fn for_num_elems(n: u32) -> Self {
const NUM_THREADS: u32 = 1024;
let num_blocks = (n + NUM_THREADS - 1) / NUM_THREADS;
Self {
grid_dim: (num_blocks, 1, 1),
block_dim: (NUM_THREADS, 1, 1),
shared_mem_bytes: 0,
}
}
}
pub unsafe trait AsKernelParam {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
self as *const Self as *mut _
}
}
unsafe impl AsKernelParam for i8 {}
unsafe impl AsKernelParam for i16 {}
unsafe impl AsKernelParam for i32 {}
unsafe impl AsKernelParam for i64 {}
unsafe impl AsKernelParam for isize {}
unsafe impl AsKernelParam for u8 {}
unsafe impl AsKernelParam for u16 {}
unsafe impl AsKernelParam for u32 {}
unsafe impl AsKernelParam for u64 {}
unsafe impl AsKernelParam for usize {}
unsafe impl AsKernelParam for f32 {}
unsafe impl AsKernelParam for f64 {}
unsafe impl<T> AsKernelParam for &mut CudaSlice<T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.cu_device_ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<T> AsKernelParam for &CudaSlice<T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.cu_device_ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<'a, T> AsKernelParam for &CudaView<'a, T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<'a, T> AsKernelParam for &mut CudaViewMut<'a, T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
pub unsafe trait LaunchAsync<Params> {
unsafe fn launch_async(self, cfg: LaunchConfig, params: Params) -> Result<(), DriverError>;
}
macro_rules! impl_launch {
([$($Vars:tt),*], [$($Idx:tt),*]) => {
unsafe impl<$($Vars: AsKernelParam),*> LaunchAsync<($($Vars, )*)> for CudaFunction {
unsafe fn launch_async(
self,
cfg: LaunchConfig,
args: ($($Vars, )*)
) -> Result<(), DriverError> {
let params = &mut [$(args.$Idx.as_kernel_param(), )*];
result::launch_kernel(
self.cu_function,
cfg.grid_dim,
cfg.block_dim,
cfg.shared_mem_bytes,
self.device.cu_stream,
params,
)
}
}
};
}
impl_launch!([A], [0]);
impl_launch!([A, B], [0, 1]);
impl_launch!([A, B, C], [0, 1, 2]);
impl_launch!([A, B, C, D], [0, 1, 2, 3]);
impl_launch!([A, B, C, D, E], [0, 1, 2, 3, 4]);
impl_launch!([A, B, C, D, E, F], [0, 1, 2, 3, 4, 5]);
impl_launch!([A, B, C, D, E, F, G], [0, 1, 2, 3, 4, 5, 6]);
impl_launch!([A, B, C, D, E, F, G, H], [0, 1, 2, 3, 4, 5, 6, 7]);
impl_launch!([A, B, C, D, E, F, G, H, I], [0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_launch!(
[A, B, C, D, E, F, G, H, I, J],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
impl_launch!(
[A, B, C, D, E, F, G, H, I, J, K],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
impl_launch!(
[A, B, C, D, E, F, G, H, I, J, K, L],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
);
#[derive(Debug)]
pub struct CudaDeviceBuilder {
pub(crate) ordinal: usize,
pub(crate) ptx_files: Vec<PtxFileConfig>,
pub(crate) ptxs: Vec<PtxConfig>,
}
#[derive(Debug)]
pub(crate) struct PtxFileConfig {
pub(crate) key: &'static str,
pub(crate) fname: &'static str,
pub(crate) fn_names: Vec<&'static str>,
}
#[derive(Debug)]
pub(crate) struct PtxConfig {
pub(crate) key: &'static str,
pub(crate) ptx: Ptx,
pub(crate) fn_names: Vec<&'static str>,
}
impl CudaDeviceBuilder {
pub fn new(ordinal: usize) -> Self {
Self {
ordinal,
ptx_files: Vec::new(),
ptxs: Vec::new(),
}
}
pub fn with_ptx_from_file(
mut self,
ptx_path: &'static str,
key: &'static str,
fn_names: &[&'static str],
) -> Self {
self.ptx_files.push(PtxFileConfig {
key,
fname: ptx_path,
fn_names: fn_names.to_vec(),
});
self
}
pub fn with_ptx(mut self, ptx: Ptx, key: &'static str, fn_names: &[&'static str]) -> Self {
self.ptxs.push(PtxConfig {
key,
ptx,
fn_names: fn_names.to_vec(),
});
self
}
pub fn build(mut self) -> Result<Arc<CudaDevice>, BuildError> {
result::init().map_err(BuildError::InitError)?;
let cu_device =
result::device::get(self.ordinal as i32).map_err(BuildError::DeviceError)?;
let cu_primary_ctx =
unsafe { result::primary_ctx::retain(cu_device) }.map_err(BuildError::ContextError)?;
unsafe { result::ctx::set_current(cu_primary_ctx) }.map_err(BuildError::ContextError)?;
let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)
.map_err(BuildError::StreamError)?;
let mut modules = BTreeMap::new();
for cu in self.ptx_files.drain(..) {
modules.insert(
cu.key,
Self::build_module_from_ptx_file(cu.fname, cu.key, &cu.fn_names)?,
);
}
for ptx in self.ptxs.drain(..) {
modules.insert(
ptx.key,
Self::build_module_from_ptx(ptx.ptx, ptx.key, &ptx.fn_names)?,
);
}
let device = CudaDevice {
cu_device,
cu_primary_ctx,
cu_stream,
modules: RwLock::new(modules),
};
Ok(Arc::new(device))
}
fn build_module_from_ptx_file(
ptx_path: &'static str,
key: &'static str,
func_names: &[&'static str],
) -> Result<CudaModule, BuildError> {
let name_c = CString::new(ptx_path).map_err(BuildError::CStringError)?;
let cu_module = result::module::load(name_c)
.map_err(|cuda| BuildError::PtxLoadingError { key, cuda })?;
let mut functions = BTreeMap::new();
for &fn_name in func_names.iter() {
let fn_name_c = CString::new(fn_name).map_err(BuildError::CStringError)?;
let cu_function = unsafe { result::module::get_function(cu_module, fn_name_c) }
.map_err(|e| BuildError::GetFunctionError {
key,
symbol: fn_name,
cuda: e,
})?;
functions.insert(fn_name, cu_function);
}
Ok(CudaModule {
cu_module,
functions,
})
}
fn build_module_from_ptx(
ptx: Ptx,
key: &'static str,
fn_names: &[&'static str],
) -> Result<CudaModule, BuildError> {
let cu_module = match ptx {
Ptx::Image(image) => unsafe { result::module::load_data(image.as_ptr() as *const _) },
Ptx::Src(src) => {
let c_src = CString::new(src).unwrap();
unsafe { result::module::load_data(c_src.as_ptr() as *const _) }
}
}
.map_err(|cuda| BuildError::NvrtcLoadingError { key, cuda })?;
let mut functions = BTreeMap::new();
for &fn_name in fn_names.iter() {
let fn_name_c = CString::new(fn_name).map_err(BuildError::CStringError)?;
let cu_function = unsafe { result::module::get_function(cu_module, fn_name_c) }
.map_err(|e| BuildError::GetFunctionError {
key,
symbol: fn_name,
cuda: e,
})?;
functions.insert(fn_name, cu_function);
}
Ok(CudaModule {
cu_module,
functions,
})
}
}
#[derive(Debug)]
pub enum BuildError {
InitError(DriverError),
DeviceError(DriverError),
ContextError(DriverError),
StreamError(DriverError),
PtxLoadingError {
key: &'static str,
cuda: DriverError,
},
NvrtcLoadingError {
key: &'static str,
cuda: DriverError,
},
GetFunctionError {
key: &'static str,
symbol: &'static str,
cuda: DriverError,
},
CStringError(NulError),
}
#[cfg(feature = "std")]
impl std::fmt::Display for BuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for BuildError {}
pub unsafe trait ValidAsZeroBits {}
unsafe impl ValidAsZeroBits for i8 {}
unsafe impl ValidAsZeroBits for i16 {}
unsafe impl ValidAsZeroBits for i32 {}
unsafe impl ValidAsZeroBits for i64 {}
unsafe impl ValidAsZeroBits for isize {}
unsafe impl ValidAsZeroBits for u8 {}
unsafe impl ValidAsZeroBits for u16 {}
unsafe impl ValidAsZeroBits for u32 {}
unsafe impl ValidAsZeroBits for u64 {}
unsafe impl ValidAsZeroBits for usize {}
unsafe impl ValidAsZeroBits for f32 {}
unsafe impl ValidAsZeroBits for f64 {}
unsafe impl<T: ValidAsZeroBits, const M: usize> ValidAsZeroBits for [T; M] {}
#[cfg(test)]
mod tests {
use crate::nvrtc::compile_ptx_with_opts;
use super::*;
#[test]
fn test_post_build_arc_count() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_alloc_arc_counts() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.alloc_zeros_async::<f32>(1).unwrap();
assert!(t.host_buf.is_none());
assert_eq!(Arc::strong_count(&device), 2);
}
#[test]
fn test_post_take_arc_counts() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.take_async([0.0f32; 5].to_vec()).unwrap();
assert!(t.host_buf.is_some());
assert_eq!(Arc::strong_count(&device), 2);
drop(t);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_clone_counts() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.take_async([0.0f64; 10].to_vec()).unwrap();
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
drop(t);
assert_eq!(Arc::strong_count(&device), 2);
drop(r);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_clone_arc_slice_counts() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = Arc::new(device.take_async([0.0f64; 10].to_vec()).unwrap());
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 2);
drop(t);
assert_eq!(Arc::strong_count(&device), 2);
drop(r);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_release_counts() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.take_async([1.0f32, 2.0, 3.0].to_vec()).unwrap();
#[allow(clippy::redundant_clone)]
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
let r_host = device.sync_release(r).unwrap();
assert_eq!(&r_host, &[1.0, 2.0, 3.0]);
assert_eq!(Arc::strong_count(&device), 2);
drop(r_host);
assert_eq!(Arc::strong_count(&device), 2);
}
#[test]
#[ignore = "must be executed by itself"]
fn test_post_alloc_memory() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let (free1, total1) = result::mem_get_info().unwrap();
let t = device.take_async([0.0f32; 5].to_vec()).unwrap();
let (free2, total2) = result::mem_get_info().unwrap();
assert_eq!(total1, total2);
assert!(free2 < free1);
drop(t);
device.synchronize().unwrap();
let (free3, total3) = result::mem_get_info().unwrap();
assert_eq!(total2, total3);
assert!(free3 > free2);
assert_eq!(free3, free1);
}
#[test]
fn test_mut_into_kernel_param_no_inc_rc() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.take_async([0.0f32; 1].to_vec()).unwrap();
let _r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
let _ = (&t).as_kernel_param();
assert_eq!(Arc::strong_count(&device), 3);
}
#[test]
fn test_ref_into_kernel_param_inc_rc() {
let device = CudaDeviceBuilder::new(0).build().unwrap();
let t = device.take_async([0.0f32; 1].to_vec()).unwrap();
let _r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
let _ = (&t).as_kernel_param();
assert_eq!(Arc::strong_count(&device), 3);
}
const SIN_CU: &str = "
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, size_t numel) {
size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}";
#[test]
fn test_launch_with_mut_and_ref_cudarc() {
let ptx = compile_ptx_with_opts(SIN_CU, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "sin", &["sin_kernel"])
.build()
.unwrap();
let sin_kernel = dev.get_func("sin", "sin_kernel").unwrap();
let a_host = [-1.0f32, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8];
let a_dev = dev.take_async(a_host.clone().to_vec()).unwrap();
let mut b_dev = a_dev.clone();
unsafe {
sin_kernel.launch_async(
LaunchConfig::for_num_elems(10),
(&mut b_dev, &a_dev, 10usize),
)
}
.unwrap();
let b_host = dev.sync_release(b_dev).unwrap();
for (a_i, b_i) in a_host.iter().zip(b_host.iter()) {
let expected = a_i.sin();
assert!((b_i - expected).abs() <= 1e-6);
}
drop(a_dev);
}
#[test]
fn test_large_launches() {
let ptx = compile_ptx_with_opts(SIN_CU, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "sin", &["sin_kernel"])
.build()
.unwrap();
for numel in [256, 512, 1024, 1280, 1536, 2048] {
let mut a = Vec::with_capacity(numel);
a.resize(numel, 1.0f32);
let a = dev.take_async(a).unwrap();
let mut b = dev.alloc_zeros_async::<f32>(numel).unwrap();
let sin_kernel = dev.get_func("sin", "sin_kernel").unwrap();
let cfg = LaunchConfig::for_num_elems(numel as u32);
unsafe { sin_kernel.launch_async(cfg, (&mut b, &a, numel)) }.unwrap();
let b = dev.sync_release(b).unwrap();
for v in b {
assert_eq!(v, 0.841471);
}
}
}
#[test]
fn test_launch_with_views() {
let ptx = compile_ptx_with_opts(SIN_CU, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "sin", &["sin_kernel"])
.build()
.unwrap();
let a_host = [-1.0f32, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8];
let a_dev = dev.take_async(a_host.clone().to_vec()).unwrap();
let mut b_dev = a_dev.clone();
for i in 0..5 {
let a_sub = a_dev.try_slice(i * 2..).unwrap();
let mut b_sub = b_dev.try_slice_mut(i * 2..).unwrap();
let f = dev.get_func("sin", "sin_kernel").unwrap();
unsafe { f.launch_async(LaunchConfig::for_num_elems(2), (&mut b_sub, &a_sub, 2usize)) }
.unwrap();
}
let b_host = dev.sync_release(b_dev).unwrap();
for (a_i, b_i) in a_host.iter().zip(b_host.iter()) {
let expected = a_i.sin();
assert!((b_i - expected).abs() <= 1e-6);
}
drop(a_dev);
}
const TEST_KERNELS: &str = "
extern \"C\" __global__ void int_8bit(signed char s_min, char s_max, unsigned char u_min, unsigned char u_max) {
assert(s_min == -128);
assert(s_max == 127);
assert(u_min == 0);
assert(u_max == 255);
}
extern \"C\" __global__ void int_16bit(signed short s_min, short s_max, unsigned short u_min, unsigned short u_max) {
assert(s_min == -32768);
assert(s_max == 32767);
assert(u_min == 0);
assert(u_max == 65535);
}
extern \"C\" __global__ void int_32bit(signed int s_min, int s_max, unsigned int u_min, unsigned int u_max) {
assert(s_min == -2147483648);
assert(s_max == 2147483647);
assert(u_min == 0);
assert(u_max == 4294967295);
}
extern \"C\" __global__ void int_64bit(signed long s_min, long s_max, unsigned long u_min, unsigned long u_max) {
assert(s_min == -9223372036854775808);
assert(s_max == 9223372036854775807);
assert(u_min == 0);
assert(u_max == 18446744073709551615);
}
extern \"C\" __global__ void floating(float f, double d) {
assert(fabs(f - 1.2345678) <= 1e-7);
assert(fabs(d - -10.123456789876543) <= 1e-16);
}
";
#[test]
fn test_launch_with_8bit() {
let ptx = compile_ptx_with_opts(TEST_KERNELS, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "tests", &["int_8bit"])
.build()
.unwrap();
let f = dev.get_func("tests", "int_8bit").unwrap();
unsafe {
f.launch_async(
LaunchConfig::for_num_elems(1),
(i8::MIN, i8::MAX, u8::MIN, u8::MAX),
)
}
.unwrap();
dev.synchronize().unwrap();
}
#[test]
fn test_launch_with_16bit() {
let ptx = compile_ptx_with_opts(TEST_KERNELS, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "tests", &["int_16bit"])
.build()
.unwrap();
let f = dev.get_func("tests", "int_16bit").unwrap();
unsafe {
f.launch_async(
LaunchConfig::for_num_elems(1),
(i16::MIN, i16::MAX, u16::MIN, u16::MAX),
)
}
.unwrap();
dev.synchronize().unwrap();
}
#[test]
fn test_launch_with_32bit() {
let ptx = compile_ptx_with_opts(TEST_KERNELS, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "tests", &["int_32bit"])
.build()
.unwrap();
let f = dev.get_func("tests", "int_32bit").unwrap();
unsafe {
f.launch_async(
LaunchConfig::for_num_elems(1),
(i32::MIN, i32::MAX, u32::MIN, u32::MAX),
)
}
.unwrap();
dev.synchronize().unwrap();
}
#[test]
fn test_launch_with_64bit() {
let ptx = compile_ptx_with_opts(TEST_KERNELS, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "tests", &["int_64bit"])
.build()
.unwrap();
let f = dev.get_func("tests", "int_64bit").unwrap();
unsafe {
f.launch_async(
LaunchConfig::for_num_elems(1),
(i64::MIN, i64::MAX, u64::MIN, u64::MAX),
)
}
.unwrap();
dev.synchronize().unwrap();
}
#[test]
fn test_launch_with_floats() {
let ptx = compile_ptx_with_opts(TEST_KERNELS, Default::default()).unwrap();
let dev = CudaDeviceBuilder::new(0)
.with_ptx(ptx, "tests", &["floating"])
.build()
.unwrap();
let f = dev.get_func("tests", "floating").unwrap();
unsafe {
f.launch_async(
LaunchConfig::for_num_elems(1),
(1.2345678f32, -10.123456789876543f64),
)
}
.unwrap();
dev.synchronize().unwrap();
}
}