Skip to main content

baracuda_runtime/
memory.rs

1//! Runtime-API device memory.
2
3use core::ffi::c_void;
4use core::marker::PhantomData;
5use core::mem::size_of;
6
7use baracuda_cuda_sys::runtime::{cudaMemcpyKind, runtime};
8use baracuda_types::DeviceRepr;
9
10use crate::error::{check, Result};
11use crate::stream::Stream;
12
13/// Owned, typed allocation of device memory (Runtime API).
14pub struct DeviceBuffer<T: DeviceRepr> {
15    ptr: *mut c_void,
16    len: usize,
17    _marker: PhantomData<T>,
18}
19
20unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
21
22impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
23    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
24        f.debug_struct("DeviceBuffer")
25            .field("ptr", &self.ptr)
26            .field("len", &self.len)
27            .field("type", &core::any::type_name::<T>())
28            .finish()
29    }
30}
31
32impl<T: DeviceRepr> DeviceBuffer<T> {
33    /// Allocate an uninitialized buffer of `len` elements on the current device.
34    pub fn new(len: usize) -> Result<Self> {
35        let r = runtime()?;
36        let cu = r.cuda_malloc()?;
37        let bytes = len
38            .checked_mul(size_of::<T>())
39            .expect("overflow computing allocation size");
40        let mut ptr: *mut c_void = core::ptr::null_mut();
41        check(unsafe { cu(&mut ptr, bytes) })?;
42        Ok(Self {
43            ptr,
44            len,
45            _marker: PhantomData,
46        })
47    }
48
49    /// Allocate and zero-fill.
50    pub fn zeros(len: usize) -> Result<Self> {
51        let buf = Self::new(len)?;
52        let r = runtime()?;
53        let cu = r.cuda_memset()?;
54        let bytes = len * size_of::<T>();
55        check(unsafe { cu(buf.ptr, 0, bytes) })?;
56        Ok(buf)
57    }
58
59    /// Allocate and synchronously copy `src` from host memory.
60    pub fn from_slice(src: &[T]) -> Result<Self> {
61        let buf = Self::new(src.len())?;
62        buf.copy_from_host(src)?;
63        Ok(buf)
64    }
65
66    /// Synchronous H2D copy.
67    pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
68        assert_eq!(src.len(), self.len);
69        let r = runtime()?;
70        let cu = r.cuda_memcpy()?;
71        let bytes = self.len * size_of::<T>();
72        check(unsafe {
73            cu(
74                self.ptr,
75                src.as_ptr() as *const c_void,
76                bytes,
77                cudaMemcpyKind::HostToDevice,
78            )
79        })
80    }
81
82    /// Synchronous D2H copy.
83    pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
84        assert_eq!(dst.len(), self.len);
85        let r = runtime()?;
86        let cu = r.cuda_memcpy()?;
87        let bytes = self.len * size_of::<T>();
88        check(unsafe {
89            cu(
90                dst.as_mut_ptr() as *mut c_void,
91                self.ptr,
92                bytes,
93                cudaMemcpyKind::DeviceToHost,
94            )
95        })
96    }
97
98    /// Asynchronous H2D copy on `stream`.
99    pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
100        assert_eq!(src.len(), self.len);
101        let r = runtime()?;
102        let cu = r.cuda_memcpy_async()?;
103        let bytes = self.len * size_of::<T>();
104        check(unsafe {
105            cu(
106                self.ptr,
107                src.as_ptr() as *const c_void,
108                bytes,
109                cudaMemcpyKind::HostToDevice,
110                stream.as_raw(),
111            )
112        })
113    }
114
115    /// Asynchronous D2H copy on `stream`.
116    pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
117        assert_eq!(dst.len(), self.len);
118        let r = runtime()?;
119        let cu = r.cuda_memcpy_async()?;
120        let bytes = self.len * size_of::<T>();
121        check(unsafe {
122            cu(
123                dst.as_mut_ptr() as *mut c_void,
124                self.ptr,
125                bytes,
126                cudaMemcpyKind::DeviceToHost,
127                stream.as_raw(),
128            )
129        })
130    }
131
132    /// Number of elements.
133    #[inline]
134    pub fn len(&self) -> usize {
135        self.len
136    }
137
138    /// Size in bytes.
139    #[inline]
140    pub fn byte_size(&self) -> usize {
141        self.len * size_of::<T>()
142    }
143
144    /// `true` if zero elements.
145    #[inline]
146    pub fn is_empty(&self) -> bool {
147        self.len == 0
148    }
149
150    /// Raw device pointer. Use with care.
151    #[inline]
152    pub fn as_raw(&self) -> *mut c_void {
153        self.ptr
154    }
155
156    /// Raw device pointer as the u64 value kernels expect. Convenience
157    /// wrapper around [`as_raw`](Self::as_raw).
158    #[inline]
159    pub fn as_device_ptr(&self) -> u64 {
160        self.ptr as u64
161    }
162}
163
164impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
165    fn drop(&mut self) {
166        if self.ptr.is_null() {
167            return;
168        }
169        if let Ok(r) = runtime() {
170            if let Ok(cu) = r.cuda_free() {
171                let _ = unsafe { cu(self.ptr) };
172            }
173        }
174    }
175}
176
177// ---- Mem info / prefetch / advise ----------------------------------------
178
179/// `cudaMemGetInfo` — `(free, total)` bytes on the current device.
180pub fn mem_get_info() -> Result<(u64, u64)> {
181    let r = runtime()?;
182    let cu = r.cuda_mem_get_info()?;
183    let mut free: usize = 0;
184    let mut total: usize = 0;
185    check(unsafe { cu(&mut free, &mut total) })?;
186    Ok((free as u64, total as u64))
187}
188
189/// Target for [`mem_prefetch_async`] / [`mem_advise`]. The CUDA Runtime
190/// API's v1 variants take an ordinal — pass `cudaCpuDeviceId` (-1) for host.
191#[derive(Copy, Clone, Debug, Eq, PartialEq)]
192pub enum PrefetchTarget {
193    /// Prefetch to a specific CUDA device (by ordinal).
194    Device(i32),
195    /// Prefetch to the host CPU.
196    Host,
197}
198
199impl PrefetchTarget {
200    #[inline]
201    fn as_raw(self) -> i32 {
202        match self {
203            PrefetchTarget::Device(i) => i,
204            PrefetchTarget::Host => -1, // cudaCpuDeviceId
205        }
206    }
207}
208
209/// Prefetch `count` bytes of unified memory at `dev_ptr` to `target`,
210/// ordered on `stream`. `dev_ptr` must be a managed-memory allocation
211/// (from [`ManagedBuffer`] or `cudaMallocManaged`).
212///
213/// # Safety
214///
215/// `dev_ptr..dev_ptr+count` must be a live managed allocation.
216pub unsafe fn mem_prefetch_async(
217    dev_ptr: *const core::ffi::c_void,
218    count: usize,
219    target: PrefetchTarget,
220    stream: &Stream,
221) -> Result<()> {
222    let r = runtime()?;
223    let cu = r.cuda_mem_prefetch_async()?;
224    check(cu(dev_ptr, count, target.as_raw(), stream.as_raw()))
225}
226
227/// `cudaMemAdvise` — unified-memory placement hint. `advice` is a
228/// constant from [`baracuda_cuda_sys::runtime::types::cudaMemoryAdvise`].
229///
230/// # Safety
231///
232/// `dev_ptr..dev_ptr+count` must be a live managed allocation.
233pub unsafe fn mem_advise(
234    dev_ptr: *const core::ffi::c_void,
235    count: usize,
236    advice: i32,
237    target: PrefetchTarget,
238) -> Result<()> {
239    let r = runtime()?;
240    let cu = r.cuda_mem_advise()?;
241    check(cu(dev_ptr, count, advice, target.as_raw()))
242}
243
244// ---- Managed memory -------------------------------------------------------
245
246/// Unified managed-memory buffer — allocated via `cudaMallocManaged`.
247/// Accessible from both host and device without explicit copies.
248pub struct ManagedBuffer<T: DeviceRepr> {
249    ptr: *mut T,
250    len: usize,
251    _marker: PhantomData<T>,
252}
253
254unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
255unsafe impl<T: DeviceRepr + Sync> Sync for ManagedBuffer<T> {}
256
257impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
258    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
259        f.debug_struct("ManagedBuffer")
260            .field("ptr", &self.ptr)
261            .field("len", &self.len)
262            .field("type", &core::any::type_name::<T>())
263            .finish()
264    }
265}
266
267impl<T: DeviceRepr> ManagedBuffer<T> {
268    /// Allocate `len` managed elements with the default attach (`GLOBAL`).
269    pub fn new(len: usize) -> Result<Self> {
270        use baracuda_cuda_sys::runtime::types::cudaMemAttach;
271        Self::with_flags(len, cudaMemAttach::GLOBAL)
272    }
273
274    /// Allocate with explicit attach flags (see
275    /// [`baracuda_cuda_sys::runtime::types::cudaMemAttach`]).
276    pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
277        let r = runtime()?;
278        let cu = r.cuda_malloc_managed()?;
279        let bytes = len
280            .checked_mul(size_of::<T>())
281            .expect("overflow computing allocation size");
282        let mut ptr: *mut c_void = core::ptr::null_mut();
283        check(unsafe { cu(&mut ptr, bytes, flags) })?;
284        Ok(Self {
285            ptr: ptr as *mut T,
286            len,
287            _marker: PhantomData,
288        })
289    }
290
291    /// Number of elements.
292    #[inline]
293    pub fn len(&self) -> usize {
294        self.len
295    }
296
297    #[inline]
298    pub fn is_empty(&self) -> bool {
299        self.len == 0
300    }
301
302    /// Raw pointer — usable from both host and device code.
303    #[inline]
304    pub fn as_ptr(&self) -> *const T {
305        self.ptr
306    }
307
308    #[inline]
309    pub fn as_mut_ptr(&mut self) -> *mut T {
310        self.ptr
311    }
312
313    /// Access as a host slice (synchronizes through device cache on access).
314    pub fn as_slice(&self) -> &[T] {
315        // SAFETY: ptr is live for len elements; managed memory is
316        // host-accessible on supported platforms.
317        unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
318    }
319
320    pub fn as_mut_slice(&mut self) -> &mut [T] {
321        unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
322    }
323}
324
325impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
326    fn drop(&mut self) {
327        if self.ptr.is_null() {
328            return;
329        }
330        if let Ok(r) = runtime() {
331            if let Ok(cu) = r.cuda_free() {
332                let _ = unsafe { cu(self.ptr as *mut c_void) };
333            }
334        }
335    }
336}
337
338// ---- Pinned host memory --------------------------------------------------
339
340/// Flags for `cudaHostAlloc`. See
341/// [`baracuda_cuda_sys::runtime::types::cudaHostAllocFlags`] for raw values.
342pub mod pinned_flags {
343    pub use baracuda_cuda_sys::runtime::types::cudaHostAllocFlags::*;
344}
345
346/// Pinned (page-locked) host allocation — CUDA-owned memory that supports
347/// real async H↔D copies without staging.
348pub struct PinnedHostBuffer<T: DeviceRepr> {
349    ptr: *mut T,
350    len: usize,
351    _marker: PhantomData<T>,
352}
353
354unsafe impl<T: DeviceRepr + Send> Send for PinnedHostBuffer<T> {}
355unsafe impl<T: DeviceRepr + Sync> Sync for PinnedHostBuffer<T> {}
356
357impl<T: DeviceRepr> core::fmt::Debug for PinnedHostBuffer<T> {
358    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
359        f.debug_struct("PinnedHostBuffer")
360            .field("ptr", &self.ptr)
361            .field("len", &self.len)
362            .finish()
363    }
364}
365
366impl<T: DeviceRepr> PinnedHostBuffer<T> {
367    /// Allocate `len` pinned elements with default flags.
368    pub fn new(len: usize) -> Result<Self> {
369        Self::with_flags(len, 0)
370    }
371
372    /// Allocate with `cudaHostAllocFlags` bitmask.
373    pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
374        let r = runtime()?;
375        let cu = r.cuda_host_alloc()?;
376        let bytes = len
377            .checked_mul(size_of::<T>())
378            .expect("overflow computing allocation size");
379        let mut ptr: *mut c_void = core::ptr::null_mut();
380        check(unsafe { cu(&mut ptr, bytes, flags) })?;
381        Ok(Self {
382            ptr: ptr as *mut T,
383            len,
384            _marker: PhantomData,
385        })
386    }
387
388    /// Device-side pointer that aliases this pinned region (requires
389    /// `MAPPED` flag at alloc time).
390    pub fn device_ptr(&self) -> Result<*mut c_void> {
391        let r = runtime()?;
392        let cu = r.cuda_host_get_device_pointer()?;
393        let mut dev: *mut c_void = core::ptr::null_mut();
394        check(unsafe { cu(&mut dev, self.ptr as *mut c_void, 0) })?;
395        Ok(dev)
396    }
397
398    /// Query the flags this buffer was created with.
399    pub fn flags(&self) -> Result<u32> {
400        let r = runtime()?;
401        let cu = r.cuda_host_get_flags()?;
402        let mut f: core::ffi::c_uint = 0;
403        check(unsafe { cu(&mut f, self.ptr as *mut c_void) })?;
404        Ok(f)
405    }
406
407    #[inline]
408    pub fn len(&self) -> usize {
409        self.len
410    }
411    #[inline]
412    pub fn is_empty(&self) -> bool {
413        self.len == 0
414    }
415    #[inline]
416    pub fn as_ptr(&self) -> *const T {
417        self.ptr
418    }
419    #[inline]
420    pub fn as_mut_ptr(&mut self) -> *mut T {
421        self.ptr
422    }
423}
424
425impl<T: DeviceRepr> core::ops::Deref for PinnedHostBuffer<T> {
426    type Target = [T];
427    fn deref(&self) -> &[T] {
428        unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
429    }
430}
431
432impl<T: DeviceRepr> core::ops::DerefMut for PinnedHostBuffer<T> {
433    fn deref_mut(&mut self) -> &mut [T] {
434        unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
435    }
436}
437
438impl<T: DeviceRepr> Drop for PinnedHostBuffer<T> {
439    fn drop(&mut self) {
440        if self.ptr.is_null() {
441            return;
442        }
443        if let Ok(r) = runtime() {
444            if let Ok(cu) = r.cuda_free_host() {
445                let _ = unsafe { cu(self.ptr as *mut c_void) };
446            }
447        }
448    }
449}
450
451/// RAII guard for `cudaHostRegister` — pins an existing host slice and
452/// unregisters on drop.
453pub struct PinnedRegistration<'a, T: DeviceRepr> {
454    ptr: *mut T,
455    len: usize,
456    _borrow: PhantomData<&'a mut [T]>,
457}
458
459unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
460
461impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
462    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
463        f.debug_struct("PinnedRegistration")
464            .field("ptr", &self.ptr)
465            .field("len", &self.len)
466            .finish()
467    }
468}
469
470impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
471    /// Pin `slice` with `flags = 0` until the guard drops.
472    pub fn register(slice: &'a mut [T]) -> Result<Self> {
473        Self::register_with_flags(slice, 0)
474    }
475
476    pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
477        let r = runtime()?;
478        let cu = r.cuda_host_register()?;
479        check(unsafe {
480            cu(
481                slice.as_mut_ptr() as *mut c_void,
482                core::mem::size_of_val(slice),
483                flags,
484            )
485        })?;
486        Ok(Self {
487            ptr: slice.as_mut_ptr(),
488            len: slice.len(),
489            _borrow: PhantomData,
490        })
491    }
492
493    #[inline]
494    pub fn len(&self) -> usize {
495        self.len
496    }
497    #[inline]
498    pub fn is_empty(&self) -> bool {
499        self.len == 0
500    }
501}
502
503impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
504    fn drop(&mut self) {
505        if self.ptr.is_null() {
506            return;
507        }
508        if let Ok(r) = runtime() {
509            if let Ok(cu) = r.cuda_host_unregister() {
510                let _ = unsafe { cu(self.ptr as *mut c_void) };
511            }
512        }
513    }
514}
515
516// ---- Async alloc / free --------------------------------------------------
517
518impl<T: DeviceRepr> DeviceBuffer<T> {
519    /// Asynchronously allocate `len` elements on `stream` from the device's
520    /// default memory pool (CUDA 11.2+).
521    pub fn new_async(len: usize, stream: &Stream) -> Result<Self> {
522        let r = runtime()?;
523        let cu = r.cuda_malloc_async()?;
524        let bytes = len
525            .checked_mul(size_of::<T>())
526            .expect("overflow computing allocation size");
527        let mut ptr: *mut c_void = core::ptr::null_mut();
528        check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
529        Ok(Self {
530            ptr,
531            len,
532            _marker: PhantomData,
533        })
534    }
535
536    /// Free this buffer asynchronously on `stream`. Consumes `self` so
537    /// the sync `Drop` does not also free.
538    pub fn free_async(mut self, stream: &Stream) -> Result<()> {
539        let ptr = core::mem::replace(&mut self.ptr, core::ptr::null_mut());
540        if ptr.is_null() {
541            return Ok(());
542        }
543        let r = runtime()?;
544        let cu = r.cuda_free_async()?;
545        check(unsafe { cu(ptr, stream.as_raw()) })
546    }
547
548    /// Asynchronous memset of `self` to byte value `value` on `stream`.
549    pub fn memset_async(&self, value: u8, stream: &Stream) -> Result<()> {
550        let r = runtime()?;
551        let cu = r.cuda_memset_async()?;
552        let bytes = self.len * size_of::<T>();
553        check(unsafe { cu(self.ptr, value as core::ffi::c_int, bytes, stream.as_raw()) })
554    }
555}
556
557// ---- Peer memcpy ---------------------------------------------------------
558
559/// Peer-to-peer device memory copy. Both buffers must be on enabled-peer
560/// devices (see [`crate::Device::enable_peer_access`]).
561pub fn memcpy_peer<T: DeviceRepr>(
562    dst: &DeviceBuffer<T>,
563    dst_device: &crate::Device,
564    src: &DeviceBuffer<T>,
565    src_device: &crate::Device,
566) -> Result<()> {
567    assert_eq!(dst.len(), src.len());
568    let r = runtime()?;
569    let cu = r.cuda_memcpy_peer()?;
570    let bytes = src.len() * size_of::<T>();
571    check(unsafe {
572        cu(
573            dst.as_raw(),
574            dst_device.ordinal(),
575            src.as_raw(),
576            src_device.ordinal(),
577            bytes,
578        )
579    })
580}
581
582/// Async peer-to-peer memcpy ordered on `stream`.
583pub fn memcpy_peer_async<T: DeviceRepr>(
584    dst: &DeviceBuffer<T>,
585    dst_device: &crate::Device,
586    src: &DeviceBuffer<T>,
587    src_device: &crate::Device,
588    stream: &Stream,
589) -> Result<()> {
590    assert_eq!(dst.len(), src.len());
591    let r = runtime()?;
592    let cu = r.cuda_memcpy_peer_async()?;
593    let bytes = src.len() * size_of::<T>();
594    check(unsafe {
595        cu(
596            dst.as_raw(),
597            dst_device.ordinal(),
598            src.as_raw(),
599            src_device.ordinal(),
600            bytes,
601            stream.as_raw(),
602        )
603    })
604}