1use core::ffi::c_void;
4use core::marker::PhantomData;
5use core::mem::size_of;
6
7use baracuda_cuda_sys::runtime::{cudaMemcpyKind, runtime};
8use baracuda_types::DeviceRepr;
9
10use crate::error::{check, Result};
11use crate::stream::Stream;
12
13pub struct DeviceBuffer<T: DeviceRepr> {
15 ptr: *mut c_void,
16 len: usize,
17 _marker: PhantomData<T>,
18}
19
20unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
21
22impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
23 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
24 f.debug_struct("DeviceBuffer")
25 .field("ptr", &self.ptr)
26 .field("len", &self.len)
27 .field("type", &core::any::type_name::<T>())
28 .finish()
29 }
30}
31
32impl<T: DeviceRepr> DeviceBuffer<T> {
33 pub fn new(len: usize) -> Result<Self> {
35 let r = runtime()?;
36 let cu = r.cuda_malloc()?;
37 let bytes = len
38 .checked_mul(size_of::<T>())
39 .expect("overflow computing allocation size");
40 let mut ptr: *mut c_void = core::ptr::null_mut();
41 check(unsafe { cu(&mut ptr, bytes) })?;
42 Ok(Self {
43 ptr,
44 len,
45 _marker: PhantomData,
46 })
47 }
48
49 pub fn zeros(len: usize) -> Result<Self> {
51 let buf = Self::new(len)?;
52 let r = runtime()?;
53 let cu = r.cuda_memset()?;
54 let bytes = len * size_of::<T>();
55 check(unsafe { cu(buf.ptr, 0, bytes) })?;
56 Ok(buf)
57 }
58
59 pub fn from_slice(src: &[T]) -> Result<Self> {
61 let buf = Self::new(src.len())?;
62 buf.copy_from_host(src)?;
63 Ok(buf)
64 }
65
66 pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
68 assert_eq!(src.len(), self.len);
69 let r = runtime()?;
70 let cu = r.cuda_memcpy()?;
71 let bytes = self.len * size_of::<T>();
72 check(unsafe {
73 cu(
74 self.ptr,
75 src.as_ptr() as *const c_void,
76 bytes,
77 cudaMemcpyKind::HostToDevice,
78 )
79 })
80 }
81
82 pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
84 assert_eq!(dst.len(), self.len);
85 let r = runtime()?;
86 let cu = r.cuda_memcpy()?;
87 let bytes = self.len * size_of::<T>();
88 check(unsafe {
89 cu(
90 dst.as_mut_ptr() as *mut c_void,
91 self.ptr,
92 bytes,
93 cudaMemcpyKind::DeviceToHost,
94 )
95 })
96 }
97
98 pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
100 assert_eq!(src.len(), self.len);
101 let r = runtime()?;
102 let cu = r.cuda_memcpy_async()?;
103 let bytes = self.len * size_of::<T>();
104 check(unsafe {
105 cu(
106 self.ptr,
107 src.as_ptr() as *const c_void,
108 bytes,
109 cudaMemcpyKind::HostToDevice,
110 stream.as_raw(),
111 )
112 })
113 }
114
115 pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
117 assert_eq!(dst.len(), self.len);
118 let r = runtime()?;
119 let cu = r.cuda_memcpy_async()?;
120 let bytes = self.len * size_of::<T>();
121 check(unsafe {
122 cu(
123 dst.as_mut_ptr() as *mut c_void,
124 self.ptr,
125 bytes,
126 cudaMemcpyKind::DeviceToHost,
127 stream.as_raw(),
128 )
129 })
130 }
131
132 #[inline]
134 pub fn len(&self) -> usize {
135 self.len
136 }
137
138 #[inline]
140 pub fn byte_size(&self) -> usize {
141 self.len * size_of::<T>()
142 }
143
144 #[inline]
146 pub fn is_empty(&self) -> bool {
147 self.len == 0
148 }
149
150 #[inline]
152 pub fn as_raw(&self) -> *mut c_void {
153 self.ptr
154 }
155
156 #[inline]
159 pub fn as_device_ptr(&self) -> u64 {
160 self.ptr as u64
161 }
162}
163
164impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
165 fn drop(&mut self) {
166 if self.ptr.is_null() {
167 return;
168 }
169 if let Ok(r) = runtime() {
170 if let Ok(cu) = r.cuda_free() {
171 let _ = unsafe { cu(self.ptr) };
172 }
173 }
174 }
175}
176
177pub fn mem_get_info() -> Result<(u64, u64)> {
181 let r = runtime()?;
182 let cu = r.cuda_mem_get_info()?;
183 let mut free: usize = 0;
184 let mut total: usize = 0;
185 check(unsafe { cu(&mut free, &mut total) })?;
186 Ok((free as u64, total as u64))
187}
188
189#[derive(Copy, Clone, Debug, Eq, PartialEq)]
192pub enum PrefetchTarget {
193 Device(i32),
195 Host,
197}
198
199impl PrefetchTarget {
200 #[inline]
201 fn as_raw(self) -> i32 {
202 match self {
203 PrefetchTarget::Device(i) => i,
204 PrefetchTarget::Host => -1, }
206 }
207}
208
209pub unsafe fn mem_prefetch_async(
217 dev_ptr: *const core::ffi::c_void,
218 count: usize,
219 target: PrefetchTarget,
220 stream: &Stream,
221) -> Result<()> {
222 let r = runtime()?;
223 let cu = r.cuda_mem_prefetch_async()?;
224 check(cu(dev_ptr, count, target.as_raw(), stream.as_raw()))
225}
226
227pub unsafe fn mem_advise(
234 dev_ptr: *const core::ffi::c_void,
235 count: usize,
236 advice: i32,
237 target: PrefetchTarget,
238) -> Result<()> {
239 let r = runtime()?;
240 let cu = r.cuda_mem_advise()?;
241 check(cu(dev_ptr, count, advice, target.as_raw()))
242}
243
244pub struct ManagedBuffer<T: DeviceRepr> {
249 ptr: *mut T,
250 len: usize,
251 _marker: PhantomData<T>,
252}
253
254unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
255unsafe impl<T: DeviceRepr + Sync> Sync for ManagedBuffer<T> {}
256
257impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
258 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
259 f.debug_struct("ManagedBuffer")
260 .field("ptr", &self.ptr)
261 .field("len", &self.len)
262 .field("type", &core::any::type_name::<T>())
263 .finish()
264 }
265}
266
267impl<T: DeviceRepr> ManagedBuffer<T> {
268 pub fn new(len: usize) -> Result<Self> {
270 use baracuda_cuda_sys::runtime::types::cudaMemAttach;
271 Self::with_flags(len, cudaMemAttach::GLOBAL)
272 }
273
274 pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
277 let r = runtime()?;
278 let cu = r.cuda_malloc_managed()?;
279 let bytes = len
280 .checked_mul(size_of::<T>())
281 .expect("overflow computing allocation size");
282 let mut ptr: *mut c_void = core::ptr::null_mut();
283 check(unsafe { cu(&mut ptr, bytes, flags) })?;
284 Ok(Self {
285 ptr: ptr as *mut T,
286 len,
287 _marker: PhantomData,
288 })
289 }
290
291 #[inline]
293 pub fn len(&self) -> usize {
294 self.len
295 }
296
297 #[inline]
298 pub fn is_empty(&self) -> bool {
299 self.len == 0
300 }
301
302 #[inline]
304 pub fn as_ptr(&self) -> *const T {
305 self.ptr
306 }
307
308 #[inline]
309 pub fn as_mut_ptr(&mut self) -> *mut T {
310 self.ptr
311 }
312
313 pub fn as_slice(&self) -> &[T] {
315 unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
318 }
319
320 pub fn as_mut_slice(&mut self) -> &mut [T] {
321 unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
322 }
323}
324
325impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
326 fn drop(&mut self) {
327 if self.ptr.is_null() {
328 return;
329 }
330 if let Ok(r) = runtime() {
331 if let Ok(cu) = r.cuda_free() {
332 let _ = unsafe { cu(self.ptr as *mut c_void) };
333 }
334 }
335 }
336}
337
338pub mod pinned_flags {
343 pub use baracuda_cuda_sys::runtime::types::cudaHostAllocFlags::*;
344}
345
346pub struct PinnedHostBuffer<T: DeviceRepr> {
349 ptr: *mut T,
350 len: usize,
351 _marker: PhantomData<T>,
352}
353
354unsafe impl<T: DeviceRepr + Send> Send for PinnedHostBuffer<T> {}
355unsafe impl<T: DeviceRepr + Sync> Sync for PinnedHostBuffer<T> {}
356
357impl<T: DeviceRepr> core::fmt::Debug for PinnedHostBuffer<T> {
358 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
359 f.debug_struct("PinnedHostBuffer")
360 .field("ptr", &self.ptr)
361 .field("len", &self.len)
362 .finish()
363 }
364}
365
366impl<T: DeviceRepr> PinnedHostBuffer<T> {
367 pub fn new(len: usize) -> Result<Self> {
369 Self::with_flags(len, 0)
370 }
371
372 pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
374 let r = runtime()?;
375 let cu = r.cuda_host_alloc()?;
376 let bytes = len
377 .checked_mul(size_of::<T>())
378 .expect("overflow computing allocation size");
379 let mut ptr: *mut c_void = core::ptr::null_mut();
380 check(unsafe { cu(&mut ptr, bytes, flags) })?;
381 Ok(Self {
382 ptr: ptr as *mut T,
383 len,
384 _marker: PhantomData,
385 })
386 }
387
388 pub fn device_ptr(&self) -> Result<*mut c_void> {
391 let r = runtime()?;
392 let cu = r.cuda_host_get_device_pointer()?;
393 let mut dev: *mut c_void = core::ptr::null_mut();
394 check(unsafe { cu(&mut dev, self.ptr as *mut c_void, 0) })?;
395 Ok(dev)
396 }
397
398 pub fn flags(&self) -> Result<u32> {
400 let r = runtime()?;
401 let cu = r.cuda_host_get_flags()?;
402 let mut f: core::ffi::c_uint = 0;
403 check(unsafe { cu(&mut f, self.ptr as *mut c_void) })?;
404 Ok(f)
405 }
406
407 #[inline]
408 pub fn len(&self) -> usize {
409 self.len
410 }
411 #[inline]
412 pub fn is_empty(&self) -> bool {
413 self.len == 0
414 }
415 #[inline]
416 pub fn as_ptr(&self) -> *const T {
417 self.ptr
418 }
419 #[inline]
420 pub fn as_mut_ptr(&mut self) -> *mut T {
421 self.ptr
422 }
423}
424
425impl<T: DeviceRepr> core::ops::Deref for PinnedHostBuffer<T> {
426 type Target = [T];
427 fn deref(&self) -> &[T] {
428 unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
429 }
430}
431
432impl<T: DeviceRepr> core::ops::DerefMut for PinnedHostBuffer<T> {
433 fn deref_mut(&mut self) -> &mut [T] {
434 unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
435 }
436}
437
438impl<T: DeviceRepr> Drop for PinnedHostBuffer<T> {
439 fn drop(&mut self) {
440 if self.ptr.is_null() {
441 return;
442 }
443 if let Ok(r) = runtime() {
444 if let Ok(cu) = r.cuda_free_host() {
445 let _ = unsafe { cu(self.ptr as *mut c_void) };
446 }
447 }
448 }
449}
450
451pub struct PinnedRegistration<'a, T: DeviceRepr> {
454 ptr: *mut T,
455 len: usize,
456 _borrow: PhantomData<&'a mut [T]>,
457}
458
459unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
460
461impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
462 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
463 f.debug_struct("PinnedRegistration")
464 .field("ptr", &self.ptr)
465 .field("len", &self.len)
466 .finish()
467 }
468}
469
470impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
471 pub fn register(slice: &'a mut [T]) -> Result<Self> {
473 Self::register_with_flags(slice, 0)
474 }
475
476 pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
477 let r = runtime()?;
478 let cu = r.cuda_host_register()?;
479 check(unsafe {
480 cu(
481 slice.as_mut_ptr() as *mut c_void,
482 core::mem::size_of_val(slice),
483 flags,
484 )
485 })?;
486 Ok(Self {
487 ptr: slice.as_mut_ptr(),
488 len: slice.len(),
489 _borrow: PhantomData,
490 })
491 }
492
493 #[inline]
494 pub fn len(&self) -> usize {
495 self.len
496 }
497 #[inline]
498 pub fn is_empty(&self) -> bool {
499 self.len == 0
500 }
501}
502
503impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
504 fn drop(&mut self) {
505 if self.ptr.is_null() {
506 return;
507 }
508 if let Ok(r) = runtime() {
509 if let Ok(cu) = r.cuda_host_unregister() {
510 let _ = unsafe { cu(self.ptr as *mut c_void) };
511 }
512 }
513 }
514}
515
516impl<T: DeviceRepr> DeviceBuffer<T> {
519 pub fn new_async(len: usize, stream: &Stream) -> Result<Self> {
522 let r = runtime()?;
523 let cu = r.cuda_malloc_async()?;
524 let bytes = len
525 .checked_mul(size_of::<T>())
526 .expect("overflow computing allocation size");
527 let mut ptr: *mut c_void = core::ptr::null_mut();
528 check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
529 Ok(Self {
530 ptr,
531 len,
532 _marker: PhantomData,
533 })
534 }
535
536 pub fn free_async(mut self, stream: &Stream) -> Result<()> {
539 let ptr = core::mem::replace(&mut self.ptr, core::ptr::null_mut());
540 if ptr.is_null() {
541 return Ok(());
542 }
543 let r = runtime()?;
544 let cu = r.cuda_free_async()?;
545 check(unsafe { cu(ptr, stream.as_raw()) })
546 }
547
548 pub fn memset_async(&self, value: u8, stream: &Stream) -> Result<()> {
550 let r = runtime()?;
551 let cu = r.cuda_memset_async()?;
552 let bytes = self.len * size_of::<T>();
553 check(unsafe { cu(self.ptr, value as core::ffi::c_int, bytes, stream.as_raw()) })
554 }
555}
556
557pub fn memcpy_peer<T: DeviceRepr>(
562 dst: &DeviceBuffer<T>,
563 dst_device: &crate::Device,
564 src: &DeviceBuffer<T>,
565 src_device: &crate::Device,
566) -> Result<()> {
567 assert_eq!(dst.len(), src.len());
568 let r = runtime()?;
569 let cu = r.cuda_memcpy_peer()?;
570 let bytes = src.len() * size_of::<T>();
571 check(unsafe {
572 cu(
573 dst.as_raw(),
574 dst_device.ordinal(),
575 src.as_raw(),
576 src_device.ordinal(),
577 bytes,
578 )
579 })
580}
581
582pub fn memcpy_peer_async<T: DeviceRepr>(
584 dst: &DeviceBuffer<T>,
585 dst_device: &crate::Device,
586 src: &DeviceBuffer<T>,
587 src_device: &crate::Device,
588 stream: &Stream,
589) -> Result<()> {
590 assert_eq!(dst.len(), src.len());
591 let r = runtime()?;
592 let cu = r.cuda_memcpy_peer_async()?;
593 let bytes = src.len() * size_of::<T>();
594 check(unsafe {
595 cu(
596 dst.as_raw(),
597 dst_device.ordinal(),
598 src.as_raw(),
599 src_device.ordinal(),
600 bytes,
601 stream.as_raw(),
602 )
603 })
604}