fil_rustacuda/memory/device/
device_slice.rs

1use crate::error::{CudaResult, ToResult};
2use crate::memory::device::AsyncCopyDestination;
3use crate::memory::device::{CopyDestination, DeviceBuffer};
4use crate::memory::DeviceCopy;
5use crate::memory::DevicePointer;
6use crate::stream::Stream;
7use std::iter::{ExactSizeIterator, FusedIterator};
8use std::mem;
9use std::ops::{
10    Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
11};
12
13use std::os::raw::c_void;
14use std::slice::{self, Chunks, ChunksMut};
15
16/// Fixed-size device-side slice.
17#[derive(Debug)]
18#[repr(C)]
19pub struct DeviceSlice<T>([T]);
20// This works by faking a regular slice out of the device raw-pointer and the length and transmuting
21// I have no idea if this is safe or not. Probably not, though I can't imagine how the compiler
22// could possibly know that the pointer is not de-referenceable. I'm banking that we get proper
23// Dynamicaly-sized Types before the compiler authors break this assumption.
24impl<T> DeviceSlice<T> {
25    /// Returns the number of elements in the slice.
26    ///
27    /// # Examples
28    ///
29    /// ```
30    /// # let _context = rustacuda::quick_init().unwrap();
31    /// use rustacuda::memory::*;
32    /// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
33    /// assert_eq!(a.len(), 3);
34    /// ```
35    pub fn len(&self) -> usize {
36        self.0.len()
37    }
38
39    /// Returns `true` if the slice has a length of 0.
40    ///
41    /// # Examples
42    ///
43    /// ```
44    /// # let _context = rustacuda::quick_init().unwrap();
45    /// use rustacuda::memory::*;
46    /// let a : DeviceBuffer<u64> = unsafe { DeviceBuffer::uninitialized(0).unwrap() };
47    /// assert!(a.is_empty());
48    /// ```
49    pub fn is_empty(&self) -> bool {
50        self.0.is_empty()
51    }
52
53    /// Return a raw device-pointer to the slice's buffer.
54    ///
55    /// The caller must ensure that the slice outlives the pointer this function returns, or else
56    /// it will end up pointing to garbage. The caller must also ensure that the pointer is not
57    /// dereferenced by the CPU.
58    ///
59    /// Examples:
60    ///
61    /// ```
62    /// # let _context = rustacuda::quick_init().unwrap();
63    /// use rustacuda::memory::*;
64    /// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
65    /// println!("{:p}", a.as_ptr());
66    /// ```
67    pub fn as_ptr(&self) -> *const T {
68        self.0.as_ptr()
69    }
70
71    /// Returns an unsafe mutable device-pointer to the slice's buffer.
72    ///
73    /// The caller must ensure that the slice outlives the pointer this function returns, or else
74    /// it will end up pointing to garbage. The caller must also ensure that the pointer is not
75    /// dereferenced by the CPU.
76    ///
77    /// Examples:
78    ///
79    /// ```
80    /// # let _context = rustacuda::quick_init().unwrap();
81    /// use rustacuda::memory::*;
82    /// let mut a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
83    /// println!("{:p}", a.as_mut_ptr());
84    /// ```
85    pub fn as_mut_ptr(&mut self) -> *mut T {
86        self.0.as_mut_ptr()
87    }
88
89    /// Divides one DeviceSlice into two at a given index.
90    ///
91    /// The first will contain all indices from `[0, mid)` (excluding the index `mid` itself) and
92    /// the second will contain all indices from `[mid, len)` (excluding the index `len` itself).
93    ///
94    /// # Panics
95    ///
96    /// Panics if `min > len`.
97    ///
98    /// Examples:
99    ///
100    /// ```
101    /// # let _context = rustacuda::quick_init().unwrap();
102    /// use rustacuda::memory::*;
103    /// let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
104    /// let (left, right) = buf.split_at(3);
105    /// let mut left_host = [0u64, 0, 0];
106    /// let mut right_host = [0u64, 0, 0];
107    /// left.copy_to(&mut left_host).unwrap();
108    /// right.copy_to(&mut right_host).unwrap();
109    /// assert_eq!([0u64, 1, 2], left_host);
110    /// assert_eq!([3u64, 4, 5], right_host);
111    /// ```
112    pub fn split_at(&self, mid: usize) -> (&DeviceSlice<T>, &DeviceSlice<T>) {
113        let (left, right) = self.0.split_at(mid);
114        unsafe {
115            (
116                DeviceSlice::from_slice(left),
117                DeviceSlice::from_slice(right),
118            )
119        }
120    }
121
122    /// Divides one mutable DeviceSlice into two at a given index.
123    ///
124    /// The first will contain all indices from `[0, mid)` (excluding the index `mid` itself) and
125    /// the second will contain all indices from `[mid, len)` (excluding the index `len` itself).
126    ///
127    /// # Panics
128    ///
129    /// Panics if `min > len`.
130    ///
131    /// Examples:
132    ///
133    /// ```
134    /// # let _context = rustacuda::quick_init().unwrap();
135    /// use rustacuda::memory::*;
136    /// let mut buf = DeviceBuffer::from_slice(&[0u64, 0, 0, 0, 0, 0]).unwrap();
137    ///
138    /// {
139    ///     let (left, right) = buf.split_at_mut(3);
140    ///     let left_host = [0u64, 1, 2];
141    ///     let right_host = [3u64, 4, 5];
142    ///     left.copy_from(&left_host).unwrap();
143    ///     right.copy_from(&right_host).unwrap();
144    /// }
145    ///
146    /// let mut host_full = [0u64; 6];
147    /// buf.copy_to(&mut host_full).unwrap();
148    /// assert_eq!([0u64, 1, 2, 3, 4, 5], host_full);
149    /// ```
150    pub fn split_at_mut(&mut self, mid: usize) -> (&mut DeviceSlice<T>, &mut DeviceSlice<T>) {
151        let (left, right) = self.0.split_at_mut(mid);
152        unsafe {
153            (
154                DeviceSlice::from_slice_mut(left),
155                DeviceSlice::from_slice_mut(right),
156            )
157        }
158    }
159
160    /// Returns an iterator over `chunk_size` elements of the slice at a time. The chunks are device
161    /// slices and do not overlap. If `chunk_size` does not divide the length of the slice, then the
162    /// last chunk will not have length `chunk_size`.
163    ///
164    /// See `exact_chunks` for a variant of this iterator that returns chunks of always exactly
165    /// `chunk_size` elements.
166    ///
167    /// # Panics
168    ///
169    /// Panics if `chunk_size` is 0.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// # let _context = rustacuda::quick_init().unwrap();
175    /// use rustacuda::memory::*;
176    /// let slice = DeviceBuffer::from_slice(&[1u64, 2, 3, 4, 5]).unwrap();
177    /// let mut iter = slice.chunks(2);
178    ///
179    /// assert_eq!(iter.next().unwrap().len(), 2);
180    ///
181    /// let mut host_buf = [0u64, 0];
182    /// iter.next().unwrap().copy_to(&mut host_buf).unwrap();
183    /// assert_eq!([3, 4], host_buf);
184    ///
185    /// assert_eq!(iter.next().unwrap().len(), 1);
186    ///
187    /// ```
188    pub fn chunks(&self, chunk_size: usize) -> DeviceChunks<T> {
189        DeviceChunks(self.0.chunks(chunk_size))
190    }
191
192    /// Returns an iterator over `chunk_size` elements of the slice at a time. The chunks are
193    /// mutable device slices and do not overlap. If `chunk_size` does not divide the length of the
194    /// slice, then the last chunk will not have length `chunk_size`.
195    ///
196    /// See `exact_chunks` for a variant of this iterator that returns chunks of always exactly
197    /// `chunk_size` elements.
198    ///
199    /// # Panics
200    ///
201    /// Panics if `chunk_size` is 0.
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// # let _context = rustacuda::quick_init().unwrap();
207    /// use rustacuda::memory::*;
208    /// let mut slice = DeviceBuffer::from_slice(&[0u64, 0, 0, 0, 0]).unwrap();
209    /// {
210    ///     let mut iter = slice.chunks_mut(2);
211    ///
212    ///     assert_eq!(iter.next().unwrap().len(), 2);
213    ///
214    ///     let host_buf = [2u64, 3];
215    ///     iter.next().unwrap().copy_from(&host_buf).unwrap();
216    ///
217    ///     assert_eq!(iter.next().unwrap().len(), 1);
218    /// }
219    ///
220    /// let mut host_buf = [0u64, 0, 0, 0, 0];
221    /// slice.copy_to(&mut host_buf).unwrap();
222    /// assert_eq!([0u64, 0, 2, 3, 0], host_buf);
223    /// ```
224    pub fn chunks_mut(&mut self, chunk_size: usize) -> DeviceChunksMut<T> {
225        DeviceChunksMut(self.0.chunks_mut(chunk_size))
226    }
227
228    /// Private function used to transmute a CPU slice (which must have the device pointer as it's
229    /// buffer pointer) to a DeviceSlice. Completely unsafe.
230    pub(super) unsafe fn from_slice(slice: &[T]) -> &DeviceSlice<T> {
231        &*(slice as *const [T] as *const DeviceSlice<T>)
232    }
233
234    /// Private function used to transmute a mutable CPU slice (which must have the device pointer
235    /// as it's buffer pointer) to a mutable DeviceSlice. Completely unsafe.
236    pub(super) unsafe fn from_slice_mut(slice: &mut [T]) -> &mut DeviceSlice<T> {
237        &mut *(slice as *mut [T] as *mut DeviceSlice<T>)
238    }
239
240    /// Returns a `DevicePointer<T>` to the buffer.
241    ///
242    /// The caller must ensure that the buffer outlives the returned pointer, or it will end up
243    /// pointing to garbage.
244    ///
245    /// Modifying `DeviceBuffer` is guaranteed not to cause its buffer to be reallocated, so pointers
246    /// cannot be invalidated in that manner, but other types may be added in the future which can
247    /// reallocate.
248    pub fn as_device_ptr(&mut self) -> DevicePointer<T> {
249        unsafe { DevicePointer::wrap(self.0.as_mut_ptr()) }
250    }
251
252    /// Forms a slice from a `DevicePointer` and a length.
253    ///
254    /// The `len` argument is the number of _elements_, not the number of bytes.
255    ///
256    /// # Safety
257    ///
258    /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
259    /// elements, nor whether the lifetime inferred is a suitable lifetime for the returned slice.
260    ///
261    /// # Caveat
262    ///
263    /// The lifetime for the returned slice is inferred from its usage. To prevent accidental misuse,
264    /// it's suggested to tie the lifetime to whatever source lifetime is safe in the context, such
265    /// as by providing a helper function taking the lifetime of a host value for the slice or
266    /// by explicit annotation.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// # let _context = rustacuda::quick_init().unwrap();
272    /// use rustacuda::memory::*;
273    /// let mut x = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
274    /// // Manually slice the buffer (this is not recommended!)
275    /// let ptr = unsafe { x.as_device_ptr().offset(1) };
276    /// let slice = unsafe { DeviceSlice::from_raw_parts(ptr, 2) };
277    /// let mut host_buf = [0u64, 0];
278    /// slice.copy_to(&mut host_buf).unwrap();
279    /// assert_eq!([1u64, 2], host_buf);
280    /// ```
281    #[allow(clippy::needless_pass_by_value)]
282    pub unsafe fn from_raw_parts<'a>(data: DevicePointer<T>, len: usize) -> &'a DeviceSlice<T> {
283        DeviceSlice::from_slice(slice::from_raw_parts(data.as_raw(), len))
284    }
285
286    /// Performs the same functionality as `from_raw_parts`, except that a
287    /// mutable slice is returned.
288    ///
289    /// # Safety
290    ///
291    /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
292    /// elements, nor whether the lifetime inferred is a suitable lifetime for the returned slice.
293    ///
294    /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
295    /// elements, not whether the lifetime inferred is a suitable lifetime for the returned slice,
296    /// as well as not being able to provide a non-aliasing guarantee of the returned
297    /// mutable slice. `data` must be non-null and aligned even for zero-length
298    /// slices as with `from_raw_parts`.
299    ///
300    /// See the documentation of `from_raw_parts` for more details.
301    pub unsafe fn from_raw_parts_mut<'a>(
302        mut data: DevicePointer<T>,
303        len: usize,
304    ) -> &'a mut DeviceSlice<T> {
305        DeviceSlice::from_slice_mut(slice::from_raw_parts_mut(data.as_raw_mut(), len))
306    }
307}
308
309/// An iterator over a [`DeviceSlice`](struct.DeviceSlice.html) in (non-overlapping) chunks
310/// (`chunk_size` elements at a time).
311///
312/// When the slice len is not evenly divided by the chunk size, the last slice of the iteration will
313/// be the remainder.
314///
315/// This struct is created by the `chunks` method on `DeviceSlices`.
316#[derive(Debug, Clone)]
317pub struct DeviceChunks<'a, T: 'a>(Chunks<'a, T>);
318impl<'a, T> Iterator for DeviceChunks<'a, T> {
319    type Item = &'a DeviceSlice<T>;
320
321    fn next(&mut self) -> Option<&'a DeviceSlice<T>> {
322        self.0
323            .next()
324            .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
325    }
326
327    fn size_hint(&self) -> (usize, Option<usize>) {
328        self.0.size_hint()
329    }
330
331    fn count(self) -> usize {
332        self.0.len()
333    }
334
335    fn nth(&mut self, n: usize) -> Option<Self::Item> {
336        self.0
337            .nth(n)
338            .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
339    }
340
341    #[inline]
342    fn last(self) -> Option<Self::Item> {
343        self.0
344            .last()
345            .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
346    }
347}
348impl<'a, T> DoubleEndedIterator for DeviceChunks<'a, T> {
349    #[inline]
350    fn next_back(&mut self) -> Option<&'a DeviceSlice<T>> {
351        self.0
352            .next_back()
353            .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
354    }
355}
356impl<'a, T> ExactSizeIterator for DeviceChunks<'a, T> {}
357impl<'a, T> FusedIterator for DeviceChunks<'a, T> {}
358
359/// An iterator over a [`DeviceSlice`](struct.DeviceSlice.html) in (non-overlapping) mutable chunks
360/// (`chunk_size` elements at a time).
361///
362/// When the slice len is not evenly divided by the chunk size, the last slice of the iteration will
363/// be the remainder.
364///
365/// This struct is created by the `chunks` method on `DeviceSlices`.
366#[derive(Debug)]
367pub struct DeviceChunksMut<'a, T: 'a>(ChunksMut<'a, T>);
368impl<'a, T> Iterator for DeviceChunksMut<'a, T> {
369    type Item = &'a mut DeviceSlice<T>;
370
371    fn next(&mut self) -> Option<&'a mut DeviceSlice<T>> {
372        self.0
373            .next()
374            .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
375    }
376
377    fn size_hint(&self) -> (usize, Option<usize>) {
378        self.0.size_hint()
379    }
380
381    fn count(self) -> usize {
382        self.0.len()
383    }
384
385    fn nth(&mut self, n: usize) -> Option<Self::Item> {
386        self.0
387            .nth(n)
388            .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
389    }
390
391    #[inline]
392    fn last(self) -> Option<Self::Item> {
393        self.0
394            .last()
395            .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
396    }
397}
398impl<'a, T> DoubleEndedIterator for DeviceChunksMut<'a, T> {
399    #[inline]
400    fn next_back(&mut self) -> Option<&'a mut DeviceSlice<T>> {
401        self.0
402            .next_back()
403            .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
404    }
405}
406impl<'a, T> ExactSizeIterator for DeviceChunksMut<'a, T> {}
407impl<'a, T> FusedIterator for DeviceChunksMut<'a, T> {}
408
409macro_rules! impl_index {
410    ($($t:ty)*) => {
411        $(
412            impl<T> Index<$t> for DeviceSlice<T>
413            {
414                type Output = DeviceSlice<T>;
415
416                fn index(&self, index: $t) -> &Self {
417                    unsafe { DeviceSlice::from_slice(self.0.index(index)) }
418                }
419            }
420
421            impl<T> IndexMut<$t> for DeviceSlice<T>
422            {
423                fn index_mut(&mut self, index: $t) -> &mut Self {
424                    unsafe { DeviceSlice::from_slice_mut( self.0.index_mut(index)) }
425                }
426            }
427        )*
428    }
429}
430impl_index! {
431    Range<usize>
432    RangeFull
433    RangeFrom<usize>
434    RangeInclusive<usize>
435    RangeTo<usize>
436    RangeToInclusive<usize>
437}
438impl<T> crate::private::Sealed for DeviceSlice<T> {}
439impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for DeviceSlice<T> {
440    fn copy_from(&mut self, val: &I) -> CudaResult<()> {
441        let val = val.as_ref();
442        assert!(
443            self.len() == val.len(),
444            "destination and source slices have different lengths"
445        );
446        let size = mem::size_of::<T>() * self.len();
447        if size != 0 {
448            unsafe {
449                cuda_driver_sys::cuMemcpyHtoD_v2(
450                    self.0.as_mut_ptr() as u64,
451                    val.as_ptr() as *const c_void,
452                    size,
453                )
454                .to_result()?
455            }
456        }
457        Ok(())
458    }
459
460    fn copy_to(&self, val: &mut I) -> CudaResult<()> {
461        let val = val.as_mut();
462        assert!(
463            self.len() == val.len(),
464            "destination and source slices have different lengths"
465        );
466        let size = mem::size_of::<T>() * self.len();
467        if size != 0 {
468            unsafe {
469                cuda_driver_sys::cuMemcpyDtoH_v2(
470                    val.as_mut_ptr() as *mut c_void,
471                    self.as_ptr() as u64,
472                    size,
473                )
474                .to_result()?
475            }
476        }
477        Ok(())
478    }
479}
480impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
481    fn copy_from(&mut self, val: &DeviceSlice<T>) -> CudaResult<()> {
482        assert!(
483            self.len() == val.len(),
484            "destination and source slices have different lengths"
485        );
486        let size = mem::size_of::<T>() * self.len();
487        if size != 0 {
488            unsafe {
489                cuda_driver_sys::cuMemcpyDtoD_v2(
490                    self.0.as_mut_ptr() as u64,
491                    val.as_ptr() as u64,
492                    size,
493                )
494                .to_result()?
495            }
496        }
497        Ok(())
498    }
499
500    fn copy_to(&self, val: &mut DeviceSlice<T>) -> CudaResult<()> {
501        assert!(
502            self.len() == val.len(),
503            "destination and source slices have different lengths"
504        );
505        let size = mem::size_of::<T>() * self.len();
506        if size != 0 {
507            unsafe {
508                cuda_driver_sys::cuMemcpyDtoD_v2(
509                    val.as_mut_ptr() as u64,
510                    self.as_ptr() as u64,
511                    size,
512                )
513                .to_result()?
514            }
515        }
516        Ok(())
517    }
518}
519impl<T: DeviceCopy> CopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
520    fn copy_from(&mut self, val: &DeviceBuffer<T>) -> CudaResult<()> {
521        self.copy_from(val as &DeviceSlice<T>)
522    }
523
524    fn copy_to(&self, val: &mut DeviceBuffer<T>) -> CudaResult<()> {
525        self.copy_to(val as &mut DeviceSlice<T>)
526    }
527}
528impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
529    for DeviceSlice<T>
530{
531    unsafe fn async_copy_from(&mut self, val: &I, stream: &Stream) -> CudaResult<()> {
532        let val = val.as_ref();
533        assert!(
534            self.len() == val.len(),
535            "destination and source slices have different lengths"
536        );
537        let size = mem::size_of::<T>() * self.len();
538        if size != 0 {
539            cuda_driver_sys::cuMemcpyHtoDAsync_v2(
540                self.0.as_mut_ptr() as u64,
541                val.as_ptr() as *const c_void,
542                size,
543                stream.as_inner(),
544            )
545            .to_result()?
546        }
547        Ok(())
548    }
549
550    unsafe fn async_copy_to(&self, val: &mut I, stream: &Stream) -> CudaResult<()> {
551        let val = val.as_mut();
552        assert!(
553            self.len() == val.len(),
554            "destination and source slices have different lengths"
555        );
556        let size = mem::size_of::<T>() * self.len();
557        if size != 0 {
558            cuda_driver_sys::cuMemcpyDtoHAsync_v2(
559                val.as_mut_ptr() as *mut c_void,
560                self.as_ptr() as u64,
561                size,
562                stream.as_inner(),
563            )
564            .to_result()?
565        }
566        Ok(())
567    }
568}
569impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
570    unsafe fn async_copy_from(&mut self, val: &DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
571        assert!(
572            self.len() == val.len(),
573            "destination and source slices have different lengths"
574        );
575        let size = mem::size_of::<T>() * self.len();
576        if size != 0 {
577            cuda_driver_sys::cuMemcpyDtoDAsync_v2(
578                self.0.as_mut_ptr() as u64,
579                val.as_ptr() as u64,
580                size,
581                stream.as_inner(),
582            )
583            .to_result()?
584        }
585        Ok(())
586    }
587
588    unsafe fn async_copy_to(&self, val: &mut DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
589        assert!(
590            self.len() == val.len(),
591            "destination and source slices have different lengths"
592        );
593        let size = mem::size_of::<T>() * self.len();
594        if size != 0 {
595            cuda_driver_sys::cuMemcpyDtoDAsync_v2(
596                val.as_mut_ptr() as u64,
597                self.as_ptr() as u64,
598                size,
599                stream.as_inner(),
600            )
601            .to_result()?
602        }
603        Ok(())
604    }
605}
606impl<T: DeviceCopy> AsyncCopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
607    unsafe fn async_copy_from(&mut self, val: &DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
608        self.async_copy_from(val as &DeviceSlice<T>, stream)
609    }
610
611    unsafe fn async_copy_to(&self, val: &mut DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
612        self.async_copy_to(val as &mut DeviceSlice<T>, stream)
613    }
614}