Skip to main content

baracuda_driver/
pinned.rs

1//! Pinned (page-locked) host memory.
2//!
3//! Regular Rust allocations are pageable — the kernel can move or swap
4//! them out, so a CUDA async memcpy has to stage through the driver's
5//! private pinned pool. Allocating the host buffer pinned from the start
6//! cuts that staging step, unlocking real HtoD/DtoH/compute overlap.
7//!
8//! Two flavors:
9//!
10//! - [`PinnedBuffer<T>`] — pinned allocation owned by CUDA
11//!   (`cuMemHostAlloc` + `cuMemFreeHost`).
12//! - [`PinnedRegistration`] — pin an existing Rust allocation
13//!   (`cuMemHostRegister` + `cuMemHostUnregister`).
14
15use core::ffi::c_void;
16use core::mem::size_of;
17use core::ops::{Deref, DerefMut};
18
19use baracuda_cuda_sys::{driver, CUdeviceptr};
20use baracuda_types::DeviceRepr;
21
22use crate::context::Context;
23use crate::error::{check, Result};
24
25/// Flags for [`PinnedBuffer::with_flags`] / [`PinnedRegistration::register_with_flags`].
26///
27/// `PORTABLE` makes the pinned pages visible to every CUDA context in the
28/// process. `MAPPED` maps the allocation into device space so it can be
29/// used directly by kernels (zero-copy). `WRITECOMBINED` uses a write-only
30/// caching mode that speeds up HtoD at the cost of slow host reads.
31#[allow(non_snake_case)]
32pub mod flags {
33    pub const PORTABLE: u32 = 0x01;
34    pub const DEVICEMAP: u32 = 0x02;
35    pub const WRITECOMBINED: u32 = 0x04;
36}
37
38/// A pinned host allocation owned by CUDA. `Deref`s to `&[T]` / `&mut [T]`.
39pub struct PinnedBuffer<T: DeviceRepr> {
40    ptr: *mut T,
41    len: usize,
42    #[allow(dead_code)]
43    context: Context,
44    _marker: core::marker::PhantomData<T>,
45}
46
47unsafe impl<T: DeviceRepr + Send> Send for PinnedBuffer<T> {}
48unsafe impl<T: DeviceRepr + Sync> Sync for PinnedBuffer<T> {}
49
50impl<T: DeviceRepr> core::fmt::Debug for PinnedBuffer<T> {
51    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52        f.debug_struct("PinnedBuffer")
53            .field("ptr", &self.ptr)
54            .field("len", &self.len)
55            .field("type", &core::any::type_name::<T>())
56            .finish()
57    }
58}
59
60impl<T: DeviceRepr> PinnedBuffer<T> {
61    /// Allocate `len` pinned elements with flags = 0.
62    pub fn new(context: &Context, len: usize) -> Result<Self> {
63        Self::with_flags(context, len, 0)
64    }
65
66    /// Allocate `len` pinned elements, passing `flags` to `cuMemHostAlloc`
67    /// (see the [`flags`] module).
68    ///
69    /// Zero-length allocations (`len == 0` or a zero-sized `T`) short-circuit:
70    /// CUDA rejects 0-byte `cuMemHostAlloc` with a null pointer, which would
71    /// make [`Deref`]'s `slice::from_raw_parts(null, 0)` unsound. Instead we
72    /// use a dangling-but-aligned pointer (the same trick `Vec::new()` uses
73    /// internally), skip the CUDA call entirely, and [`Drop`] recognizes
74    /// the sentinel and skips `cuMemFreeHost`.
75    pub fn with_flags(context: &Context, len: usize, flags: u32) -> Result<Self> {
76        let bytes = len
77            .checked_mul(size_of::<T>())
78            .expect("overflow in PinnedBuffer byte count");
79        if bytes == 0 {
80            return Ok(Self {
81                ptr: core::ptr::NonNull::<T>::dangling().as_ptr(),
82                len,
83                context: context.clone(),
84                _marker: core::marker::PhantomData,
85            });
86        }
87        context.set_current()?;
88        let d = driver()?;
89        let cu = d.cu_mem_host_alloc()?;
90        let mut p: *mut c_void = core::ptr::null_mut();
91        check(unsafe { cu(&mut p, bytes, flags) })?;
92        Ok(Self {
93            ptr: p as *mut T,
94            len,
95            context: context.clone(),
96            _marker: core::marker::PhantomData,
97        })
98    }
99
100    /// Retrieve the device-visible pointer corresponding to this host
101    /// allocation. Requires the buffer was created with the `DEVICEMAP`
102    /// flag (or the process has `CU_CTX_MAP_HOST` enabled).
103    ///
104    /// On an empty buffer, returns `CUdeviceptr(0)` — there's no real
105    /// allocation to map. This mirrors [`crate::DeviceBuffer`]'s null
106    /// sentinel for zero-length device buffers.
107    pub fn device_ptr(&self) -> Result<CUdeviceptr> {
108        if self.len == 0 {
109            return Ok(CUdeviceptr(0));
110        }
111        let d = driver()?;
112        let cu = d.cu_mem_host_get_device_pointer()?;
113        let mut dptr = CUdeviceptr(0);
114        check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
115        Ok(dptr)
116    }
117
118    /// Query the flags the pinned allocation was created with. On an
119    /// empty buffer, returns `0` since there's no real allocation to
120    /// query.
121    pub fn flags(&self) -> Result<u32> {
122        if self.len == 0 {
123            return Ok(0);
124        }
125        let d = driver()?;
126        let cu = d.cu_mem_host_get_flags()?;
127        let mut flags: core::ffi::c_uint = 0;
128        check(unsafe { cu(&mut flags, self.ptr as *mut c_void) })?;
129        Ok(flags)
130    }
131
132    #[inline]
133    pub fn len(&self) -> usize {
134        self.len
135    }
136    #[inline]
137    pub fn is_empty(&self) -> bool {
138        self.len == 0
139    }
140    #[inline]
141    pub fn as_ptr(&self) -> *const T {
142        self.ptr
143    }
144    #[inline]
145    pub fn as_mut_ptr(&mut self) -> *mut T {
146        self.ptr
147    }
148}
149
150impl<T: DeviceRepr> Deref for PinnedBuffer<T> {
151    type Target = [T];
152    fn deref(&self) -> &[T] {
153        // SAFETY: ptr is live for len elements until Drop.
154        unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
155    }
156}
157
158impl<T: DeviceRepr> DerefMut for PinnedBuffer<T> {
159    fn deref_mut(&mut self) -> &mut [T] {
160        unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
161    }
162}
163
164impl<T: DeviceRepr> Drop for PinnedBuffer<T> {
165    fn drop(&mut self) {
166        // Zero-length buffers use `NonNull::dangling()` as a sentinel —
167        // there's nothing to free. We also guard against a null ptr just
168        // in case some future constructor stores one.
169        if self.len == 0 || self.ptr.is_null() {
170            return;
171        }
172        if let Ok(d) = driver() {
173            if let Ok(cu) = d.cu_mem_free_host() {
174                let _ = unsafe { cu(self.ptr as *mut c_void) };
175            }
176        }
177    }
178}
179
180/// Pin an existing Rust slice for the lifetime of this guard. Unpin on drop.
181///
182/// Use this when the host buffer already exists (e.g. from a `Vec<T>`) and
183/// you want `cuMemcpy*Async` to fast-path on it.
184pub struct PinnedRegistration<'a, T: DeviceRepr> {
185    ptr: *mut T,
186    len: usize,
187    _borrow: core::marker::PhantomData<&'a mut [T]>,
188}
189
190unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
191unsafe impl<T: DeviceRepr + Sync> Sync for PinnedRegistration<'_, T> {}
192
193impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
194    /// Pin `slice` with flags = 0.
195    pub fn register(slice: &'a mut [T]) -> Result<Self> {
196        Self::register_with_flags(slice, 0)
197    }
198
199    pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
200        let d = driver()?;
201        let cu = d.cu_mem_host_register()?;
202        let bytes = core::mem::size_of_val(slice);
203        check(unsafe { cu(slice.as_mut_ptr() as *mut c_void, bytes, flags) })?;
204        Ok(Self {
205            ptr: slice.as_mut_ptr(),
206            len: slice.len(),
207            _borrow: core::marker::PhantomData,
208        })
209    }
210
211    /// Device-side pointer aliasing this pinned region (requires `DEVICEMAP`).
212    pub fn device_ptr(&self) -> Result<CUdeviceptr> {
213        let d = driver()?;
214        let cu = d.cu_mem_host_get_device_pointer()?;
215        let mut dptr = CUdeviceptr(0);
216        check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
217        Ok(dptr)
218    }
219
220    #[inline]
221    pub fn len(&self) -> usize {
222        self.len
223    }
224    #[inline]
225    pub fn is_empty(&self) -> bool {
226        self.len == 0
227    }
228}
229
230impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
231    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
232        f.debug_struct("PinnedRegistration")
233            .field("ptr", &self.ptr)
234            .field("len", &self.len)
235            .finish()
236    }
237}
238
239impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
240    fn drop(&mut self) {
241        if self.ptr.is_null() {
242            return;
243        }
244        if let Ok(d) = driver() {
245            if let Ok(cu) = d.cu_mem_host_unregister() {
246                let _ = unsafe { cu(self.ptr as *mut c_void) };
247            }
248        }
249    }
250}