fil_rustacuda/memory/device/
device_box.rs

1use crate::error::{CudaResult, DropResult, ToResult};
2use crate::memory::device::AsyncCopyDestination;
3use crate::memory::device::CopyDestination;
4use crate::memory::malloc::{cuda_free, cuda_malloc};
5use crate::memory::DeviceCopy;
6use crate::memory::DevicePointer;
7use crate::stream::Stream;
8use std::fmt::{self, Pointer};
9use std::mem;
10
11use std::os::raw::c_void;
12
13/// A pointer type for heap-allocation in CUDA device memory.
14///
15/// See the [`module-level documentation`](../memory/index.html) for more information on device memory.
16#[derive(Debug)]
17pub struct DeviceBox<T> {
18    ptr: DevicePointer<T>,
19}
20impl<T: DeviceCopy> DeviceBox<T> {
21    /// Allocate device memory and place val into it.
22    ///
23    /// This doesn't actually allocate if `T` is zero-sized.
24    ///
25    /// # Errors
26    ///
27    /// If a CUDA error occurs, return the error.
28    ///
29    /// # Examples
30    ///
31    /// ```
32    /// # let _context = rustacuda::quick_init().unwrap();
33    /// use rustacuda::memory::*;
34    /// let five = DeviceBox::new(&5).unwrap();
35    /// ```
36    pub fn new(val: &T) -> CudaResult<Self> {
37        let mut dev_box = unsafe { DeviceBox::uninitialized()? };
38        dev_box.copy_from(val)?;
39        Ok(dev_box)
40    }
41}
42impl<T> DeviceBox<T> {
43    /// Allocate device memory, but do not initialize it.
44    ///
45    /// This doesn't actually allocate if `T` is zero-sized.
46    ///
47    /// # Safety
48    ///
49    /// Since the backing memory is not initialized, this function is not safe. The caller must
50    /// ensure that the backing memory is set to a valid value before it is read, else undefined
51    /// behavior may occur.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// # let _context = rustacuda::quick_init().unwrap();
57    /// use rustacuda::memory::*;
58    /// let mut five = unsafe { DeviceBox::uninitialized().unwrap() };
59    /// five.copy_from(&5u64).unwrap();
60    /// ```
61    pub unsafe fn uninitialized() -> CudaResult<Self> {
62        if mem::size_of::<T>() == 0 {
63            Ok(DeviceBox {
64                ptr: DevicePointer::null(),
65            })
66        } else {
67            let ptr = cuda_malloc(1)?;
68            Ok(DeviceBox { ptr })
69        }
70    }
71
72    /// Allocate device memory and fill it with zeroes (`0u8`).
73    ///
74    /// This doesn't actually allocate if `T` is zero-sized.
75    ///
76    /// # Safety
77    ///
78    /// The backing memory is zeroed, which may not be a valid bit-pattern for type `T`. The caller
79    /// must ensure either that all-zeroes is a valid bit-pattern for type `T` or that the backing
80    /// memory is set to a valid value before it is read.
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// # let _context = rustacuda::quick_init().unwrap();
86    /// use rustacuda::memory::*;
87    /// let mut zero = unsafe { DeviceBox::zeroed().unwrap() };
88    /// let mut value = 5u64;
89    /// zero.copy_to(&mut value).unwrap();
90    /// assert_eq!(0, value);
91    /// ```
92    pub unsafe fn zeroed() -> CudaResult<Self> {
93        let mut new_box = DeviceBox::uninitialized()?;
94        if mem::size_of::<T>() != 0 {
95            cuda_driver_sys::cuMemsetD8_v2(
96                new_box.as_device_ptr().as_raw_mut() as u64,
97                0,
98                mem::size_of::<T>(),
99            )
100            .to_result()?;
101        }
102        Ok(new_box)
103    }
104
105    /// Constructs a DeviceBox from a raw pointer.
106    ///
107    /// After calling this function, the raw pointer and the memory it points to is owned by the
108    /// DeviceBox. The DeviceBox destructor will free the allocated memory, but will not call the destructor
109    /// of `T`. This function may accept any pointer produced by the `cuMemAllocManaged` CUDA API
110    /// call.
111    ///
112    /// # Safety
113    ///
114    /// This function is unsafe because improper use may lead to memory problems. For example, a
115    /// double free may occur if this function is called twice on the same pointer, or a segfault
116    /// may occur if the pointer is not one returned by the appropriate API call.
117    ///
118    /// # Examples
119    ///
120    /// ```
121    /// # let _context = rustacuda::quick_init().unwrap();
122    /// use rustacuda::memory::*;
123    /// let x = DeviceBox::new(&5).unwrap();
124    /// let ptr = DeviceBox::into_device(x).as_raw_mut();
125    /// let x = unsafe { DeviceBox::from_raw(ptr) };
126    /// ```
127    pub unsafe fn from_raw(ptr: *mut T) -> Self {
128        DeviceBox {
129            ptr: DevicePointer::wrap(ptr),
130        }
131    }
132
133    /// Constructs a DeviceBox from a DevicePointer.
134    ///
135    /// After calling this function, the pointer and the memory it points to is owned by the
136    /// DeviceBox. The DeviceBox destructor will free the allocated memory, but will not call the destructor
137    /// of `T`. This function may accept any pointer produced by the `cuMemAllocManaged` CUDA API
138    /// call, such as one taken from `DeviceBox::into_device`.
139    ///
140    /// # Safety
141    ///
142    /// This function is unsafe because improper use may lead to memory problems. For example, a
143    /// double free may occur if this function is called twice on the same pointer, or a segfault
144    /// may occur if the pointer is not one returned by the appropriate API call.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// # let _context = rustacuda::quick_init().unwrap();
150    /// use rustacuda::memory::*;
151    /// let x = DeviceBox::new(&5).unwrap();
152    /// let ptr = DeviceBox::into_device(x);
153    /// let x = unsafe { DeviceBox::from_device(ptr) };
154    /// ```
155    pub unsafe fn from_device(ptr: DevicePointer<T>) -> Self {
156        DeviceBox { ptr }
157    }
158
159    /// Consumes the DeviceBox, returning the wrapped DevicePointer.
160    ///
161    /// After calling this function, the caller is responsible for the memory previously managed by
162    /// the DeviceBox. In particular, the caller should properly destroy T and deallocate the memory.
163    /// The easiest way to do so is to create a new DeviceBox using the `DeviceBox::from_device` function.
164    ///
165    /// Note: This is an associated function, which means that you have to all it as
166    /// `DeviceBox::into_device(b)` instead of `b.into_device()` This is so that there is no conflict with
167    /// a method on the inner type.
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// # let _context = rustacuda::quick_init().unwrap();
173    /// use rustacuda::memory::*;
174    /// let x = DeviceBox::new(&5).unwrap();
175    /// let ptr = DeviceBox::into_device(x);
176    /// # unsafe { DeviceBox::from_device(ptr) };
177    /// ```
178    #[allow(clippy::wrong_self_convention)]
179    pub fn into_device(mut b: DeviceBox<T>) -> DevicePointer<T> {
180        let ptr = mem::replace(&mut b.ptr, DevicePointer::null());
181        mem::forget(b);
182        ptr
183    }
184
185    /// Returns the contained device pointer without consuming the box.
186    ///
187    /// This is useful for passing the box to a kernel launch.
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// # let _context = rustacuda::quick_init().unwrap();
193    /// use rustacuda::memory::*;
194    /// let mut x = DeviceBox::new(&5).unwrap();
195    /// let ptr = x.as_device_ptr();
196    /// println!("{:p}", ptr);
197    /// ```
198    pub fn as_device_ptr(&mut self) -> DevicePointer<T> {
199        self.ptr
200    }
201
202    /// Destroy a `DeviceBox`, returning an error.
203    ///
204    /// Deallocating device memory can return errors from previous asynchronous work. This function
205    /// destroys the given box and returns the error and the un-destroyed box on failure.
206    ///
207    /// # Example
208    ///
209    /// ```
210    /// # let _context = rustacuda::quick_init().unwrap();
211    /// use rustacuda::memory::*;
212    /// let x = DeviceBox::new(&5).unwrap();
213    /// match DeviceBox::drop(x) {
214    ///     Ok(()) => println!("Successfully destroyed"),
215    ///     Err((e, dev_box)) => {
216    ///         println!("Failed to destroy box: {:?}", e);
217    ///         // Do something with dev_box
218    ///     },
219    /// }
220    /// ```
221    pub fn drop(mut dev_box: DeviceBox<T>) -> DropResult<DeviceBox<T>> {
222        if dev_box.ptr.is_null() {
223            return Ok(());
224        }
225
226        let ptr = mem::replace(&mut dev_box.ptr, DevicePointer::null());
227        unsafe {
228            match cuda_free(ptr) {
229                Ok(()) => {
230                    mem::forget(dev_box);
231                    Ok(())
232                }
233                Err(e) => Err((e, DeviceBox { ptr })),
234            }
235        }
236    }
237}
238impl<T> Drop for DeviceBox<T> {
239    fn drop(&mut self) {
240        if self.ptr.is_null() {
241            return;
242        }
243
244        let ptr = mem::replace(&mut self.ptr, DevicePointer::null());
245        // No choice but to panic if this fails.
246        unsafe {
247            cuda_free(ptr).expect("Failed to deallocate CUDA memory.");
248        }
249    }
250}
251impl<T> Pointer for DeviceBox<T> {
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        fmt::Pointer::fmt(&self.ptr, f)
254    }
255}
256impl<T> crate::private::Sealed for DeviceBox<T> {}
257impl<T: DeviceCopy> CopyDestination<T> for DeviceBox<T> {
258    fn copy_from(&mut self, val: &T) -> CudaResult<()> {
259        let size = mem::size_of::<T>();
260        if size != 0 {
261            unsafe {
262                cuda_driver_sys::cuMemcpyHtoD_v2(
263                    self.ptr.as_raw_mut() as u64,
264                    val as *const T as *const c_void,
265                    size,
266                )
267                .to_result()?
268            }
269        }
270        Ok(())
271    }
272
273    fn copy_to(&self, val: &mut T) -> CudaResult<()> {
274        let size = mem::size_of::<T>();
275        if size != 0 {
276            unsafe {
277                cuda_driver_sys::cuMemcpyDtoH_v2(
278                    val as *const T as *mut c_void,
279                    self.ptr.as_raw() as u64,
280                    size,
281                )
282                .to_result()?
283            }
284        }
285        Ok(())
286    }
287}
288impl<T: DeviceCopy> CopyDestination<DeviceBox<T>> for DeviceBox<T> {
289    fn copy_from(&mut self, val: &DeviceBox<T>) -> CudaResult<()> {
290        let size = mem::size_of::<T>();
291        if size != 0 {
292            unsafe {
293                cuda_driver_sys::cuMemcpyDtoD_v2(
294                    self.ptr.as_raw_mut() as u64,
295                    val.ptr.as_raw() as u64,
296                    size,
297                )
298                .to_result()?
299            }
300        }
301        Ok(())
302    }
303
304    fn copy_to(&self, val: &mut DeviceBox<T>) -> CudaResult<()> {
305        let size = mem::size_of::<T>();
306        if size != 0 {
307            unsafe {
308                cuda_driver_sys::cuMemcpyDtoD_v2(
309                    val.ptr.as_raw_mut() as u64,
310                    self.ptr.as_raw() as u64,
311                    size,
312                )
313                .to_result()?
314            }
315        }
316        Ok(())
317    }
318}
319impl<T: DeviceCopy> AsyncCopyDestination<DeviceBox<T>> for DeviceBox<T> {
320    unsafe fn async_copy_from(&mut self, val: &DeviceBox<T>, stream: &Stream) -> CudaResult<()> {
321        let size = mem::size_of::<T>();
322        if size != 0 {
323            cuda_driver_sys::cuMemcpyDtoDAsync_v2(
324                self.ptr.as_raw_mut() as u64,
325                val.ptr.as_raw() as u64,
326                size,
327                stream.as_inner(),
328            )
329            .to_result()?
330        }
331        Ok(())
332    }
333
334    unsafe fn async_copy_to(&self, val: &mut DeviceBox<T>, stream: &Stream) -> CudaResult<()> {
335        let size = mem::size_of::<T>();
336        if size != 0 {
337            cuda_driver_sys::cuMemcpyDtoDAsync_v2(
338                val.ptr.as_raw_mut() as u64,
339                self.ptr.as_raw() as u64,
340                size,
341                stream.as_inner(),
342            )
343            .to_result()?
344        }
345        Ok(())
346    }
347}
348
349#[cfg(test)]
350mod test_device_box {
351    use super::*;
352
353    #[derive(Clone, Debug)]
354    struct ZeroSizedType;
355    unsafe impl DeviceCopy for ZeroSizedType {}
356
357    #[test]
358    fn test_allocate_and_free_device_box() {
359        let _context = crate::quick_init().unwrap();
360        let x = DeviceBox::new(&5u64).unwrap();
361        drop(x);
362    }
363
364    #[test]
365    fn test_device_box_allocates_for_non_zst() {
366        let _context = crate::quick_init().unwrap();
367        let x = DeviceBox::new(&5u64).unwrap();
368        let ptr = DeviceBox::into_device(x);
369        assert!(!ptr.is_null());
370        let _ = unsafe { DeviceBox::from_device(ptr) };
371    }
372
373    #[test]
374    fn test_device_box_doesnt_allocate_for_zero_sized_type() {
375        let _context = crate::quick_init().unwrap();
376        let x = DeviceBox::new(&ZeroSizedType).unwrap();
377        let ptr = DeviceBox::into_device(x);
378        assert!(ptr.is_null());
379        let _ = unsafe { DeviceBox::from_device(ptr) };
380    }
381
382    #[test]
383    fn test_into_from_device() {
384        let _context = crate::quick_init().unwrap();
385        let x = DeviceBox::new(&5u64).unwrap();
386        let ptr = DeviceBox::into_device(x);
387        let _ = unsafe { DeviceBox::from_device(ptr) };
388    }
389
390    #[test]
391    fn test_copy_host_to_device() {
392        let _context = crate::quick_init().unwrap();
393        let y = 5u64;
394        let mut x = DeviceBox::new(&0u64).unwrap();
395        x.copy_from(&y).unwrap();
396        let mut z = 10u64;
397        x.copy_to(&mut z).unwrap();
398        assert_eq!(y, z);
399    }
400
401    #[test]
402    fn test_copy_device_to_host() {
403        let _context = crate::quick_init().unwrap();
404        let x = DeviceBox::new(&5u64).unwrap();
405        let mut y = 0u64;
406        x.copy_to(&mut y).unwrap();
407        assert_eq!(5, y);
408    }
409
410    #[test]
411    fn test_copy_device_to_device() {
412        let _context = crate::quick_init().unwrap();
413        let x = DeviceBox::new(&5u64).unwrap();
414        let mut y = DeviceBox::new(&0u64).unwrap();
415        let mut z = DeviceBox::new(&0u64).unwrap();
416        x.copy_to(&mut y).unwrap();
417        z.copy_from(&y).unwrap();
418
419        let mut h = 0u64;
420        z.copy_to(&mut h).unwrap();
421        assert_eq!(5, h);
422    }
423
424    #[test]
425    fn test_device_pointer_implements_traits_safely() {
426        let _context = crate::quick_init().unwrap();
427        let mut x = DeviceBox::new(&5u64).unwrap();
428        let mut y = DeviceBox::new(&0u64).unwrap();
429
430        // If the impls dereference the pointer, this should segfault.
431        let _ = Ord::cmp(&x.as_device_ptr(), &y.as_device_ptr());
432        let _ = PartialOrd::partial_cmp(&x.as_device_ptr(), &y.as_device_ptr());
433        let _ = PartialEq::eq(&x.as_device_ptr(), &y.as_device_ptr());
434
435        let mut hasher = std::collections::hash_map::DefaultHasher::new();
436        std::hash::Hash::hash(&x.as_device_ptr(), &mut hasher);
437
438        let _ = format!("{:?}", x.as_device_ptr());
439        let _ = format!("{:p}", x.as_device_ptr());
440    }
441}