use crate::error::{CudaResult, ToResult};
use crate::memory::device::AsyncCopyDestination;
use crate::memory::device::{CopyDestination, DeviceBuffer};
use crate::memory::DeviceCopy;
use crate::memory::DevicePointer;
use crate::stream::Stream;
use std::iter::{ExactSizeIterator, FusedIterator};
use std::mem;
use std::ops::{
Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
use std::os::raw::c_void;
use std::slice::{self, Chunks, ChunksMut};
#[derive(Debug)]
#[repr(C)]
pub struct DeviceSlice<T>([T]);
impl<T> DeviceSlice<T> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn as_ptr(&self) -> *const T {
self.0.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.0.as_mut_ptr()
}
pub fn split_at(&self, mid: usize) -> (&DeviceSlice<T>, &DeviceSlice<T>) {
let (left, right) = self.0.split_at(mid);
unsafe {
(
DeviceSlice::from_slice(left),
DeviceSlice::from_slice(right),
)
}
}
pub fn split_at_mut(&mut self, mid: usize) -> (&mut DeviceSlice<T>, &mut DeviceSlice<T>) {
let (left, right) = self.0.split_at_mut(mid);
unsafe {
(
DeviceSlice::from_slice_mut(left),
DeviceSlice::from_slice_mut(right),
)
}
}
pub fn chunks(&self, chunk_size: usize) -> DeviceChunks<T> {
DeviceChunks(self.0.chunks(chunk_size))
}
pub fn chunks_mut(&mut self, chunk_size: usize) -> DeviceChunksMut<T> {
DeviceChunksMut(self.0.chunks_mut(chunk_size))
}
pub(super) unsafe fn from_slice(slice: &[T]) -> &DeviceSlice<T> {
&*(slice as *const [T] as *const DeviceSlice<T>)
}
pub(super) unsafe fn from_slice_mut(slice: &mut [T]) -> &mut DeviceSlice<T> {
&mut *(slice as *mut [T] as *mut DeviceSlice<T>)
}
pub fn as_device_ptr(&mut self) -> DevicePointer<T> {
unsafe { DevicePointer::wrap(self.0.as_mut_ptr()) }
}
#[allow(clippy::needless_pass_by_value)]
pub unsafe fn from_raw_parts<'a>(data: DevicePointer<T>, len: usize) -> &'a DeviceSlice<T> {
DeviceSlice::from_slice(slice::from_raw_parts(data.as_raw(), len))
}
pub unsafe fn from_raw_parts_mut<'a>(
mut data: DevicePointer<T>,
len: usize,
) -> &'a mut DeviceSlice<T> {
DeviceSlice::from_slice_mut(slice::from_raw_parts_mut(data.as_raw_mut(), len))
}
}
#[derive(Debug, Clone)]
pub struct DeviceChunks<'a, T: 'a>(Chunks<'a, T>);
impl<'a, T> Iterator for DeviceChunks<'a, T> {
type Item = &'a DeviceSlice<T>;
fn next(&mut self) -> Option<&'a DeviceSlice<T>> {
self.0
.next()
.map(|slice| unsafe { DeviceSlice::from_slice(slice) })
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
fn count(self) -> usize {
self.0.len()
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.0
.nth(n)
.map(|slice| unsafe { DeviceSlice::from_slice(slice) })
}
#[inline]
fn last(self) -> Option<Self::Item> {
self.0
.last()
.map(|slice| unsafe { DeviceSlice::from_slice(slice) })
}
}
impl<'a, T> DoubleEndedIterator for DeviceChunks<'a, T> {
#[inline]
fn next_back(&mut self) -> Option<&'a DeviceSlice<T>> {
self.0
.next_back()
.map(|slice| unsafe { DeviceSlice::from_slice(slice) })
}
}
impl<'a, T> ExactSizeIterator for DeviceChunks<'a, T> {}
impl<'a, T> FusedIterator for DeviceChunks<'a, T> {}
#[derive(Debug)]
pub struct DeviceChunksMut<'a, T: 'a>(ChunksMut<'a, T>);
impl<'a, T> Iterator for DeviceChunksMut<'a, T> {
type Item = &'a mut DeviceSlice<T>;
fn next(&mut self) -> Option<&'a mut DeviceSlice<T>> {
self.0
.next()
.map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
fn count(self) -> usize {
self.0.len()
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.0
.nth(n)
.map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
}
#[inline]
fn last(self) -> Option<Self::Item> {
self.0
.last()
.map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
}
}
impl<'a, T> DoubleEndedIterator for DeviceChunksMut<'a, T> {
#[inline]
fn next_back(&mut self) -> Option<&'a mut DeviceSlice<T>> {
self.0
.next_back()
.map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
}
}
impl<'a, T> ExactSizeIterator for DeviceChunksMut<'a, T> {}
impl<'a, T> FusedIterator for DeviceChunksMut<'a, T> {}
macro_rules! impl_index {
($($t:ty)*) => {
$(
impl<T> Index<$t> for DeviceSlice<T>
{
type Output = DeviceSlice<T>;
fn index(&self, index: $t) -> &Self {
unsafe { DeviceSlice::from_slice(self.0.index(index)) }
}
}
impl<T> IndexMut<$t> for DeviceSlice<T>
{
fn index_mut(&mut self, index: $t) -> &mut Self {
unsafe { DeviceSlice::from_slice_mut( self.0.index_mut(index)) }
}
}
)*
}
}
impl_index! {
Range<usize>
RangeFull
RangeFrom<usize>
RangeInclusive<usize>
RangeTo<usize>
RangeToInclusive<usize>
}
impl<T> crate::private::Sealed for DeviceSlice<T> {}
impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for DeviceSlice<T> {
fn copy_from(&mut self, val: &I) -> CudaResult<()> {
let val = val.as_ref();
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
unsafe {
cuda_driver_sys::cuMemcpyHtoD_v2(
self.0.as_mut_ptr() as u64,
val.as_ptr() as *const c_void,
size,
)
.to_result()?
}
}
Ok(())
}
fn copy_to(&self, val: &mut I) -> CudaResult<()> {
let val = val.as_mut();
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
unsafe {
cuda_driver_sys::cuMemcpyDtoH_v2(
val.as_mut_ptr() as *mut c_void,
self.as_ptr() as u64,
size,
)
.to_result()?
}
}
Ok(())
}
}
impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
fn copy_from(&mut self, val: &DeviceSlice<T>) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
unsafe {
cuda_driver_sys::cuMemcpyDtoD_v2(
self.0.as_mut_ptr() as u64,
val.as_ptr() as u64,
size,
)
.to_result()?
}
}
Ok(())
}
fn copy_to(&self, val: &mut DeviceSlice<T>) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
unsafe {
cuda_driver_sys::cuMemcpyDtoD_v2(
val.as_mut_ptr() as u64,
self.as_ptr() as u64,
size,
)
.to_result()?
}
}
Ok(())
}
}
impl<T: DeviceCopy> CopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
fn copy_from(&mut self, val: &DeviceBuffer<T>) -> CudaResult<()> {
self.copy_from(val as &DeviceSlice<T>)
}
fn copy_to(&self, val: &mut DeviceBuffer<T>) -> CudaResult<()> {
self.copy_to(val as &mut DeviceSlice<T>)
}
}
impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
for DeviceSlice<T>
{
unsafe fn async_copy_from(&mut self, val: &I, stream: &Stream) -> CudaResult<()> {
let val = val.as_ref();
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
cuda_driver_sys::cuMemcpyHtoDAsync_v2(
self.0.as_mut_ptr() as u64,
val.as_ptr() as *const c_void,
size,
stream.as_inner(),
)
.to_result()?
}
Ok(())
}
unsafe fn async_copy_to(&self, val: &mut I, stream: &Stream) -> CudaResult<()> {
let val = val.as_mut();
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
cuda_driver_sys::cuMemcpyDtoHAsync_v2(
val.as_mut_ptr() as *mut c_void,
self.as_ptr() as u64,
size,
stream.as_inner(),
)
.to_result()?
}
Ok(())
}
}
impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
unsafe fn async_copy_from(&mut self, val: &DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
cuda_driver_sys::cuMemcpyDtoDAsync_v2(
self.0.as_mut_ptr() as u64,
val.as_ptr() as u64,
size,
stream.as_inner(),
)
.to_result()?
}
Ok(())
}
unsafe fn async_copy_to(&self, val: &mut DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
if size != 0 {
cuda_driver_sys::cuMemcpyDtoDAsync_v2(
val.as_mut_ptr() as u64,
self.as_ptr() as u64,
size,
stream.as_inner(),
)
.to_result()?
}
Ok(())
}
}
impl<T: DeviceCopy> AsyncCopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
unsafe fn async_copy_from(&mut self, val: &DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
self.async_copy_from(val as &DeviceSlice<T>, stream)
}
unsafe fn async_copy_to(&self, val: &mut DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
self.async_copy_to(val as &mut DeviceSlice<T>, stream)
}
}