use core::marker::PhantomData;
use std::ops::RangeBounds;
use std::sync::Arc;
use crate::driver::{result, sys};
use super::{
core::to_range, CudaContext, CudaEvent, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr,
DeviceSlice, DriverError, HostSlice, LaunchArgs, PushKernelArg, ValidAsZeroBits,
};
#[derive(Debug)]
pub struct UnifiedSlice<T> {
pub(crate) cu_device_ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) stream: Arc<CudaStream>,
pub(crate) event: CudaEvent,
pub(crate) attach_mode: sys::CUmemAttach_flags,
pub(crate) concurrent_managed_access: bool,
pub(crate) marker: PhantomData<*const T>,
}
unsafe impl<T> Send for UnifiedSlice<T> {}
unsafe impl<T> Sync for UnifiedSlice<T> {}
impl<T> Drop for UnifiedSlice<T> {
fn drop(&mut self) {
self.stream.ctx.record_err(self.event.synchronize());
self.stream
.ctx
.record_err(unsafe { result::memory_free(self.cu_device_ptr) });
}
}
#[derive(Debug, Copy, Clone)]
pub struct UnifiedView<'a, T> {
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) event: &'a CudaEvent,
pub(crate) stream: &'a Arc<CudaStream>,
pub(crate) attach_mode: sys::CUmemAttach_flags,
pub(crate) concurrent_managed_access: bool,
marker: PhantomData<&'a [T]>,
}
#[derive(Debug)]
pub struct UnifiedViewMut<'a, T> {
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) event: &'a CudaEvent,
pub(crate) stream: &'a Arc<CudaStream>,
pub(crate) attach_mode: sys::CUmemAttach_flags,
pub(crate) concurrent_managed_access: bool,
marker: PhantomData<&'a mut [T]>,
}
impl CudaContext {
pub unsafe fn alloc_unified<T: DeviceRepr>(
self: &Arc<Self>,
len: usize,
attach_global: bool,
) -> Result<UnifiedSlice<T>, DriverError> {
if self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY)? == 0 {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
}
let attach_mode = if attach_global {
sys::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL
} else {
sys::CUmemAttach_flags::CU_MEM_ATTACH_HOST
};
let cu_device_ptr = result::malloc_managed(len * std::mem::size_of::<T>(), attach_mode)?;
let concurrent_managed_access = self
.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)?
!= 0;
let stream = self.default_stream();
let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
Ok(UnifiedSlice {
cu_device_ptr,
len,
stream,
event,
attach_mode,
concurrent_managed_access,
marker: PhantomData,
})
}
}
impl<T> UnifiedSlice<T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn attach_mode(&self) -> sys::CUmemAttach_flags {
self.attach_mode
}
pub fn num_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
pub fn attach(
&mut self,
stream: &Arc<CudaStream>,
flags: sys::CUmemAttach_flags,
) -> Result<(), DriverError> {
self.event.synchronize()?;
self.stream = stream.clone();
self.attach_mode = flags;
unsafe {
result::stream::attach_mem_async(
self.stream.cu_stream,
self.cu_device_ptr,
self.num_bytes(),
self.attach_mode,
)
}
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010"
)))]
pub fn prefetch(&self) -> Result<(), DriverError> {
let location = match self.attach_mode {
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL
| sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
if !self.concurrent_managed_access {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
}
sys::CUmemLocation {
type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
#[cfg(not(feature = "cuda-13020"))]
id: self.stream.ctx.ordinal as i32,
#[cfg(feature = "cuda-13020")]
__bindgen_anon_1: sys::CUmemLocation_st__bindgen_ty_1 {
id: self.stream.ctx.ordinal as i32,
},
}
}
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
sys::CUmemLocation {
type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
#[cfg(not(feature = "cuda-13020"))]
id: 0, #[cfg(feature = "cuda-13020")]
__bindgen_anon_1: sys::CUmemLocation_st__bindgen_ty_1 { id: 0 },
}
}
};
unsafe {
result::mem_prefetch_async(
self.cu_device_ptr,
self.len * std::mem::size_of::<T>(),
location,
self.stream.cu_stream,
)
}
}
pub fn check_host_access(&self) -> Result<(), DriverError> {
match self.attach_mode {
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
}
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
}
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
self.stream.synchronize()?;
}
};
Ok(())
}
pub fn check_device_access(&self, stream: &CudaStream) -> Result<(), DriverError> {
check_device_access(
self.attach_mode,
&self.stream,
self.concurrent_managed_access,
stream,
)
}
}
fn check_device_access(
attach_mode: sys::CUmemAttach_flags,
owner_stream: &Arc<CudaStream>,
concurrent_managed_access: bool,
stream: &CudaStream,
) -> Result<(), DriverError> {
match attach_mode {
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
}
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
let concurrent_managed_access = if owner_stream.context() != stream.context() {
stream.context().attribute(
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
)? != 0
} else {
concurrent_managed_access
};
if !concurrent_managed_access {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
}
}
sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
if owner_stream.as_ref() != stream {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
}
}
};
Ok(())
}
impl<T> UnifiedSlice<T> {
pub fn as_view(&self) -> UnifiedView<'_, T> {
UnifiedView {
ptr: self.cu_device_ptr,
len: self.len,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
}
}
pub fn as_view_mut(&mut self) -> UnifiedViewMut<'_, T> {
UnifiedViewMut {
ptr: self.cu_device_ptr,
len: self.len,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
}
}
pub fn slice(&self, bounds: impl RangeBounds<usize>) -> UnifiedView<'_, T> {
self.as_view().slice(bounds)
}
pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<UnifiedView<'_, T>> {
self.as_view().try_slice(bounds)
}
pub fn slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> UnifiedViewMut<'_, T> {
self.try_slice_mut(bounds).unwrap()
}
pub fn try_slice_mut(
&mut self,
bounds: impl RangeBounds<usize>,
) -> Option<UnifiedViewMut<'_, T>> {
to_range(bounds, self.len).map(|(start, end)| UnifiedViewMut {
ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
})
}
pub fn split_at(&self, mid: usize) -> (UnifiedView<'_, T>, UnifiedView<'_, T>) {
self.try_split_at(mid).unwrap()
}
pub fn try_split_at(&self, mid: usize) -> Option<(UnifiedView<'_, T>, UnifiedView<'_, T>)> {
(mid <= self.len).then(|| {
let a = UnifiedView {
ptr: self.cu_device_ptr,
len: mid,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
let b = UnifiedView {
ptr: self.cu_device_ptr + (mid * std::mem::size_of::<T>()) as u64,
len: self.len - mid,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
(a, b)
})
}
pub fn split_at_mut(&mut self, mid: usize) -> (UnifiedViewMut<'_, T>, UnifiedViewMut<'_, T>) {
self.try_split_at_mut(mid).unwrap()
}
pub fn try_split_at_mut(
&mut self,
mid: usize,
) -> Option<(UnifiedViewMut<'_, T>, UnifiedViewMut<'_, T>)> {
let length = self.len;
(mid <= length).then(|| {
let a = UnifiedViewMut {
ptr: self.cu_device_ptr,
len: mid,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
let b = UnifiedViewMut {
ptr: self.cu_device_ptr + (mid * std::mem::size_of::<T>()) as u64,
len: self.len - mid,
event: &self.event,
stream: &self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
(a, b)
})
}
}
impl<T> DeviceSlice<T> for UnifiedSlice<T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl<T> DevicePtr<T> for UnifiedSlice<T> {
fn device_ptr<'a>(
&'a self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
stream.ctx.record_err(self.check_device_access(stream));
stream.ctx.record_err(stream.wait(&self.event));
(
self.cu_device_ptr,
super::SyncOnDrop::Record(Some((&self.event, stream))),
)
}
}
impl<T> DevicePtrMut<T> for UnifiedSlice<T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
stream.ctx.record_err(self.check_device_access(stream));
stream.ctx.record_err(stream.wait(&self.event));
(
self.cu_device_ptr,
super::SyncOnDrop::Record(Some((&self.event, stream))),
)
}
}
impl<T: ValidAsZeroBits> UnifiedSlice<T> {
pub fn as_slice(&self) -> Result<&[T], DriverError> {
self.check_host_access()?;
self.event.synchronize()?;
Ok(unsafe { std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len) })
}
pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
self.check_host_access()?;
self.event.synchronize()?;
Ok(unsafe { std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len) })
}
}
impl<T> HostSlice<T> for UnifiedSlice<T> {
fn len(&self) -> usize {
self.len
}
unsafe fn stream_synced_slice<'a>(
&'a self,
stream: &'a CudaStream,
) -> (&'a [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(self.check_device_access(stream));
stream.ctx.record_err(stream.wait(&self.event));
(
std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len),
super::SyncOnDrop::Record(Some((&self.event, stream))),
)
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (&'a mut [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(self.check_device_access(stream));
stream.ctx.record_err(stream.wait(&self.event));
(
std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len),
super::SyncOnDrop::Record(Some((&self.event, stream))),
)
}
}
unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b UnifiedSlice<T>> for LaunchArgs<'a> {
#[inline(always)]
fn arg(&mut self, arg: &'b UnifiedSlice<T>) -> &mut Self {
self.stream
.ctx
.record_err(arg.check_device_access(self.stream));
self.waits.push(&arg.event);
self.records.push(&arg.event);
self.args
.push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
self
}
}
unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b mut UnifiedSlice<T>> for LaunchArgs<'a> {
#[inline(always)]
fn arg(&mut self, arg: &'b mut UnifiedSlice<T>) -> &mut Self {
self.stream
.ctx
.record_err(arg.check_device_access(self.stream));
self.waits.push(&arg.event);
self.records.push(&arg.event);
self.args
.push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
self
}
}
impl<T> DeviceSlice<T> for UnifiedView<'_, T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
self.stream
}
}
impl<T> DeviceSlice<T> for UnifiedViewMut<'_, T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
self.stream
}
}
impl<T> DevicePtr<T> for UnifiedView<'_, T> {
fn device_ptr<'a>(
&'a self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
self.ptr,
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
}
impl<T> DevicePtr<T> for UnifiedViewMut<'_, T> {
fn device_ptr<'a>(
&'a self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
self.ptr,
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
}
impl<T> DevicePtrMut<T> for UnifiedViewMut<'_, T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
self.ptr,
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
}
impl<T> HostSlice<T> for UnifiedView<'_, T> {
fn len(&self) -> usize {
self.len
}
unsafe fn stream_synced_slice<'a>(
&'a self,
stream: &'a CudaStream,
) -> (&'a [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
std::slice::from_raw_parts(self.ptr as *const T, self.len),
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (&'a mut [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
std::slice::from_raw_parts_mut(self.ptr as *mut T, self.len),
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
}
impl<T> HostSlice<T> for UnifiedViewMut<'_, T> {
fn len(&self) -> usize {
self.len
}
unsafe fn stream_synced_slice<'a>(
&'a self,
stream: &'a CudaStream,
) -> (&'a [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
std::slice::from_raw_parts(self.ptr as *const T, self.len),
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (&'a mut [T], super::SyncOnDrop<'a>) {
stream.ctx.record_err(check_device_access(
self.attach_mode,
self.stream,
self.concurrent_managed_access,
stream,
));
stream.ctx.record_err(stream.wait(self.event));
(
std::slice::from_raw_parts_mut(self.ptr as *mut T, self.len),
super::SyncOnDrop::Record(Some((self.event, stream))),
)
}
}
unsafe impl<'a, 'b: 'a, 'c: 'b, T> PushKernelArg<&'b UnifiedView<'c, T>> for LaunchArgs<'a> {
#[inline(always)]
fn arg(&mut self, arg: &'b UnifiedView<'c, T>) -> &mut Self {
self.stream.ctx.record_err(check_device_access(
arg.attach_mode,
arg.stream,
arg.concurrent_managed_access,
self.stream,
));
self.waits.push(arg.event);
self.records.push(arg.event);
self.args.push((&arg.ptr) as *const sys::CUdeviceptr as _);
self
}
}
unsafe impl<'a, 'b: 'a, 'c: 'b, T> PushKernelArg<&'b UnifiedViewMut<'c, T>> for LaunchArgs<'a> {
#[inline(always)]
fn arg(&mut self, arg: &'b UnifiedViewMut<'c, T>) -> &mut Self {
self.stream.ctx.record_err(check_device_access(
arg.attach_mode,
arg.stream,
arg.concurrent_managed_access,
self.stream,
));
self.waits.push(arg.event);
self.records.push(arg.event);
self.args.push((&arg.ptr) as *const sys::CUdeviceptr as _);
self
}
}
unsafe impl<'a, 'b: 'a, 'c: 'b, T> PushKernelArg<&'b mut UnifiedViewMut<'c, T>> for LaunchArgs<'a> {
#[inline(always)]
fn arg(&mut self, arg: &'b mut UnifiedViewMut<'c, T>) -> &mut Self {
self.stream.ctx.record_err(check_device_access(
arg.attach_mode,
arg.stream,
arg.concurrent_managed_access,
self.stream,
));
self.waits.push(arg.event);
self.records.push(arg.event);
self.args.push((&arg.ptr) as *const sys::CUdeviceptr as _);
self
}
}
impl<'a, T> UnifiedView<'a, T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn slice(&self, bounds: impl RangeBounds<usize>) -> Self {
self.try_slice(bounds).unwrap()
}
pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<Self> {
to_range(bounds, self.len).map(|(start, end)| UnifiedView {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
})
}
pub fn split_at(&self, mid: usize) -> (Self, Self) {
self.try_split_at(mid).unwrap()
}
pub fn try_split_at(&self, mid: usize) -> Option<(Self, Self)> {
(mid <= self.len).then(|| {
let a = UnifiedView {
ptr: self.ptr,
len: mid,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
let b = UnifiedView {
ptr: self.ptr + (mid * std::mem::size_of::<T>()) as u64,
len: self.len - mid,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
(a, b)
})
}
}
impl<'a, T> UnifiedViewMut<'a, T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_view<'b>(&'b self) -> UnifiedView<'b, T> {
UnifiedView {
ptr: self.ptr,
len: self.len,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
}
}
pub fn slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> UnifiedView<'b, T> {
self.try_slice(bounds).unwrap()
}
pub fn try_slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> Option<UnifiedView<'b, T>> {
to_range(bounds, self.len).map(|(start, end)| UnifiedView {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
})
}
pub fn slice_mut<'b>(&'b mut self, bounds: impl RangeBounds<usize>) -> UnifiedViewMut<'b, T> {
self.try_slice_mut(bounds).unwrap()
}
pub fn try_slice_mut<'b>(
&'b mut self,
bounds: impl RangeBounds<usize>,
) -> Option<UnifiedViewMut<'b, T>> {
to_range(bounds, self.len).map(|(start, end)| UnifiedViewMut {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
})
}
pub fn split_at_mut<'b>(
&'b mut self,
mid: usize,
) -> (UnifiedViewMut<'b, T>, UnifiedViewMut<'b, T>) {
self.try_split_at_mut(mid).unwrap()
}
pub fn try_split_at_mut<'b>(
&'b mut self,
mid: usize,
) -> Option<(UnifiedViewMut<'b, T>, UnifiedViewMut<'b, T>)> {
(mid <= self.len()).then(|| {
let a = UnifiedViewMut {
ptr: self.ptr,
len: mid,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
let b = UnifiedViewMut {
ptr: self.ptr + (mid * std::mem::size_of::<T>()) as u64,
len: self.len - mid,
event: self.event,
stream: self.stream,
attach_mode: self.attach_mode,
concurrent_managed_access: self.concurrent_managed_access,
marker: PhantomData,
};
(a, b)
})
}
}
#[cfg(feature = "nvrtc")]
#[cfg(test)]
mod tests {
#![allow(clippy::needless_range_loop)]
use crate::driver::{LaunchConfig, PushKernelArg};
use super::*;
#[test]
fn test_unified_memory_global() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
{
let buf = a.as_mut_slice()?;
for i in 0..100 {
buf[i] = i as f32;
}
}
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let ptx = crate::nvrtc::compile_ptx(
"
extern \"C\" __global__ void kernel(float *buf) {
if (threadIdx.x < 100) {
assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
}
}",
)
.unwrap();
let module = ctx.load_module(ptx)?;
let f = module.load_function("kernel")?;
let stream1 = ctx.default_stream();
unsafe {
stream1
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}?;
stream1.synchronize()?;
let stream2 = ctx.new_stream()?;
unsafe {
stream2
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}?;
stream2.synchronize()?;
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let vs = stream1.clone_dtoh(&a)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
let b = stream1.clone_htod(&a)?;
let vs = stream1.clone_dtoh(&b)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
stream1.memset_zeros(&mut a)?;
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], 0.0);
}
}
Ok(())
}
#[test]
fn test_unified_memory_host() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let mut a = unsafe { ctx.alloc_unified::<f32>(100, false) }?;
{
let buf = a.as_mut_slice()?;
for i in 0..100 {
buf[i] = i as f32;
}
}
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let ptx = crate::nvrtc::compile_ptx(
"
extern \"C\" __global__ void kernel(float *buf) {
if (threadIdx.x < 100) {
assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
}
}",
)
.unwrap();
let module = ctx.load_module(ptx)?;
let f = module.load_function("kernel")?;
let stream1 = ctx.default_stream();
unsafe {
stream1
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}?;
stream1.synchronize()?;
let stream2 = ctx.new_stream()?;
unsafe {
stream2
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}?;
stream2.synchronize()?;
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let vs = stream1.clone_dtoh(&a)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
let b = stream1.clone_htod(&a)?;
let vs = stream1.clone_dtoh(&b)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
stream1.memset_zeros(&mut a)?;
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], 0.0);
}
}
Ok(())
}
#[test]
fn test_unified_memory_single_stream() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
{
let buf = a.as_mut_slice()?;
for i in 0..100 {
buf[i] = i as f32;
}
}
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let ptx = crate::nvrtc::compile_ptx(
"
extern \"C\" __global__ void kernel(float *buf) {
if (threadIdx.x < 100) {
assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
}
}",
)
.unwrap();
let module = ctx.load_module(ptx)?;
let f = module.load_function("kernel")?;
let stream2 = ctx.new_stream()?;
a.attach(&stream2, sys::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE)?;
unsafe {
stream2
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}?;
stream2.synchronize()?;
let stream1 = ctx.default_stream();
unsafe {
stream1
.launch_builder(&f)
.arg(&mut a)
.launch(LaunchConfig::for_num_elems(100))
}
.expect_err("Other stream access should've failed");
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], i as f32);
}
}
let vs = stream2.clone_dtoh(&a)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
let b = stream2.clone_htod(&a)?;
let vs = stream2.clone_dtoh(&b)?;
for i in 0..100 {
assert_eq!(vs[i], i as f32);
}
stream2.memset_zeros(&mut a)?;
{
let buf = a.as_slice()?;
for i in 0..100 {
assert_eq!(buf[i], 0.0);
}
}
Ok(())
}
#[test]
fn test_unified_slice_copy_to_views() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let mut smalls = [
unsafe { ctx.alloc_unified::<f32>(2, true) }?,
unsafe { ctx.alloc_unified::<f32>(2, true) }?,
unsafe { ctx.alloc_unified::<f32>(2, true) }?,
unsafe { ctx.alloc_unified::<f32>(2, true) }?,
unsafe { ctx.alloc_unified::<f32>(2, true) }?,
];
{
let buf = smalls[0].as_mut_slice()?;
buf[0] = -1.0;
buf[1] = -0.8;
}
{
let buf = smalls[1].as_mut_slice()?;
buf[0] = -0.6;
buf[1] = -0.4;
}
{
let buf = smalls[2].as_mut_slice()?;
buf[0] = -0.2;
buf[1] = 0.0;
}
{
let buf = smalls[3].as_mut_slice()?;
buf[0] = 0.2;
buf[1] = 0.4;
}
{
let buf = smalls[4].as_mut_slice()?;
buf[0] = 0.6;
buf[1] = 0.8;
}
let mut big = unsafe { ctx.alloc_unified::<f32>(10, true) }?;
stream.memset_zeros(&mut big)?;
let mut offset = 0;
for small in smalls.iter() {
let mut sub = big.slice_mut(offset..offset + small.len());
stream.memcpy_dtod(small, &mut sub)?;
offset += small.len();
}
stream.synchronize()?;
let result = stream.clone_dtoh(&big)?;
assert_eq!(
result,
[-1.0, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
);
Ok(())
}
#[test]
fn test_unified_slice_split_at() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let mut unified = unsafe { ctx.alloc_unified::<f32>(10, true) }?;
{
let buf = unified.as_mut_slice()?;
for i in 0..10 {
buf[i] = i as f32;
}
}
let (left, right) = unified.split_at(5);
assert_eq!(left.len(), 5);
assert_eq!(right.len(), 5);
let left_data = stream.clone_dtoh(&left)?;
let right_data = stream.clone_dtoh(&right)?;
assert_eq!(left_data, [0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(right_data, [5.0, 6.0, 7.0, 8.0, 9.0]);
let (mut left_mut, right_mut) = unified.split_at_mut(5);
assert_eq!(left_mut.len(), 5);
assert_eq!(right_mut.len(), 5);
let zeros = std::vec![0.0f32; 5];
stream.memcpy_htod(&zeros, &mut left_mut)?;
stream.synchronize()?;
let result = stream.clone_dtoh(&unified)?;
assert_eq!(result, [0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
Ok(())
}
#[test]
fn test_unified_slice_views_respect_stream_attachment() -> Result<(), DriverError> {
let ctx = CudaContext::new(0)?;
let mut unified = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
{
let buf = unified.as_mut_slice()?;
for i in 0..100 {
buf[i] = i as f32;
}
}
let stream1 = ctx.default_stream();
let stream2 = ctx.new_stream()?;
let view1 = unified.slice(0..50);
let data1 = stream1.clone_dtoh(&view1)?;
assert_eq!(data1[0], 0.0);
assert_eq!(data1[49], 49.0);
let view2 = unified.slice(50..100);
let data2 = stream2.clone_dtoh(&view2)?;
assert_eq!(data2[0], 50.0);
assert_eq!(data2[49], 99.0);
let mut view_mut = unified.slice_mut(10..20);
let write_data = std::vec![999.0f32; 10];
stream1.memcpy_htod(&write_data, &mut view_mut)?;
stream1.synchronize()?;
let verify_data = stream1.clone_dtoh(&unified)?;
for i in 0..10 {
assert_eq!(
verify_data[i], i as f32,
"Data before write range should be unchanged"
);
}
for i in 10..20 {
assert_eq!(verify_data[i], 999.0, "Data in write range should be 999.0");
}
for i in 20..100 {
assert_eq!(
verify_data[i], i as f32,
"Data after write range should be unchanged"
);
}
unified.attach(&stream2, sys::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE)?;
let view_single = unified.slice(0..50);
let data_ok = stream2.clone_dtoh(&view_single)?;
assert_eq!(data_ok[0], 0.0);
let _ = stream1.clone_dtoh(&view_single);
assert!(
ctx.check_err().is_err(),
"Expected error to be recorded when accessing SINGLE mode view from wrong stream"
);
let mut view_single_mut = unified.slice_mut(30..40);
let write_data2 = std::vec![777.0f32; 10];
stream2.memcpy_htod(&write_data2, &mut view_single_mut)?;
stream2.synchronize()?;
let verify_data2 = stream2.clone_dtoh(&unified)?;
for i in 30..40 {
assert_eq!(
verify_data2[i], 777.0,
"Data written through SINGLE mode view should be 777.0"
);
}
let mut view_wrong_stream = unified.slice_mut(40..50);
let _ = stream1.memcpy_htod(&write_data2, &mut view_wrong_stream);
assert!(
ctx.check_err().is_err(),
"Expected error to be recorded when writing to SINGLE mode view from wrong stream"
);
Ok(())
}
}