fil_rustacuda/memory/device/device_box.rs
1use crate::error::{CudaResult, DropResult, ToResult};
2use crate::memory::device::AsyncCopyDestination;
3use crate::memory::device::CopyDestination;
4use crate::memory::malloc::{cuda_free, cuda_malloc};
5use crate::memory::DeviceCopy;
6use crate::memory::DevicePointer;
7use crate::stream::Stream;
8use std::fmt::{self, Pointer};
9use std::mem;
10
11use std::os::raw::c_void;
12
13/// A pointer type for heap-allocation in CUDA device memory.
14///
15/// See the [`module-level documentation`](../memory/index.html) for more information on device memory.
16#[derive(Debug)]
17pub struct DeviceBox<T> {
18 ptr: DevicePointer<T>,
19}
20impl<T: DeviceCopy> DeviceBox<T> {
21 /// Allocate device memory and place val into it.
22 ///
23 /// This doesn't actually allocate if `T` is zero-sized.
24 ///
25 /// # Errors
26 ///
27 /// If a CUDA error occurs, return the error.
28 ///
29 /// # Examples
30 ///
31 /// ```
32 /// # let _context = rustacuda::quick_init().unwrap();
33 /// use rustacuda::memory::*;
34 /// let five = DeviceBox::new(&5).unwrap();
35 /// ```
36 pub fn new(val: &T) -> CudaResult<Self> {
37 let mut dev_box = unsafe { DeviceBox::uninitialized()? };
38 dev_box.copy_from(val)?;
39 Ok(dev_box)
40 }
41}
42impl<T> DeviceBox<T> {
43 /// Allocate device memory, but do not initialize it.
44 ///
45 /// This doesn't actually allocate if `T` is zero-sized.
46 ///
47 /// # Safety
48 ///
49 /// Since the backing memory is not initialized, this function is not safe. The caller must
50 /// ensure that the backing memory is set to a valid value before it is read, else undefined
51 /// behavior may occur.
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// # let _context = rustacuda::quick_init().unwrap();
57 /// use rustacuda::memory::*;
58 /// let mut five = unsafe { DeviceBox::uninitialized().unwrap() };
59 /// five.copy_from(&5u64).unwrap();
60 /// ```
61 pub unsafe fn uninitialized() -> CudaResult<Self> {
62 if mem::size_of::<T>() == 0 {
63 Ok(DeviceBox {
64 ptr: DevicePointer::null(),
65 })
66 } else {
67 let ptr = cuda_malloc(1)?;
68 Ok(DeviceBox { ptr })
69 }
70 }
71
72 /// Allocate device memory and fill it with zeroes (`0u8`).
73 ///
74 /// This doesn't actually allocate if `T` is zero-sized.
75 ///
76 /// # Safety
77 ///
78 /// The backing memory is zeroed, which may not be a valid bit-pattern for type `T`. The caller
79 /// must ensure either that all-zeroes is a valid bit-pattern for type `T` or that the backing
80 /// memory is set to a valid value before it is read.
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// # let _context = rustacuda::quick_init().unwrap();
86 /// use rustacuda::memory::*;
87 /// let mut zero = unsafe { DeviceBox::zeroed().unwrap() };
88 /// let mut value = 5u64;
89 /// zero.copy_to(&mut value).unwrap();
90 /// assert_eq!(0, value);
91 /// ```
92 pub unsafe fn zeroed() -> CudaResult<Self> {
93 let mut new_box = DeviceBox::uninitialized()?;
94 if mem::size_of::<T>() != 0 {
95 cuda_driver_sys::cuMemsetD8_v2(
96 new_box.as_device_ptr().as_raw_mut() as u64,
97 0,
98 mem::size_of::<T>(),
99 )
100 .to_result()?;
101 }
102 Ok(new_box)
103 }
104
105 /// Constructs a DeviceBox from a raw pointer.
106 ///
107 /// After calling this function, the raw pointer and the memory it points to is owned by the
108 /// DeviceBox. The DeviceBox destructor will free the allocated memory, but will not call the destructor
109 /// of `T`. This function may accept any pointer produced by the `cuMemAllocManaged` CUDA API
110 /// call.
111 ///
112 /// # Safety
113 ///
114 /// This function is unsafe because improper use may lead to memory problems. For example, a
115 /// double free may occur if this function is called twice on the same pointer, or a segfault
116 /// may occur if the pointer is not one returned by the appropriate API call.
117 ///
118 /// # Examples
119 ///
120 /// ```
121 /// # let _context = rustacuda::quick_init().unwrap();
122 /// use rustacuda::memory::*;
123 /// let x = DeviceBox::new(&5).unwrap();
124 /// let ptr = DeviceBox::into_device(x).as_raw_mut();
125 /// let x = unsafe { DeviceBox::from_raw(ptr) };
126 /// ```
127 pub unsafe fn from_raw(ptr: *mut T) -> Self {
128 DeviceBox {
129 ptr: DevicePointer::wrap(ptr),
130 }
131 }
132
133 /// Constructs a DeviceBox from a DevicePointer.
134 ///
135 /// After calling this function, the pointer and the memory it points to is owned by the
136 /// DeviceBox. The DeviceBox destructor will free the allocated memory, but will not call the destructor
137 /// of `T`. This function may accept any pointer produced by the `cuMemAllocManaged` CUDA API
138 /// call, such as one taken from `DeviceBox::into_device`.
139 ///
140 /// # Safety
141 ///
142 /// This function is unsafe because improper use may lead to memory problems. For example, a
143 /// double free may occur if this function is called twice on the same pointer, or a segfault
144 /// may occur if the pointer is not one returned by the appropriate API call.
145 ///
146 /// # Examples
147 ///
148 /// ```
149 /// # let _context = rustacuda::quick_init().unwrap();
150 /// use rustacuda::memory::*;
151 /// let x = DeviceBox::new(&5).unwrap();
152 /// let ptr = DeviceBox::into_device(x);
153 /// let x = unsafe { DeviceBox::from_device(ptr) };
154 /// ```
155 pub unsafe fn from_device(ptr: DevicePointer<T>) -> Self {
156 DeviceBox { ptr }
157 }
158
159 /// Consumes the DeviceBox, returning the wrapped DevicePointer.
160 ///
161 /// After calling this function, the caller is responsible for the memory previously managed by
162 /// the DeviceBox. In particular, the caller should properly destroy T and deallocate the memory.
163 /// The easiest way to do so is to create a new DeviceBox using the `DeviceBox::from_device` function.
164 ///
165 /// Note: This is an associated function, which means that you have to all it as
166 /// `DeviceBox::into_device(b)` instead of `b.into_device()` This is so that there is no conflict with
167 /// a method on the inner type.
168 ///
169 /// # Examples
170 ///
171 /// ```
172 /// # let _context = rustacuda::quick_init().unwrap();
173 /// use rustacuda::memory::*;
174 /// let x = DeviceBox::new(&5).unwrap();
175 /// let ptr = DeviceBox::into_device(x);
176 /// # unsafe { DeviceBox::from_device(ptr) };
177 /// ```
178 #[allow(clippy::wrong_self_convention)]
179 pub fn into_device(mut b: DeviceBox<T>) -> DevicePointer<T> {
180 let ptr = mem::replace(&mut b.ptr, DevicePointer::null());
181 mem::forget(b);
182 ptr
183 }
184
185 /// Returns the contained device pointer without consuming the box.
186 ///
187 /// This is useful for passing the box to a kernel launch.
188 ///
189 /// # Examples
190 ///
191 /// ```
192 /// # let _context = rustacuda::quick_init().unwrap();
193 /// use rustacuda::memory::*;
194 /// let mut x = DeviceBox::new(&5).unwrap();
195 /// let ptr = x.as_device_ptr();
196 /// println!("{:p}", ptr);
197 /// ```
198 pub fn as_device_ptr(&mut self) -> DevicePointer<T> {
199 self.ptr
200 }
201
202 /// Destroy a `DeviceBox`, returning an error.
203 ///
204 /// Deallocating device memory can return errors from previous asynchronous work. This function
205 /// destroys the given box and returns the error and the un-destroyed box on failure.
206 ///
207 /// # Example
208 ///
209 /// ```
210 /// # let _context = rustacuda::quick_init().unwrap();
211 /// use rustacuda::memory::*;
212 /// let x = DeviceBox::new(&5).unwrap();
213 /// match DeviceBox::drop(x) {
214 /// Ok(()) => println!("Successfully destroyed"),
215 /// Err((e, dev_box)) => {
216 /// println!("Failed to destroy box: {:?}", e);
217 /// // Do something with dev_box
218 /// },
219 /// }
220 /// ```
221 pub fn drop(mut dev_box: DeviceBox<T>) -> DropResult<DeviceBox<T>> {
222 if dev_box.ptr.is_null() {
223 return Ok(());
224 }
225
226 let ptr = mem::replace(&mut dev_box.ptr, DevicePointer::null());
227 unsafe {
228 match cuda_free(ptr) {
229 Ok(()) => {
230 mem::forget(dev_box);
231 Ok(())
232 }
233 Err(e) => Err((e, DeviceBox { ptr })),
234 }
235 }
236 }
237}
238impl<T> Drop for DeviceBox<T> {
239 fn drop(&mut self) {
240 if self.ptr.is_null() {
241 return;
242 }
243
244 let ptr = mem::replace(&mut self.ptr, DevicePointer::null());
245 // No choice but to panic if this fails.
246 unsafe {
247 cuda_free(ptr).expect("Failed to deallocate CUDA memory.");
248 }
249 }
250}
251impl<T> Pointer for DeviceBox<T> {
252 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253 fmt::Pointer::fmt(&self.ptr, f)
254 }
255}
256impl<T> crate::private::Sealed for DeviceBox<T> {}
257impl<T: DeviceCopy> CopyDestination<T> for DeviceBox<T> {
258 fn copy_from(&mut self, val: &T) -> CudaResult<()> {
259 let size = mem::size_of::<T>();
260 if size != 0 {
261 unsafe {
262 cuda_driver_sys::cuMemcpyHtoD_v2(
263 self.ptr.as_raw_mut() as u64,
264 val as *const T as *const c_void,
265 size,
266 )
267 .to_result()?
268 }
269 }
270 Ok(())
271 }
272
273 fn copy_to(&self, val: &mut T) -> CudaResult<()> {
274 let size = mem::size_of::<T>();
275 if size != 0 {
276 unsafe {
277 cuda_driver_sys::cuMemcpyDtoH_v2(
278 val as *const T as *mut c_void,
279 self.ptr.as_raw() as u64,
280 size,
281 )
282 .to_result()?
283 }
284 }
285 Ok(())
286 }
287}
288impl<T: DeviceCopy> CopyDestination<DeviceBox<T>> for DeviceBox<T> {
289 fn copy_from(&mut self, val: &DeviceBox<T>) -> CudaResult<()> {
290 let size = mem::size_of::<T>();
291 if size != 0 {
292 unsafe {
293 cuda_driver_sys::cuMemcpyDtoD_v2(
294 self.ptr.as_raw_mut() as u64,
295 val.ptr.as_raw() as u64,
296 size,
297 )
298 .to_result()?
299 }
300 }
301 Ok(())
302 }
303
304 fn copy_to(&self, val: &mut DeviceBox<T>) -> CudaResult<()> {
305 let size = mem::size_of::<T>();
306 if size != 0 {
307 unsafe {
308 cuda_driver_sys::cuMemcpyDtoD_v2(
309 val.ptr.as_raw_mut() as u64,
310 self.ptr.as_raw() as u64,
311 size,
312 )
313 .to_result()?
314 }
315 }
316 Ok(())
317 }
318}
319impl<T: DeviceCopy> AsyncCopyDestination<DeviceBox<T>> for DeviceBox<T> {
320 unsafe fn async_copy_from(&mut self, val: &DeviceBox<T>, stream: &Stream) -> CudaResult<()> {
321 let size = mem::size_of::<T>();
322 if size != 0 {
323 cuda_driver_sys::cuMemcpyDtoDAsync_v2(
324 self.ptr.as_raw_mut() as u64,
325 val.ptr.as_raw() as u64,
326 size,
327 stream.as_inner(),
328 )
329 .to_result()?
330 }
331 Ok(())
332 }
333
334 unsafe fn async_copy_to(&self, val: &mut DeviceBox<T>, stream: &Stream) -> CudaResult<()> {
335 let size = mem::size_of::<T>();
336 if size != 0 {
337 cuda_driver_sys::cuMemcpyDtoDAsync_v2(
338 val.ptr.as_raw_mut() as u64,
339 self.ptr.as_raw() as u64,
340 size,
341 stream.as_inner(),
342 )
343 .to_result()?
344 }
345 Ok(())
346 }
347}
348
349#[cfg(test)]
350mod test_device_box {
351 use super::*;
352
353 #[derive(Clone, Debug)]
354 struct ZeroSizedType;
355 unsafe impl DeviceCopy for ZeroSizedType {}
356
357 #[test]
358 fn test_allocate_and_free_device_box() {
359 let _context = crate::quick_init().unwrap();
360 let x = DeviceBox::new(&5u64).unwrap();
361 drop(x);
362 }
363
364 #[test]
365 fn test_device_box_allocates_for_non_zst() {
366 let _context = crate::quick_init().unwrap();
367 let x = DeviceBox::new(&5u64).unwrap();
368 let ptr = DeviceBox::into_device(x);
369 assert!(!ptr.is_null());
370 let _ = unsafe { DeviceBox::from_device(ptr) };
371 }
372
373 #[test]
374 fn test_device_box_doesnt_allocate_for_zero_sized_type() {
375 let _context = crate::quick_init().unwrap();
376 let x = DeviceBox::new(&ZeroSizedType).unwrap();
377 let ptr = DeviceBox::into_device(x);
378 assert!(ptr.is_null());
379 let _ = unsafe { DeviceBox::from_device(ptr) };
380 }
381
382 #[test]
383 fn test_into_from_device() {
384 let _context = crate::quick_init().unwrap();
385 let x = DeviceBox::new(&5u64).unwrap();
386 let ptr = DeviceBox::into_device(x);
387 let _ = unsafe { DeviceBox::from_device(ptr) };
388 }
389
390 #[test]
391 fn test_copy_host_to_device() {
392 let _context = crate::quick_init().unwrap();
393 let y = 5u64;
394 let mut x = DeviceBox::new(&0u64).unwrap();
395 x.copy_from(&y).unwrap();
396 let mut z = 10u64;
397 x.copy_to(&mut z).unwrap();
398 assert_eq!(y, z);
399 }
400
401 #[test]
402 fn test_copy_device_to_host() {
403 let _context = crate::quick_init().unwrap();
404 let x = DeviceBox::new(&5u64).unwrap();
405 let mut y = 0u64;
406 x.copy_to(&mut y).unwrap();
407 assert_eq!(5, y);
408 }
409
410 #[test]
411 fn test_copy_device_to_device() {
412 let _context = crate::quick_init().unwrap();
413 let x = DeviceBox::new(&5u64).unwrap();
414 let mut y = DeviceBox::new(&0u64).unwrap();
415 let mut z = DeviceBox::new(&0u64).unwrap();
416 x.copy_to(&mut y).unwrap();
417 z.copy_from(&y).unwrap();
418
419 let mut h = 0u64;
420 z.copy_to(&mut h).unwrap();
421 assert_eq!(5, h);
422 }
423
424 #[test]
425 fn test_device_pointer_implements_traits_safely() {
426 let _context = crate::quick_init().unwrap();
427 let mut x = DeviceBox::new(&5u64).unwrap();
428 let mut y = DeviceBox::new(&0u64).unwrap();
429
430 // If the impls dereference the pointer, this should segfault.
431 let _ = Ord::cmp(&x.as_device_ptr(), &y.as_device_ptr());
432 let _ = PartialOrd::partial_cmp(&x.as_device_ptr(), &y.as_device_ptr());
433 let _ = PartialEq::eq(&x.as_device_ptr(), &y.as_device_ptr());
434
435 let mut hasher = std::collections::hash_map::DefaultHasher::new();
436 std::hash::Hash::hash(&x.as_device_ptr(), &mut hasher);
437
438 let _ = format!("{:?}", x.as_device_ptr());
439 let _ = format!("{:p}", x.as_device_ptr());
440 }
441}