fil_rustacuda/memory/device/device_slice.rs
1use crate::error::{CudaResult, ToResult};
2use crate::memory::device::AsyncCopyDestination;
3use crate::memory::device::{CopyDestination, DeviceBuffer};
4use crate::memory::DeviceCopy;
5use crate::memory::DevicePointer;
6use crate::stream::Stream;
7use std::iter::{ExactSizeIterator, FusedIterator};
8use std::mem;
9use std::ops::{
10 Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
11};
12
13use std::os::raw::c_void;
14use std::slice::{self, Chunks, ChunksMut};
15
16/// Fixed-size device-side slice.
17#[derive(Debug)]
18#[repr(C)]
19pub struct DeviceSlice<T>([T]);
20// This works by faking a regular slice out of the device raw-pointer and the length and transmuting
21// I have no idea if this is safe or not. Probably not, though I can't imagine how the compiler
22// could possibly know that the pointer is not de-referenceable. I'm banking that we get proper
23// Dynamicaly-sized Types before the compiler authors break this assumption.
24impl<T> DeviceSlice<T> {
25 /// Returns the number of elements in the slice.
26 ///
27 /// # Examples
28 ///
29 /// ```
30 /// # let _context = rustacuda::quick_init().unwrap();
31 /// use rustacuda::memory::*;
32 /// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
33 /// assert_eq!(a.len(), 3);
34 /// ```
35 pub fn len(&self) -> usize {
36 self.0.len()
37 }
38
39 /// Returns `true` if the slice has a length of 0.
40 ///
41 /// # Examples
42 ///
43 /// ```
44 /// # let _context = rustacuda::quick_init().unwrap();
45 /// use rustacuda::memory::*;
46 /// let a : DeviceBuffer<u64> = unsafe { DeviceBuffer::uninitialized(0).unwrap() };
47 /// assert!(a.is_empty());
48 /// ```
49 pub fn is_empty(&self) -> bool {
50 self.0.is_empty()
51 }
52
53 /// Return a raw device-pointer to the slice's buffer.
54 ///
55 /// The caller must ensure that the slice outlives the pointer this function returns, or else
56 /// it will end up pointing to garbage. The caller must also ensure that the pointer is not
57 /// dereferenced by the CPU.
58 ///
59 /// Examples:
60 ///
61 /// ```
62 /// # let _context = rustacuda::quick_init().unwrap();
63 /// use rustacuda::memory::*;
64 /// let a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
65 /// println!("{:p}", a.as_ptr());
66 /// ```
67 pub fn as_ptr(&self) -> *const T {
68 self.0.as_ptr()
69 }
70
71 /// Returns an unsafe mutable device-pointer to the slice's buffer.
72 ///
73 /// The caller must ensure that the slice outlives the pointer this function returns, or else
74 /// it will end up pointing to garbage. The caller must also ensure that the pointer is not
75 /// dereferenced by the CPU.
76 ///
77 /// Examples:
78 ///
79 /// ```
80 /// # let _context = rustacuda::quick_init().unwrap();
81 /// use rustacuda::memory::*;
82 /// let mut a = DeviceBuffer::from_slice(&[1, 2, 3]).unwrap();
83 /// println!("{:p}", a.as_mut_ptr());
84 /// ```
85 pub fn as_mut_ptr(&mut self) -> *mut T {
86 self.0.as_mut_ptr()
87 }
88
89 /// Divides one DeviceSlice into two at a given index.
90 ///
91 /// The first will contain all indices from `[0, mid)` (excluding the index `mid` itself) and
92 /// the second will contain all indices from `[mid, len)` (excluding the index `len` itself).
93 ///
94 /// # Panics
95 ///
96 /// Panics if `min > len`.
97 ///
98 /// Examples:
99 ///
100 /// ```
101 /// # let _context = rustacuda::quick_init().unwrap();
102 /// use rustacuda::memory::*;
103 /// let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
104 /// let (left, right) = buf.split_at(3);
105 /// let mut left_host = [0u64, 0, 0];
106 /// let mut right_host = [0u64, 0, 0];
107 /// left.copy_to(&mut left_host).unwrap();
108 /// right.copy_to(&mut right_host).unwrap();
109 /// assert_eq!([0u64, 1, 2], left_host);
110 /// assert_eq!([3u64, 4, 5], right_host);
111 /// ```
112 pub fn split_at(&self, mid: usize) -> (&DeviceSlice<T>, &DeviceSlice<T>) {
113 let (left, right) = self.0.split_at(mid);
114 unsafe {
115 (
116 DeviceSlice::from_slice(left),
117 DeviceSlice::from_slice(right),
118 )
119 }
120 }
121
122 /// Divides one mutable DeviceSlice into two at a given index.
123 ///
124 /// The first will contain all indices from `[0, mid)` (excluding the index `mid` itself) and
125 /// the second will contain all indices from `[mid, len)` (excluding the index `len` itself).
126 ///
127 /// # Panics
128 ///
129 /// Panics if `min > len`.
130 ///
131 /// Examples:
132 ///
133 /// ```
134 /// # let _context = rustacuda::quick_init().unwrap();
135 /// use rustacuda::memory::*;
136 /// let mut buf = DeviceBuffer::from_slice(&[0u64, 0, 0, 0, 0, 0]).unwrap();
137 ///
138 /// {
139 /// let (left, right) = buf.split_at_mut(3);
140 /// let left_host = [0u64, 1, 2];
141 /// let right_host = [3u64, 4, 5];
142 /// left.copy_from(&left_host).unwrap();
143 /// right.copy_from(&right_host).unwrap();
144 /// }
145 ///
146 /// let mut host_full = [0u64; 6];
147 /// buf.copy_to(&mut host_full).unwrap();
148 /// assert_eq!([0u64, 1, 2, 3, 4, 5], host_full);
149 /// ```
150 pub fn split_at_mut(&mut self, mid: usize) -> (&mut DeviceSlice<T>, &mut DeviceSlice<T>) {
151 let (left, right) = self.0.split_at_mut(mid);
152 unsafe {
153 (
154 DeviceSlice::from_slice_mut(left),
155 DeviceSlice::from_slice_mut(right),
156 )
157 }
158 }
159
160 /// Returns an iterator over `chunk_size` elements of the slice at a time. The chunks are device
161 /// slices and do not overlap. If `chunk_size` does not divide the length of the slice, then the
162 /// last chunk will not have length `chunk_size`.
163 ///
164 /// See `exact_chunks` for a variant of this iterator that returns chunks of always exactly
165 /// `chunk_size` elements.
166 ///
167 /// # Panics
168 ///
169 /// Panics if `chunk_size` is 0.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// # let _context = rustacuda::quick_init().unwrap();
175 /// use rustacuda::memory::*;
176 /// let slice = DeviceBuffer::from_slice(&[1u64, 2, 3, 4, 5]).unwrap();
177 /// let mut iter = slice.chunks(2);
178 ///
179 /// assert_eq!(iter.next().unwrap().len(), 2);
180 ///
181 /// let mut host_buf = [0u64, 0];
182 /// iter.next().unwrap().copy_to(&mut host_buf).unwrap();
183 /// assert_eq!([3, 4], host_buf);
184 ///
185 /// assert_eq!(iter.next().unwrap().len(), 1);
186 ///
187 /// ```
188 pub fn chunks(&self, chunk_size: usize) -> DeviceChunks<T> {
189 DeviceChunks(self.0.chunks(chunk_size))
190 }
191
192 /// Returns an iterator over `chunk_size` elements of the slice at a time. The chunks are
193 /// mutable device slices and do not overlap. If `chunk_size` does not divide the length of the
194 /// slice, then the last chunk will not have length `chunk_size`.
195 ///
196 /// See `exact_chunks` for a variant of this iterator that returns chunks of always exactly
197 /// `chunk_size` elements.
198 ///
199 /// # Panics
200 ///
201 /// Panics if `chunk_size` is 0.
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// # let _context = rustacuda::quick_init().unwrap();
207 /// use rustacuda::memory::*;
208 /// let mut slice = DeviceBuffer::from_slice(&[0u64, 0, 0, 0, 0]).unwrap();
209 /// {
210 /// let mut iter = slice.chunks_mut(2);
211 ///
212 /// assert_eq!(iter.next().unwrap().len(), 2);
213 ///
214 /// let host_buf = [2u64, 3];
215 /// iter.next().unwrap().copy_from(&host_buf).unwrap();
216 ///
217 /// assert_eq!(iter.next().unwrap().len(), 1);
218 /// }
219 ///
220 /// let mut host_buf = [0u64, 0, 0, 0, 0];
221 /// slice.copy_to(&mut host_buf).unwrap();
222 /// assert_eq!([0u64, 0, 2, 3, 0], host_buf);
223 /// ```
224 pub fn chunks_mut(&mut self, chunk_size: usize) -> DeviceChunksMut<T> {
225 DeviceChunksMut(self.0.chunks_mut(chunk_size))
226 }
227
228 /// Private function used to transmute a CPU slice (which must have the device pointer as it's
229 /// buffer pointer) to a DeviceSlice. Completely unsafe.
230 pub(super) unsafe fn from_slice(slice: &[T]) -> &DeviceSlice<T> {
231 &*(slice as *const [T] as *const DeviceSlice<T>)
232 }
233
234 /// Private function used to transmute a mutable CPU slice (which must have the device pointer
235 /// as it's buffer pointer) to a mutable DeviceSlice. Completely unsafe.
236 pub(super) unsafe fn from_slice_mut(slice: &mut [T]) -> &mut DeviceSlice<T> {
237 &mut *(slice as *mut [T] as *mut DeviceSlice<T>)
238 }
239
240 /// Returns a `DevicePointer<T>` to the buffer.
241 ///
242 /// The caller must ensure that the buffer outlives the returned pointer, or it will end up
243 /// pointing to garbage.
244 ///
245 /// Modifying `DeviceBuffer` is guaranteed not to cause its buffer to be reallocated, so pointers
246 /// cannot be invalidated in that manner, but other types may be added in the future which can
247 /// reallocate.
248 pub fn as_device_ptr(&mut self) -> DevicePointer<T> {
249 unsafe { DevicePointer::wrap(self.0.as_mut_ptr()) }
250 }
251
252 /// Forms a slice from a `DevicePointer` and a length.
253 ///
254 /// The `len` argument is the number of _elements_, not the number of bytes.
255 ///
256 /// # Safety
257 ///
258 /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
259 /// elements, nor whether the lifetime inferred is a suitable lifetime for the returned slice.
260 ///
261 /// # Caveat
262 ///
263 /// The lifetime for the returned slice is inferred from its usage. To prevent accidental misuse,
264 /// it's suggested to tie the lifetime to whatever source lifetime is safe in the context, such
265 /// as by providing a helper function taking the lifetime of a host value for the slice or
266 /// by explicit annotation.
267 ///
268 /// # Examples
269 ///
270 /// ```
271 /// # let _context = rustacuda::quick_init().unwrap();
272 /// use rustacuda::memory::*;
273 /// let mut x = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
274 /// // Manually slice the buffer (this is not recommended!)
275 /// let ptr = unsafe { x.as_device_ptr().offset(1) };
276 /// let slice = unsafe { DeviceSlice::from_raw_parts(ptr, 2) };
277 /// let mut host_buf = [0u64, 0];
278 /// slice.copy_to(&mut host_buf).unwrap();
279 /// assert_eq!([1u64, 2], host_buf);
280 /// ```
281 #[allow(clippy::needless_pass_by_value)]
282 pub unsafe fn from_raw_parts<'a>(data: DevicePointer<T>, len: usize) -> &'a DeviceSlice<T> {
283 DeviceSlice::from_slice(slice::from_raw_parts(data.as_raw(), len))
284 }
285
286 /// Performs the same functionality as `from_raw_parts`, except that a
287 /// mutable slice is returned.
288 ///
289 /// # Safety
290 ///
291 /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
292 /// elements, nor whether the lifetime inferred is a suitable lifetime for the returned slice.
293 ///
294 /// This function is unsafe as there is no guarantee that the given pointer is valid for `len`
295 /// elements, not whether the lifetime inferred is a suitable lifetime for the returned slice,
296 /// as well as not being able to provide a non-aliasing guarantee of the returned
297 /// mutable slice. `data` must be non-null and aligned even for zero-length
298 /// slices as with `from_raw_parts`.
299 ///
300 /// See the documentation of `from_raw_parts` for more details.
301 pub unsafe fn from_raw_parts_mut<'a>(
302 mut data: DevicePointer<T>,
303 len: usize,
304 ) -> &'a mut DeviceSlice<T> {
305 DeviceSlice::from_slice_mut(slice::from_raw_parts_mut(data.as_raw_mut(), len))
306 }
307}
308
309/// An iterator over a [`DeviceSlice`](struct.DeviceSlice.html) in (non-overlapping) chunks
310/// (`chunk_size` elements at a time).
311///
312/// When the slice len is not evenly divided by the chunk size, the last slice of the iteration will
313/// be the remainder.
314///
315/// This struct is created by the `chunks` method on `DeviceSlices`.
316#[derive(Debug, Clone)]
317pub struct DeviceChunks<'a, T: 'a>(Chunks<'a, T>);
318impl<'a, T> Iterator for DeviceChunks<'a, T> {
319 type Item = &'a DeviceSlice<T>;
320
321 fn next(&mut self) -> Option<&'a DeviceSlice<T>> {
322 self.0
323 .next()
324 .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
325 }
326
327 fn size_hint(&self) -> (usize, Option<usize>) {
328 self.0.size_hint()
329 }
330
331 fn count(self) -> usize {
332 self.0.len()
333 }
334
335 fn nth(&mut self, n: usize) -> Option<Self::Item> {
336 self.0
337 .nth(n)
338 .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
339 }
340
341 #[inline]
342 fn last(self) -> Option<Self::Item> {
343 self.0
344 .last()
345 .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
346 }
347}
348impl<'a, T> DoubleEndedIterator for DeviceChunks<'a, T> {
349 #[inline]
350 fn next_back(&mut self) -> Option<&'a DeviceSlice<T>> {
351 self.0
352 .next_back()
353 .map(|slice| unsafe { DeviceSlice::from_slice(slice) })
354 }
355}
356impl<'a, T> ExactSizeIterator for DeviceChunks<'a, T> {}
357impl<'a, T> FusedIterator for DeviceChunks<'a, T> {}
358
359/// An iterator over a [`DeviceSlice`](struct.DeviceSlice.html) in (non-overlapping) mutable chunks
360/// (`chunk_size` elements at a time).
361///
362/// When the slice len is not evenly divided by the chunk size, the last slice of the iteration will
363/// be the remainder.
364///
365/// This struct is created by the `chunks` method on `DeviceSlices`.
366#[derive(Debug)]
367pub struct DeviceChunksMut<'a, T: 'a>(ChunksMut<'a, T>);
368impl<'a, T> Iterator for DeviceChunksMut<'a, T> {
369 type Item = &'a mut DeviceSlice<T>;
370
371 fn next(&mut self) -> Option<&'a mut DeviceSlice<T>> {
372 self.0
373 .next()
374 .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
375 }
376
377 fn size_hint(&self) -> (usize, Option<usize>) {
378 self.0.size_hint()
379 }
380
381 fn count(self) -> usize {
382 self.0.len()
383 }
384
385 fn nth(&mut self, n: usize) -> Option<Self::Item> {
386 self.0
387 .nth(n)
388 .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
389 }
390
391 #[inline]
392 fn last(self) -> Option<Self::Item> {
393 self.0
394 .last()
395 .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
396 }
397}
398impl<'a, T> DoubleEndedIterator for DeviceChunksMut<'a, T> {
399 #[inline]
400 fn next_back(&mut self) -> Option<&'a mut DeviceSlice<T>> {
401 self.0
402 .next_back()
403 .map(|slice| unsafe { DeviceSlice::from_slice_mut(slice) })
404 }
405}
406impl<'a, T> ExactSizeIterator for DeviceChunksMut<'a, T> {}
407impl<'a, T> FusedIterator for DeviceChunksMut<'a, T> {}
408
409macro_rules! impl_index {
410 ($($t:ty)*) => {
411 $(
412 impl<T> Index<$t> for DeviceSlice<T>
413 {
414 type Output = DeviceSlice<T>;
415
416 fn index(&self, index: $t) -> &Self {
417 unsafe { DeviceSlice::from_slice(self.0.index(index)) }
418 }
419 }
420
421 impl<T> IndexMut<$t> for DeviceSlice<T>
422 {
423 fn index_mut(&mut self, index: $t) -> &mut Self {
424 unsafe { DeviceSlice::from_slice_mut( self.0.index_mut(index)) }
425 }
426 }
427 )*
428 }
429}
430impl_index! {
431 Range<usize>
432 RangeFull
433 RangeFrom<usize>
434 RangeInclusive<usize>
435 RangeTo<usize>
436 RangeToInclusive<usize>
437}
438impl<T> crate::private::Sealed for DeviceSlice<T> {}
439impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for DeviceSlice<T> {
440 fn copy_from(&mut self, val: &I) -> CudaResult<()> {
441 let val = val.as_ref();
442 assert!(
443 self.len() == val.len(),
444 "destination and source slices have different lengths"
445 );
446 let size = mem::size_of::<T>() * self.len();
447 if size != 0 {
448 unsafe {
449 cuda_driver_sys::cuMemcpyHtoD_v2(
450 self.0.as_mut_ptr() as u64,
451 val.as_ptr() as *const c_void,
452 size,
453 )
454 .to_result()?
455 }
456 }
457 Ok(())
458 }
459
460 fn copy_to(&self, val: &mut I) -> CudaResult<()> {
461 let val = val.as_mut();
462 assert!(
463 self.len() == val.len(),
464 "destination and source slices have different lengths"
465 );
466 let size = mem::size_of::<T>() * self.len();
467 if size != 0 {
468 unsafe {
469 cuda_driver_sys::cuMemcpyDtoH_v2(
470 val.as_mut_ptr() as *mut c_void,
471 self.as_ptr() as u64,
472 size,
473 )
474 .to_result()?
475 }
476 }
477 Ok(())
478 }
479}
480impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
481 fn copy_from(&mut self, val: &DeviceSlice<T>) -> CudaResult<()> {
482 assert!(
483 self.len() == val.len(),
484 "destination and source slices have different lengths"
485 );
486 let size = mem::size_of::<T>() * self.len();
487 if size != 0 {
488 unsafe {
489 cuda_driver_sys::cuMemcpyDtoD_v2(
490 self.0.as_mut_ptr() as u64,
491 val.as_ptr() as u64,
492 size,
493 )
494 .to_result()?
495 }
496 }
497 Ok(())
498 }
499
500 fn copy_to(&self, val: &mut DeviceSlice<T>) -> CudaResult<()> {
501 assert!(
502 self.len() == val.len(),
503 "destination and source slices have different lengths"
504 );
505 let size = mem::size_of::<T>() * self.len();
506 if size != 0 {
507 unsafe {
508 cuda_driver_sys::cuMemcpyDtoD_v2(
509 val.as_mut_ptr() as u64,
510 self.as_ptr() as u64,
511 size,
512 )
513 .to_result()?
514 }
515 }
516 Ok(())
517 }
518}
519impl<T: DeviceCopy> CopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
520 fn copy_from(&mut self, val: &DeviceBuffer<T>) -> CudaResult<()> {
521 self.copy_from(val as &DeviceSlice<T>)
522 }
523
524 fn copy_to(&self, val: &mut DeviceBuffer<T>) -> CudaResult<()> {
525 self.copy_to(val as &mut DeviceSlice<T>)
526 }
527}
528impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
529 for DeviceSlice<T>
530{
531 unsafe fn async_copy_from(&mut self, val: &I, stream: &Stream) -> CudaResult<()> {
532 let val = val.as_ref();
533 assert!(
534 self.len() == val.len(),
535 "destination and source slices have different lengths"
536 );
537 let size = mem::size_of::<T>() * self.len();
538 if size != 0 {
539 cuda_driver_sys::cuMemcpyHtoDAsync_v2(
540 self.0.as_mut_ptr() as u64,
541 val.as_ptr() as *const c_void,
542 size,
543 stream.as_inner(),
544 )
545 .to_result()?
546 }
547 Ok(())
548 }
549
550 unsafe fn async_copy_to(&self, val: &mut I, stream: &Stream) -> CudaResult<()> {
551 let val = val.as_mut();
552 assert!(
553 self.len() == val.len(),
554 "destination and source slices have different lengths"
555 );
556 let size = mem::size_of::<T>() * self.len();
557 if size != 0 {
558 cuda_driver_sys::cuMemcpyDtoHAsync_v2(
559 val.as_mut_ptr() as *mut c_void,
560 self.as_ptr() as u64,
561 size,
562 stream.as_inner(),
563 )
564 .to_result()?
565 }
566 Ok(())
567 }
568}
569impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
570 unsafe fn async_copy_from(&mut self, val: &DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
571 assert!(
572 self.len() == val.len(),
573 "destination and source slices have different lengths"
574 );
575 let size = mem::size_of::<T>() * self.len();
576 if size != 0 {
577 cuda_driver_sys::cuMemcpyDtoDAsync_v2(
578 self.0.as_mut_ptr() as u64,
579 val.as_ptr() as u64,
580 size,
581 stream.as_inner(),
582 )
583 .to_result()?
584 }
585 Ok(())
586 }
587
588 unsafe fn async_copy_to(&self, val: &mut DeviceSlice<T>, stream: &Stream) -> CudaResult<()> {
589 assert!(
590 self.len() == val.len(),
591 "destination and source slices have different lengths"
592 );
593 let size = mem::size_of::<T>() * self.len();
594 if size != 0 {
595 cuda_driver_sys::cuMemcpyDtoDAsync_v2(
596 val.as_mut_ptr() as u64,
597 self.as_ptr() as u64,
598 size,
599 stream.as_inner(),
600 )
601 .to_result()?
602 }
603 Ok(())
604 }
605}
606impl<T: DeviceCopy> AsyncCopyDestination<DeviceBuffer<T>> for DeviceSlice<T> {
607 unsafe fn async_copy_from(&mut self, val: &DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
608 self.async_copy_from(val as &DeviceSlice<T>, stream)
609 }
610
611 unsafe fn async_copy_to(&self, val: &mut DeviceBuffer<T>, stream: &Stream) -> CudaResult<()> {
612 self.async_copy_to(val as &mut DeviceSlice<T>, stream)
613 }
614}