use crate::driver::{result, sys};
use super::core::{CudaDevice, CudaSlice, CudaView, CudaViewMut};
use super::device_ptr::{DevicePtr, DevicePtrMut, DeviceSlice};
use std::{marker::Unpin, pin::Pin, sync::Arc, vec::Vec};
pub unsafe trait DeviceRepr {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
self as *const Self as *mut _
}
}
unsafe impl DeviceRepr for bool {}
unsafe impl DeviceRepr for i8 {}
unsafe impl DeviceRepr for i16 {}
unsafe impl DeviceRepr for i32 {}
unsafe impl DeviceRepr for i64 {}
unsafe impl DeviceRepr for i128 {}
unsafe impl DeviceRepr for isize {}
unsafe impl DeviceRepr for u8 {}
unsafe impl DeviceRepr for u16 {}
unsafe impl DeviceRepr for u32 {}
unsafe impl DeviceRepr for u64 {}
unsafe impl DeviceRepr for u128 {}
unsafe impl DeviceRepr for usize {}
unsafe impl DeviceRepr for f32 {}
unsafe impl DeviceRepr for f64 {}
#[cfg(feature = "f16")]
unsafe impl DeviceRepr for half::f16 {}
#[cfg(feature = "f16")]
unsafe impl DeviceRepr for half::bf16 {}
unsafe impl<T: DeviceRepr> DeviceRepr for &mut CudaSlice<T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.cu_device_ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<T: DeviceRepr> DeviceRepr for &CudaSlice<T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.cu_device_ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<'a, T: DeviceRepr> DeviceRepr for &CudaView<'a, T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
unsafe impl<'a, T: DeviceRepr> DeviceRepr for &mut CudaViewMut<'a, T> {
#[inline(always)]
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
(&self.ptr) as *const sys::CUdeviceptr as *mut std::ffi::c_void
}
}
impl<T> CudaSlice<T> {
pub fn leak(mut self) -> sys::CUdeviceptr {
if let Some(host_buf) = std::mem::take(&mut self.host_buf) {
drop(host_buf);
}
let ptr = self.cu_device_ptr;
std::mem::forget(self);
ptr
}
}
impl CudaDevice {
pub unsafe fn upgrade_device_ptr<T>(
self: &Arc<Self>,
cu_device_ptr: sys::CUdeviceptr,
len: usize,
) -> CudaSlice<T> {
CudaSlice {
cu_device_ptr,
len,
device: self.clone(),
host_buf: None,
}
}
}
impl CudaDevice {
pub fn null<T>(self: &Arc<Self>) -> Result<CudaSlice<T>, result::DriverError> {
self.bind_to_thread()?;
let cu_device_ptr = unsafe { result::malloc_async(self.stream, 0) }?;
Ok(CudaSlice {
cu_device_ptr,
len: 0,
device: self.clone(),
host_buf: None,
})
}
pub unsafe fn alloc<T: DeviceRepr>(
self: &Arc<Self>,
len: usize,
) -> Result<CudaSlice<T>, result::DriverError> {
self.bind_to_thread()?;
let cu_device_ptr = result::malloc_async(self.stream, len * std::mem::size_of::<T>())?;
Ok(CudaSlice {
cu_device_ptr,
len,
device: self.clone(),
host_buf: None,
})
}
pub fn alloc_zeros<T: ValidAsZeroBits + DeviceRepr>(
self: &Arc<Self>,
len: usize,
) -> Result<CudaSlice<T>, result::DriverError> {
let mut dst = unsafe { self.alloc(len) }?;
self.memset_zeros(&mut dst)?;
Ok(dst)
}
pub fn memset_zeros<T: ValidAsZeroBits + DeviceRepr, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
dst: &mut Dst,
) -> Result<(), result::DriverError> {
self.bind_to_thread()?;
unsafe { result::memset_d8_async(*dst.device_ptr_mut(), 0, dst.num_bytes(), self.stream) }
}
pub fn dtod_copy<T: DeviceRepr, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
src: &Src,
dst: &mut Dst,
) -> Result<(), result::DriverError> {
assert_eq!(src.len(), dst.len());
self.bind_to_thread()?;
unsafe {
result::memcpy_dtod_async(
*dst.device_ptr_mut(),
*src.device_ptr(),
src.len() * std::mem::size_of::<T>(),
self.stream,
)
}
}
pub fn htod_copy<T: Unpin + DeviceRepr>(
self: &Arc<Self>,
src: Vec<T>,
) -> Result<CudaSlice<T>, result::DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.htod_copy_into(src, &mut dst)?;
Ok(dst)
}
pub fn htod_copy_into<T: DeviceRepr + Unpin>(
self: &Arc<Self>,
src: Vec<T>,
dst: &mut CudaSlice<T>,
) -> Result<(), result::DriverError> {
assert_eq!(src.len(), dst.len());
dst.host_buf = Some(Pin::new(src));
self.bind_to_thread()?;
unsafe {
result::memcpy_htod_async(
dst.cu_device_ptr,
dst.host_buf.as_ref().unwrap(),
self.stream,
)
}?;
Ok(())
}
pub fn htod_sync_copy<T: DeviceRepr>(
self: &Arc<Self>,
src: &[T],
) -> Result<CudaSlice<T>, result::DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.htod_sync_copy_into(src, &mut dst)?;
Ok(dst)
}
pub fn htod_sync_copy_into<T: DeviceRepr, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
src: &[T],
dst: &mut Dst,
) -> Result<(), result::DriverError> {
assert_eq!(src.len(), dst.len());
self.bind_to_thread()?;
unsafe { result::memcpy_htod_async(*dst.device_ptr_mut(), src, self.stream) }?;
self.synchronize()
}
#[allow(clippy::uninit_vec)]
pub fn dtoh_sync_copy<T: DeviceRepr>(
self: &Arc<Self>,
src: &CudaSlice<T>,
) -> Result<Vec<T>, result::DriverError> {
let mut dst = Vec::with_capacity(src.len());
unsafe { dst.set_len(src.len()) };
self.dtoh_sync_copy_into(src, &mut dst)?;
Ok(dst)
}
pub fn dtoh_sync_copy_into<T: DeviceRepr, Src: DevicePtr<T>>(
self: &Arc<Self>,
src: &Src,
dst: &mut [T],
) -> Result<(), result::DriverError> {
assert_eq!(src.len(), dst.len());
self.bind_to_thread()?;
unsafe { result::memcpy_dtoh_async(dst, *src.device_ptr(), self.stream) }?;
self.synchronize()
}
pub fn sync_reclaim<T: Clone + Default + DeviceRepr + Unpin>(
self: &Arc<Self>,
mut src: CudaSlice<T>,
) -> Result<Vec<T>, result::DriverError> {
let buf = src.host_buf.take();
let mut buf = buf.unwrap_or_else(|| {
let mut b = Vec::with_capacity(src.len);
b.resize(src.len, Default::default());
Pin::new(b)
});
self.dtoh_sync_copy_into(&src, &mut buf)?;
Ok(Pin::into_inner(buf))
}
pub fn synchronize(self: &Arc<Self>) -> Result<(), result::DriverError> {
self.bind_to_thread()?;
unsafe { result::stream::synchronize(self.stream) }
}
}
pub unsafe trait ValidAsZeroBits {}
unsafe impl ValidAsZeroBits for bool {}
unsafe impl ValidAsZeroBits for i8 {}
unsafe impl ValidAsZeroBits for i16 {}
unsafe impl ValidAsZeroBits for i32 {}
unsafe impl ValidAsZeroBits for i64 {}
unsafe impl ValidAsZeroBits for i128 {}
unsafe impl ValidAsZeroBits for isize {}
unsafe impl ValidAsZeroBits for u8 {}
unsafe impl ValidAsZeroBits for u16 {}
unsafe impl ValidAsZeroBits for u32 {}
unsafe impl ValidAsZeroBits for u64 {}
unsafe impl ValidAsZeroBits for u128 {}
unsafe impl ValidAsZeroBits for usize {}
unsafe impl ValidAsZeroBits for f32 {}
unsafe impl ValidAsZeroBits for f64 {}
#[cfg(feature = "f16")]
unsafe impl ValidAsZeroBits for half::f16 {}
#[cfg(feature = "f16")]
unsafe impl ValidAsZeroBits for half::bf16 {}
unsafe impl<T: ValidAsZeroBits, const M: usize> ValidAsZeroBits for [T; M] {}
macro_rules! impl_tuples {
($t:tt) => {
impl_tuples!(@ $t);
};
($l:tt $(,$t:tt)+) => {
impl_tuples!($($t),+);
impl_tuples!(@ $l $(,$t)+);
};
(@ $($t:tt),+) => {
unsafe impl<$($t: ValidAsZeroBits,)+> ValidAsZeroBits for ($($t,)+) {}
};
}
impl_tuples!(A, B, C, D, E, F, G, H, I, J, K, L);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_post_build_arc_count() {
let device = CudaDevice::new(0).unwrap();
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_alloc_arc_counts() {
let device = CudaDevice::new(0).unwrap();
let t = device.alloc_zeros::<f32>(1).unwrap();
assert!(t.host_buf.is_none());
assert_eq!(Arc::strong_count(&device), 2);
}
#[test]
fn test_post_take_arc_counts() {
let device = CudaDevice::new(0).unwrap();
let t = device.htod_copy([0.0f32; 5].to_vec()).unwrap();
assert!(t.host_buf.is_some());
assert_eq!(Arc::strong_count(&device), 2);
drop(t);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_clone_counts() {
let device = CudaDevice::new(0).unwrap();
let t = device.htod_copy([0.0f64; 10].to_vec()).unwrap();
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
drop(t);
assert_eq!(Arc::strong_count(&device), 2);
drop(r);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_clone_arc_slice_counts() {
let device = CudaDevice::new(0).unwrap();
let t = Arc::new(device.htod_copy::<f64>([0.0; 10].to_vec()).unwrap());
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 2);
drop(t);
assert_eq!(Arc::strong_count(&device), 2);
drop(r);
assert_eq!(Arc::strong_count(&device), 1);
}
#[test]
fn test_post_release_counts() {
let device = CudaDevice::new(0).unwrap();
let t = device.htod_copy([1.0f32, 2.0, 3.0].to_vec()).unwrap();
#[allow(clippy::redundant_clone)]
let r = t.clone();
assert_eq!(Arc::strong_count(&device), 3);
let r_host = device.sync_reclaim(r).unwrap();
assert_eq!(&r_host, &[1.0, 2.0, 3.0]);
assert_eq!(Arc::strong_count(&device), 2);
drop(r_host);
assert_eq!(Arc::strong_count(&device), 2);
}
#[test]
#[ignore = "must be executed by itself"]
fn test_post_alloc_memory() {
let device = CudaDevice::new(0).unwrap();
let (free1, total1) = result::mem_get_info().unwrap();
let t = device.htod_copy([0.0f32; 5].to_vec()).unwrap();
let (free2, total2) = result::mem_get_info().unwrap();
assert_eq!(total1, total2);
assert!(free2 < free1);
drop(t);
device.synchronize().unwrap();
let (free3, total3) = result::mem_get_info().unwrap();
assert_eq!(total2, total3);
assert!(free3 > free2);
assert_eq!(free3, free1);
}
#[test]
fn test_device_copy_to_views() {
let dev = CudaDevice::new(0).unwrap();
let smalls = [
dev.htod_copy(std::vec![-1.0f32, -0.8]).unwrap(),
dev.htod_copy(std::vec![-0.6, -0.4]).unwrap(),
dev.htod_copy(std::vec![-0.2, 0.0]).unwrap(),
dev.htod_copy(std::vec![0.2, 0.4]).unwrap(),
dev.htod_copy(std::vec![0.6, 0.8]).unwrap(),
];
let mut big = dev.alloc_zeros::<f32>(10).unwrap();
let mut offset = 0;
for small in smalls.iter() {
let mut sub = big.try_slice_mut(offset..offset + small.len()).unwrap();
dev.dtod_copy(small, &mut sub).unwrap();
offset += small.len();
}
assert_eq!(
dev.sync_reclaim(big).unwrap(),
[-1.0, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
);
}
#[test]
fn test_leak_and_upgrade() {
let dev = CudaDevice::new(0).unwrap();
let a = dev
.htod_copy(std::vec![1.0f32, 2.0, 3.0, 4.0, 5.0])
.unwrap();
let ptr = a.leak();
let b = unsafe { dev.upgrade_device_ptr::<f32>(ptr, 3) };
assert_eq!(dev.dtoh_sync_copy(&b).unwrap(), &[1.0, 2.0, 3.0]);
let ptr = b.leak();
let c = unsafe { dev.upgrade_device_ptr::<f32>(ptr, 5) };
assert_eq!(dev.dtoh_sync_copy(&c).unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_slice_is_freed_with_correct_context() {
let dev0 = CudaDevice::new(0).unwrap();
let slice = dev0.htod_copy(vec![1.0; 10]).unwrap();
let dev1 = CudaDevice::new(1).unwrap();
drop(dev1);
drop(dev0);
drop(slice);
}
#[test]
fn test_copy_uses_correct_context() {
let dev0 = CudaDevice::new(0).unwrap();
let _dev1 = CudaDevice::new(1).unwrap();
let slice = dev0.htod_copy(vec![1.0; 10]).unwrap();
let _out = dev0.dtoh_sync_copy(&slice).unwrap();
}
}