1use crate::error::{CudaResult, DropResult, ToResult};
2use crate::memory::device::{AsyncCopyDestination, CopyDestination, DeviceSlice};
3use crate::memory::malloc::{cuda_free, cuda_malloc};
4use crate::memory::DeviceCopy;
5use crate::memory::DevicePointer;
6use crate::stream::Stream;
7use std::mem;
8use std::ops::{Deref, DerefMut};
9
10use std::ptr;
11
12#[derive(Debug)]
14pub struct DeviceBuffer<T> {
15 buf: DevicePointer<T>,
16 capacity: usize,
17}
18impl<T> DeviceBuffer<T> {
19 pub unsafe fn uninitialized(size: usize) -> CudaResult<Self> {
41 let ptr = if size > 0 && mem::size_of::<T>() > 0 {
42 cuda_malloc(size)?
43 } else {
44 DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
45 };
46 Ok(DeviceBuffer {
47 buf: ptr,
48 capacity: size,
49 })
50 }
51
52 pub unsafe fn zeroed(size: usize) -> CudaResult<Self> {
77 let ptr = if size > 0 && mem::size_of::<T>() > 0 {
78 let mut ptr = cuda_malloc(size)?;
79 cuda_driver_sys::cuMemsetD8_v2(ptr.as_raw_mut() as u64, 0, size * mem::size_of::<T>())
80 .to_result()?;
81 ptr
82 } else {
83 DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
84 };
85 Ok(DeviceBuffer {
86 buf: ptr,
87 capacity: size,
88 })
89 }
90
91 pub unsafe fn from_raw_parts(ptr: DevicePointer<T>, capacity: usize) -> DeviceBuffer<T> {
128 DeviceBuffer { buf: ptr, capacity }
129 }
130
131 pub fn drop(mut dev_buf: DeviceBuffer<T>) -> DropResult<DeviceBuffer<T>> {
151 if dev_buf.buf.is_null() {
152 return Ok(());
153 }
154
155 if dev_buf.capacity > 0 && mem::size_of::<T>() > 0 {
156 let capacity = dev_buf.capacity;
157 let ptr = mem::replace(&mut dev_buf.buf, DevicePointer::null());
158 unsafe {
159 match cuda_free(ptr) {
160 Ok(()) => {
161 mem::forget(dev_buf);
162 Ok(())
163 }
164 Err(e) => Err((e, DeviceBuffer::from_raw_parts(ptr, capacity))),
165 }
166 }
167 } else {
168 Ok(())
169 }
170 }
171}
172impl<T: DeviceCopy> DeviceBuffer<T> {
173 pub fn from_slice(slice: &[T]) -> CudaResult<Self> {
189 unsafe {
190 let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
191 uninit.copy_from(slice)?;
192 Ok(uninit)
193 }
194 }
195
196 pub unsafe fn from_slice_async(slice: &[T], stream: &Stream) -> CudaResult<Self> {
223 let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
224 uninit.async_copy_from(slice, stream)?;
225 Ok(uninit)
226 }
227}
228impl<T> Deref for DeviceBuffer<T> {
229 type Target = DeviceSlice<T>;
230
231 fn deref(&self) -> &DeviceSlice<T> {
232 unsafe {
233 DeviceSlice::from_slice(::std::slice::from_raw_parts(
234 self.buf.as_raw(),
235 self.capacity,
236 ))
237 }
238 }
239}
240impl<T> DerefMut for DeviceBuffer<T> {
241 fn deref_mut(&mut self) -> &mut DeviceSlice<T> {
242 unsafe {
243 &mut *(::std::slice::from_raw_parts_mut(self.buf.as_raw_mut(), self.capacity)
244 as *mut [T] as *mut DeviceSlice<T>)
245 }
246 }
247}
248impl<T> Drop for DeviceBuffer<T> {
249 fn drop(&mut self) {
250 if self.buf.is_null() {
251 return;
252 }
253
254 if self.capacity > 0 && mem::size_of::<T>() > 0 {
255 let ptr = mem::replace(&mut self.buf, DevicePointer::null());
257 unsafe {
258 cuda_free(ptr).expect("Failed to deallocate CUDA Device memory.");
259 }
260 }
261 self.capacity = 0;
262 }
263}
264
265#[cfg(test)]
266mod test_device_buffer {
267 use super::*;
268 use crate::memory::device::DeviceBox;
269 use crate::stream::{Stream, StreamFlags};
270
271 #[derive(Clone, Debug)]
272 struct ZeroSizedType;
273 unsafe impl DeviceCopy for ZeroSizedType {}
274
275 #[test]
276 fn test_from_slice_drop() {
277 let _context = crate::quick_init().unwrap();
278 let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
279 drop(buf);
280 }
281
282 #[test]
283 fn test_copy_to_from_device() {
284 let _context = crate::quick_init().unwrap();
285 let start = [0u64, 1, 2, 3, 4, 5];
286 let mut end = [0u64, 0, 0, 0, 0, 0];
287 let buf = DeviceBuffer::from_slice(&start).unwrap();
288 buf.copy_to(&mut end).unwrap();
289 assert_eq!(start, end);
290 }
291
292 #[test]
293 fn test_async_copy_to_from_device() {
294 let _context = crate::quick_init().unwrap();
295 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
296 let start = [0u64, 1, 2, 3, 4, 5];
297 let mut end = [0u64, 0, 0, 0, 0, 0];
298 unsafe {
299 let buf = DeviceBuffer::from_slice_async(&start, &stream).unwrap();
300 buf.async_copy_to(&mut end, &stream).unwrap();
301 }
302 stream.synchronize().unwrap();
303 assert_eq!(start, end);
304 }
305
306 #[test]
307 fn test_slice() {
308 let _context = crate::quick_init().unwrap();
309 let start = [0u64, 1, 2, 3, 4, 5];
310 let mut end = [0u64, 0];
311 let mut buf = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
312 buf.copy_from(&start[0..4]).unwrap();
313 buf[0..2].copy_to(&mut end).unwrap();
314 assert_eq!(start[0..2], end);
315 }
316
317 #[test]
318 fn test_async_slice() {
319 let _context = crate::quick_init().unwrap();
320 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
321 let start = [0u64, 1, 2, 3, 4, 5];
322 let mut end = [0u64, 0];
323 unsafe {
324 let mut buf = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
325 buf.async_copy_from(&start[0..4], &stream).unwrap();
326 buf[0..2].async_copy_to(&mut end, &stream).unwrap();
327 stream.synchronize().unwrap();
328 assert_eq!(start[0..2], end);
329 }
330 }
331
332 #[test]
333 #[should_panic]
334 fn test_copy_to_d2h_wrong_size() {
335 let _context = crate::quick_init().unwrap();
336 let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
337 let mut end = [0u64, 1, 2, 3, 4];
338 let _ = buf.copy_to(&mut end);
339 }
340
341 #[test]
342 #[should_panic]
343 fn test_async_copy_to_d2h_wrong_size() {
344 let _context = crate::quick_init().unwrap();
345 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
346 unsafe {
347 let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
348 let mut end = [0u64, 1, 2, 3, 4];
349 let _ = buf.async_copy_to(&mut end, &stream);
350 }
351 }
352
353 #[test]
354 #[should_panic]
355 fn test_copy_from_h2d_wrong_size() {
356 let _context = crate::quick_init().unwrap();
357 let start = [0u64, 1, 2, 3, 4];
358 let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
359 let _ = buf.copy_from(&start);
360 }
361
362 #[test]
363 #[should_panic]
364 fn test_async_copy_from_h2d_wrong_size() {
365 let _context = crate::quick_init().unwrap();
366 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
367 let start = [0u64, 1, 2, 3, 4];
368 unsafe {
369 let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
370 let _ = buf.async_copy_from(&start, &stream);
371 }
372 }
373
374 #[test]
375 fn test_copy_device_slice_to_device() {
376 let _context = crate::quick_init().unwrap();
377 let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
378 let mut mid = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
379 let mut end = DeviceBuffer::from_slice(&[0u64, 0]).unwrap();
380 let mut host_end = [0u64, 0];
381 start[1..5].copy_to(&mut mid).unwrap();
382 end.copy_from(&mid[1..3]).unwrap();
383 end.copy_to(&mut host_end).unwrap();
384 assert_eq!([2u64, 3], host_end);
385 }
386
387 #[test]
388 fn test_async_copy_device_slice_to_device() {
389 let _context = crate::quick_init().unwrap();
390 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
391 unsafe {
392 let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
393 let mut mid = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
394 let mut end = DeviceBuffer::from_slice_async(&[0u64, 0], &stream).unwrap();
395 let mut host_end = [0u64, 0];
396 start[1..5].async_copy_to(&mut mid, &stream).unwrap();
397 end.async_copy_from(&mid[1..3], &stream).unwrap();
398 end.async_copy_to(&mut host_end, &stream).unwrap();
399 stream.synchronize().unwrap();
400 assert_eq!([2u64, 3], host_end);
401 }
402 }
403
404 #[test]
405 #[should_panic]
406 fn test_copy_to_d2d_wrong_size() {
407 let _context = crate::quick_init().unwrap();
408 let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
409 let mut end = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
410 let _ = buf.copy_to(&mut end);
411 }
412
413 #[test]
414 #[should_panic]
415 fn test_async_copy_to_d2d_wrong_size() {
416 let _context = crate::quick_init().unwrap();
417 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
418 unsafe {
419 let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
420 let mut end = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
421 let _ = buf.async_copy_to(&mut end, &stream);
422 }
423 }
424
425 #[test]
426 #[should_panic]
427 fn test_copy_from_d2d_wrong_size() {
428 let _context = crate::quick_init().unwrap();
429 let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
430 let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
431 let _ = buf.copy_from(&start);
432 }
433
434 #[test]
435 #[should_panic]
436 fn test_async_copy_from_d2d_wrong_size() {
437 let _context = crate::quick_init().unwrap();
438 let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
439 unsafe {
440 let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
441 let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
442 let _ = buf.async_copy_from(&start, &stream);
443 }
444 }
445
446 #[test]
447 fn test_can_create_uninitialized_non_devicecopy_buffers() {
448 let _context = crate::quick_init().unwrap();
449 unsafe {
450 let _box: DeviceBox<Vec<u8>> = DeviceBox::uninitialized().unwrap();
451 let buffer: DeviceBuffer<Vec<u8>> = DeviceBuffer::uninitialized(10).unwrap();
452 let _slice = &buffer[0..5];
453 }
454 }
455
456 #[test]
457 fn test_allocate_correct_size() {
458 use crate::context::CurrentContext;
459
460 let _context = crate::quick_init().unwrap();
461 let total_memory = CurrentContext::get_device()
462 .unwrap()
463 .total_memory()
464 .unwrap();
465
466 let allocation_size = (total_memory * 3) / 4 / mem::size_of::<u64>();
468 unsafe {
469 let _buffer = DeviceBuffer::<u64>::uninitialized(allocation_size).unwrap();
471 };
472 }
473}