baracuda_driver/
pinned.rs1use core::ffi::c_void;
16use core::mem::size_of;
17use core::ops::{Deref, DerefMut};
18
19use baracuda_cuda_sys::{driver, CUdeviceptr};
20use baracuda_types::DeviceRepr;
21
22use crate::context::Context;
23use crate::error::{check, Result};
24
25#[allow(non_snake_case)]
32pub mod flags {
33 pub const PORTABLE: u32 = 0x01;
34 pub const DEVICEMAP: u32 = 0x02;
35 pub const WRITECOMBINED: u32 = 0x04;
36}
37
38pub struct PinnedBuffer<T: DeviceRepr> {
40 ptr: *mut T,
41 len: usize,
42 #[allow(dead_code)]
43 context: Context,
44 _marker: core::marker::PhantomData<T>,
45}
46
47unsafe impl<T: DeviceRepr + Send> Send for PinnedBuffer<T> {}
48unsafe impl<T: DeviceRepr + Sync> Sync for PinnedBuffer<T> {}
49
50impl<T: DeviceRepr> core::fmt::Debug for PinnedBuffer<T> {
51 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52 f.debug_struct("PinnedBuffer")
53 .field("ptr", &self.ptr)
54 .field("len", &self.len)
55 .field("type", &core::any::type_name::<T>())
56 .finish()
57 }
58}
59
60impl<T: DeviceRepr> PinnedBuffer<T> {
61 pub fn new(context: &Context, len: usize) -> Result<Self> {
63 Self::with_flags(context, len, 0)
64 }
65
66 pub fn with_flags(context: &Context, len: usize, flags: u32) -> Result<Self> {
76 let bytes = len
77 .checked_mul(size_of::<T>())
78 .expect("overflow in PinnedBuffer byte count");
79 if bytes == 0 {
80 return Ok(Self {
81 ptr: core::ptr::NonNull::<T>::dangling().as_ptr(),
82 len,
83 context: context.clone(),
84 _marker: core::marker::PhantomData,
85 });
86 }
87 context.set_current()?;
88 let d = driver()?;
89 let cu = d.cu_mem_host_alloc()?;
90 let mut p: *mut c_void = core::ptr::null_mut();
91 check(unsafe { cu(&mut p, bytes, flags) })?;
92 Ok(Self {
93 ptr: p as *mut T,
94 len,
95 context: context.clone(),
96 _marker: core::marker::PhantomData,
97 })
98 }
99
100 pub fn device_ptr(&self) -> Result<CUdeviceptr> {
108 if self.len == 0 {
109 return Ok(CUdeviceptr(0));
110 }
111 let d = driver()?;
112 let cu = d.cu_mem_host_get_device_pointer()?;
113 let mut dptr = CUdeviceptr(0);
114 check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
115 Ok(dptr)
116 }
117
118 pub fn flags(&self) -> Result<u32> {
122 if self.len == 0 {
123 return Ok(0);
124 }
125 let d = driver()?;
126 let cu = d.cu_mem_host_get_flags()?;
127 let mut flags: core::ffi::c_uint = 0;
128 check(unsafe { cu(&mut flags, self.ptr as *mut c_void) })?;
129 Ok(flags)
130 }
131
132 #[inline]
133 pub fn len(&self) -> usize {
134 self.len
135 }
136 #[inline]
137 pub fn is_empty(&self) -> bool {
138 self.len == 0
139 }
140 #[inline]
141 pub fn as_ptr(&self) -> *const T {
142 self.ptr
143 }
144 #[inline]
145 pub fn as_mut_ptr(&mut self) -> *mut T {
146 self.ptr
147 }
148}
149
150impl<T: DeviceRepr> Deref for PinnedBuffer<T> {
151 type Target = [T];
152 fn deref(&self) -> &[T] {
153 unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
155 }
156}
157
158impl<T: DeviceRepr> DerefMut for PinnedBuffer<T> {
159 fn deref_mut(&mut self) -> &mut [T] {
160 unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
161 }
162}
163
164impl<T: DeviceRepr> Drop for PinnedBuffer<T> {
165 fn drop(&mut self) {
166 if self.len == 0 || self.ptr.is_null() {
170 return;
171 }
172 if let Ok(d) = driver() {
173 if let Ok(cu) = d.cu_mem_free_host() {
174 let _ = unsafe { cu(self.ptr as *mut c_void) };
175 }
176 }
177 }
178}
179
180pub struct PinnedRegistration<'a, T: DeviceRepr> {
185 ptr: *mut T,
186 len: usize,
187 _borrow: core::marker::PhantomData<&'a mut [T]>,
188}
189
190unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
191unsafe impl<T: DeviceRepr + Sync> Sync for PinnedRegistration<'_, T> {}
192
193impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
194 pub fn register(slice: &'a mut [T]) -> Result<Self> {
196 Self::register_with_flags(slice, 0)
197 }
198
199 pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
200 let d = driver()?;
201 let cu = d.cu_mem_host_register()?;
202 let bytes = core::mem::size_of_val(slice);
203 check(unsafe { cu(slice.as_mut_ptr() as *mut c_void, bytes, flags) })?;
204 Ok(Self {
205 ptr: slice.as_mut_ptr(),
206 len: slice.len(),
207 _borrow: core::marker::PhantomData,
208 })
209 }
210
211 pub fn device_ptr(&self) -> Result<CUdeviceptr> {
213 let d = driver()?;
214 let cu = d.cu_mem_host_get_device_pointer()?;
215 let mut dptr = CUdeviceptr(0);
216 check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
217 Ok(dptr)
218 }
219
220 #[inline]
221 pub fn len(&self) -> usize {
222 self.len
223 }
224 #[inline]
225 pub fn is_empty(&self) -> bool {
226 self.len == 0
227 }
228}
229
230impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
231 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
232 f.debug_struct("PinnedRegistration")
233 .field("ptr", &self.ptr)
234 .field("len", &self.len)
235 .finish()
236 }
237}
238
239impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
240 fn drop(&mut self) {
241 if self.ptr.is_null() {
242 return;
243 }
244 if let Ok(d) = driver() {
245 if let Ok(cu) = d.cu_mem_host_unregister() {
246 let _ = unsafe { cu(self.ptr as *mut c_void) };
247 }
248 }
249 }
250}