fil_rustacuda/memory/device/
device_buffer.rs

1use crate::error::{CudaResult, DropResult, ToResult};
2use crate::memory::device::{AsyncCopyDestination, CopyDestination, DeviceSlice};
3use crate::memory::malloc::{cuda_free, cuda_malloc};
4use crate::memory::DeviceCopy;
5use crate::memory::DevicePointer;
6use crate::stream::Stream;
7use std::mem;
8use std::ops::{Deref, DerefMut};
9
10use std::ptr;
11
12/// Fixed-size device-side buffer. Provides basic access to device memory.
13#[derive(Debug)]
14pub struct DeviceBuffer<T> {
15    buf: DevicePointer<T>,
16    capacity: usize,
17}
18impl<T> DeviceBuffer<T> {
19    /// Allocate a new device buffer large enough to hold `size` `T`'s, but without
20    /// initializing the contents.
21    ///
22    /// # Errors
23    ///
24    /// If the allocation fails, returns the error from CUDA. If `size` is large enough that
25    /// `size * mem::sizeof::<T>()` overflows usize, then returns InvalidMemoryAllocation.
26    ///
27    /// # Safety
28    ///
29    /// The caller must ensure that the contents of the buffer are initialized before reading from
30    /// the buffer.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// # let _context = rustacuda::quick_init().unwrap();
36    /// use rustacuda::memory::*;
37    /// let mut buffer = unsafe { DeviceBuffer::uninitialized(5).unwrap() };
38    /// buffer.copy_from(&[0u64, 1, 2, 3, 4]).unwrap();
39    /// ```
40    pub unsafe fn uninitialized(size: usize) -> CudaResult<Self> {
41        let ptr = if size > 0 && mem::size_of::<T>() > 0 {
42            cuda_malloc(size)?
43        } else {
44            DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
45        };
46        Ok(DeviceBuffer {
47            buf: ptr,
48            capacity: size,
49        })
50    }
51
52    /// Allocate a new device buffer large enough to hold `size` `T`'s and fill the contents with
53    /// zeroes (`0u8`).
54    ///
55    /// # Errors
56    ///
57    /// If the allocation fails, returns the error from CUDA. If `size` is large enough that
58    /// `size * mem::sizeof::<T>()` overflows usize, then returns InvalidMemoryAllocation.
59    ///
60    /// # Safety
61    ///
62    /// The backing memory is zeroed, which may not be a valid bit-pattern for type `T`. The caller
63    /// must ensure either that all-zeroes is a valid bit-pattern for type `T` or that the backing
64    /// memory is set to a valid value before it is read.
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// # let _context = rustacuda::quick_init().unwrap();
70    /// use rustacuda::memory::*;
71    /// let buffer = unsafe { DeviceBuffer::zeroed(5).unwrap() };
72    /// let mut host_values = [1u64, 2, 3, 4, 5];
73    /// buffer.copy_to(&mut host_values).unwrap();
74    /// assert_eq!([0u64, 0, 0, 0, 0], host_values);
75    /// ```
76    pub unsafe fn zeroed(size: usize) -> CudaResult<Self> {
77        let ptr = if size > 0 && mem::size_of::<T>() > 0 {
78            let mut ptr = cuda_malloc(size)?;
79            cuda_driver_sys::cuMemsetD8_v2(ptr.as_raw_mut() as u64, 0, size * mem::size_of::<T>())
80                .to_result()?;
81            ptr
82        } else {
83            DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
84        };
85        Ok(DeviceBuffer {
86            buf: ptr,
87            capacity: size,
88        })
89    }
90
91    /// Creates a `DeviceBuffer<T>` directly from the raw components of another device buffer.
92    ///
93    /// # Safety
94    ///
95    /// This is highly unsafe, due to the number of invariants that aren't
96    /// checked:
97    ///
98    /// * `ptr` needs to have been previously allocated via `DeviceBuffer` or
99    /// [`cuda_malloc`](fn.cuda_malloc.html).
100    /// * `ptr`'s `T` needs to have the same size and alignment as it was allocated with.
101    /// * `capacity` needs to be the capacity that the pointer was allocated with.
102    ///
103    /// Violating these may cause problems like corrupting the CUDA driver's
104    /// internal data structures.
105    ///
106    /// The ownership of `ptr` is effectively transferred to the
107    /// `DeviceBuffer<T>` which may then deallocate, reallocate or change the
108    /// contents of memory pointed to by the pointer at will. Ensure
109    /// that nothing else uses the pointer after calling this
110    /// function.
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// # let _context = rustacuda::quick_init().unwrap();
116    /// use std::mem;
117    /// use rustacuda::memory::*;
118    ///
119    /// let mut buffer = DeviceBuffer::from_slice(&[0u64; 5]).unwrap();
120    /// let ptr = buffer.as_device_ptr();
121    /// let size = buffer.len();
122    ///
123    /// mem::forget(buffer);
124    ///
125    /// let buffer = unsafe { DeviceBuffer::from_raw_parts(ptr, size) };
126    /// ```
127    pub unsafe fn from_raw_parts(ptr: DevicePointer<T>, capacity: usize) -> DeviceBuffer<T> {
128        DeviceBuffer { buf: ptr, capacity }
129    }
130
131    /// Destroy a `DeviceBuffer`, returning an error.
132    ///
133    /// Deallocating device memory can return errors from previous asynchronous work. This function
134    /// destroys the given buffer and returns the error and the un-destroyed buffer on failure.
135    ///
136    /// # Example
137    ///
138    /// ```
139    /// # let _context = rustacuda::quick_init().unwrap();
140    /// use rustacuda::memory::*;
141    /// let x = DeviceBuffer::from_slice(&[10, 20, 30]).unwrap();
142    /// match DeviceBuffer::drop(x) {
143    ///     Ok(()) => println!("Successfully destroyed"),
144    ///     Err((e, buf)) => {
145    ///         println!("Failed to destroy buffer: {:?}", e);
146    ///         // Do something with buf
147    ///     },
148    /// }
149    /// ```
150    pub fn drop(mut dev_buf: DeviceBuffer<T>) -> DropResult<DeviceBuffer<T>> {
151        if dev_buf.buf.is_null() {
152            return Ok(());
153        }
154
155        if dev_buf.capacity > 0 && mem::size_of::<T>() > 0 {
156            let capacity = dev_buf.capacity;
157            let ptr = mem::replace(&mut dev_buf.buf, DevicePointer::null());
158            unsafe {
159                match cuda_free(ptr) {
160                    Ok(()) => {
161                        mem::forget(dev_buf);
162                        Ok(())
163                    }
164                    Err(e) => Err((e, DeviceBuffer::from_raw_parts(ptr, capacity))),
165                }
166            }
167        } else {
168            Ok(())
169        }
170    }
171}
172impl<T: DeviceCopy> DeviceBuffer<T> {
173    /// Allocate a new device buffer of the same size as `slice`, initialized with a clone of
174    /// the data in `slice`.
175    ///
176    /// # Errors
177    ///
178    /// If the allocation fails, returns the error from CUDA.
179    ///
180    /// # Examples
181    ///
182    /// ```
183    /// # let _context = rustacuda::quick_init().unwrap();
184    /// use rustacuda::memory::*;
185    /// let values = [0u64; 5];
186    /// let mut buffer = DeviceBuffer::from_slice(&values).unwrap();
187    /// ```
188    pub fn from_slice(slice: &[T]) -> CudaResult<Self> {
189        unsafe {
190            let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
191            uninit.copy_from(slice)?;
192            Ok(uninit)
193        }
194    }
195
196    /// Asynchronously allocate a new buffer of the same size as `slice`, initialized
197    /// with a clone of the data in `slice`.
198    ///
199    /// # Safety
200    ///
201    /// For why this function is unsafe, see [AsyncCopyDestination](trait.AsyncCopyDestination.html)
202    ///
203    /// # Errors
204    ///
205    /// If the allocation fails, returns the error from CUDA.
206    ///
207    /// # Examples
208    ///
209    /// ```
210    /// # let _context = rustacuda::quick_init().unwrap();
211    /// use rustacuda::memory::*;
212    /// use rustacuda::stream::{Stream, StreamFlags};
213    ///
214    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
215    /// let values = [0u64; 5];
216    /// unsafe {
217    ///     let mut buffer = DeviceBuffer::from_slice_async(&values, &stream).unwrap();
218    ///     stream.synchronize();
219    ///     // Perform some operation on the buffer
220    /// }
221    /// ```
222    pub unsafe fn from_slice_async(slice: &[T], stream: &Stream) -> CudaResult<Self> {
223        let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
224        uninit.async_copy_from(slice, stream)?;
225        Ok(uninit)
226    }
227}
228impl<T> Deref for DeviceBuffer<T> {
229    type Target = DeviceSlice<T>;
230
231    fn deref(&self) -> &DeviceSlice<T> {
232        unsafe {
233            DeviceSlice::from_slice(::std::slice::from_raw_parts(
234                self.buf.as_raw(),
235                self.capacity,
236            ))
237        }
238    }
239}
240impl<T> DerefMut for DeviceBuffer<T> {
241    fn deref_mut(&mut self) -> &mut DeviceSlice<T> {
242        unsafe {
243            &mut *(::std::slice::from_raw_parts_mut(self.buf.as_raw_mut(), self.capacity)
244                as *mut [T] as *mut DeviceSlice<T>)
245        }
246    }
247}
248impl<T> Drop for DeviceBuffer<T> {
249    fn drop(&mut self) {
250        if self.buf.is_null() {
251            return;
252        }
253
254        if self.capacity > 0 && mem::size_of::<T>() > 0 {
255            // No choice but to panic if this fails.
256            let ptr = mem::replace(&mut self.buf, DevicePointer::null());
257            unsafe {
258                cuda_free(ptr).expect("Failed to deallocate CUDA Device memory.");
259            }
260        }
261        self.capacity = 0;
262    }
263}
264
265#[cfg(test)]
266mod test_device_buffer {
267    use super::*;
268    use crate::memory::device::DeviceBox;
269    use crate::stream::{Stream, StreamFlags};
270
271    #[derive(Clone, Debug)]
272    struct ZeroSizedType;
273    unsafe impl DeviceCopy for ZeroSizedType {}
274
275    #[test]
276    fn test_from_slice_drop() {
277        let _context = crate::quick_init().unwrap();
278        let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
279        drop(buf);
280    }
281
282    #[test]
283    fn test_copy_to_from_device() {
284        let _context = crate::quick_init().unwrap();
285        let start = [0u64, 1, 2, 3, 4, 5];
286        let mut end = [0u64, 0, 0, 0, 0, 0];
287        let buf = DeviceBuffer::from_slice(&start).unwrap();
288        buf.copy_to(&mut end).unwrap();
289        assert_eq!(start, end);
290    }
291
292    #[test]
293    fn test_async_copy_to_from_device() {
294        let _context = crate::quick_init().unwrap();
295        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
296        let start = [0u64, 1, 2, 3, 4, 5];
297        let mut end = [0u64, 0, 0, 0, 0, 0];
298        unsafe {
299            let buf = DeviceBuffer::from_slice_async(&start, &stream).unwrap();
300            buf.async_copy_to(&mut end, &stream).unwrap();
301        }
302        stream.synchronize().unwrap();
303        assert_eq!(start, end);
304    }
305
306    #[test]
307    fn test_slice() {
308        let _context = crate::quick_init().unwrap();
309        let start = [0u64, 1, 2, 3, 4, 5];
310        let mut end = [0u64, 0];
311        let mut buf = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
312        buf.copy_from(&start[0..4]).unwrap();
313        buf[0..2].copy_to(&mut end).unwrap();
314        assert_eq!(start[0..2], end);
315    }
316
317    #[test]
318    fn test_async_slice() {
319        let _context = crate::quick_init().unwrap();
320        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
321        let start = [0u64, 1, 2, 3, 4, 5];
322        let mut end = [0u64, 0];
323        unsafe {
324            let mut buf = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
325            buf.async_copy_from(&start[0..4], &stream).unwrap();
326            buf[0..2].async_copy_to(&mut end, &stream).unwrap();
327            stream.synchronize().unwrap();
328            assert_eq!(start[0..2], end);
329        }
330    }
331
332    #[test]
333    #[should_panic]
334    fn test_copy_to_d2h_wrong_size() {
335        let _context = crate::quick_init().unwrap();
336        let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
337        let mut end = [0u64, 1, 2, 3, 4];
338        let _ = buf.copy_to(&mut end);
339    }
340
341    #[test]
342    #[should_panic]
343    fn test_async_copy_to_d2h_wrong_size() {
344        let _context = crate::quick_init().unwrap();
345        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
346        unsafe {
347            let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
348            let mut end = [0u64, 1, 2, 3, 4];
349            let _ = buf.async_copy_to(&mut end, &stream);
350        }
351    }
352
353    #[test]
354    #[should_panic]
355    fn test_copy_from_h2d_wrong_size() {
356        let _context = crate::quick_init().unwrap();
357        let start = [0u64, 1, 2, 3, 4];
358        let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
359        let _ = buf.copy_from(&start);
360    }
361
362    #[test]
363    #[should_panic]
364    fn test_async_copy_from_h2d_wrong_size() {
365        let _context = crate::quick_init().unwrap();
366        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
367        let start = [0u64, 1, 2, 3, 4];
368        unsafe {
369            let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
370            let _ = buf.async_copy_from(&start, &stream);
371        }
372    }
373
374    #[test]
375    fn test_copy_device_slice_to_device() {
376        let _context = crate::quick_init().unwrap();
377        let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
378        let mut mid = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
379        let mut end = DeviceBuffer::from_slice(&[0u64, 0]).unwrap();
380        let mut host_end = [0u64, 0];
381        start[1..5].copy_to(&mut mid).unwrap();
382        end.copy_from(&mid[1..3]).unwrap();
383        end.copy_to(&mut host_end).unwrap();
384        assert_eq!([2u64, 3], host_end);
385    }
386
387    #[test]
388    fn test_async_copy_device_slice_to_device() {
389        let _context = crate::quick_init().unwrap();
390        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
391        unsafe {
392            let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
393            let mut mid = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
394            let mut end = DeviceBuffer::from_slice_async(&[0u64, 0], &stream).unwrap();
395            let mut host_end = [0u64, 0];
396            start[1..5].async_copy_to(&mut mid, &stream).unwrap();
397            end.async_copy_from(&mid[1..3], &stream).unwrap();
398            end.async_copy_to(&mut host_end, &stream).unwrap();
399            stream.synchronize().unwrap();
400            assert_eq!([2u64, 3], host_end);
401        }
402    }
403
404    #[test]
405    #[should_panic]
406    fn test_copy_to_d2d_wrong_size() {
407        let _context = crate::quick_init().unwrap();
408        let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
409        let mut end = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
410        let _ = buf.copy_to(&mut end);
411    }
412
413    #[test]
414    #[should_panic]
415    fn test_async_copy_to_d2d_wrong_size() {
416        let _context = crate::quick_init().unwrap();
417        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
418        unsafe {
419            let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
420            let mut end = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
421            let _ = buf.async_copy_to(&mut end, &stream);
422        }
423    }
424
425    #[test]
426    #[should_panic]
427    fn test_copy_from_d2d_wrong_size() {
428        let _context = crate::quick_init().unwrap();
429        let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
430        let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
431        let _ = buf.copy_from(&start);
432    }
433
434    #[test]
435    #[should_panic]
436    fn test_async_copy_from_d2d_wrong_size() {
437        let _context = crate::quick_init().unwrap();
438        let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
439        unsafe {
440            let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
441            let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
442            let _ = buf.async_copy_from(&start, &stream);
443        }
444    }
445
446    #[test]
447    fn test_can_create_uninitialized_non_devicecopy_buffers() {
448        let _context = crate::quick_init().unwrap();
449        unsafe {
450            let _box: DeviceBox<Vec<u8>> = DeviceBox::uninitialized().unwrap();
451            let buffer: DeviceBuffer<Vec<u8>> = DeviceBuffer::uninitialized(10).unwrap();
452            let _slice = &buffer[0..5];
453        }
454    }
455
456    #[test]
457    fn test_allocate_correct_size() {
458        use crate::context::CurrentContext;
459
460        let _context = crate::quick_init().unwrap();
461        let total_memory = CurrentContext::get_device()
462            .unwrap()
463            .total_memory()
464            .unwrap();
465
466        // Don't allocate all memory to leave some space for the display's frame buffer
467        let allocation_size = (total_memory * 3) / 4 / mem::size_of::<u64>();
468        unsafe {
469            // Test if allocation fails with an out-of-memory error
470            let _buffer = DeviceBuffer::<u64>::uninitialized(allocation_size).unwrap();
471        };
472    }
473}