oxicuda_memory/device_buffer.rs
1//! Type-safe device (GPU VRAM) memory buffer.
2//!
3//! [`DeviceBuffer<T>`] owns a contiguous allocation of `T` elements in device
4//! memory. It supports synchronous and asynchronous copies to/from host
5//! memory, device-to-device copies, and zero-initialisation via `cuMemsetD8`.
6//!
7//! The buffer is parameterised over `T: Copy` so that only plain-old-data
8//! types can be stored — no heap pointers that would be meaningless on the
9//! GPU.
10//!
11//! # Ownership
12//!
13//! The allocation is freed automatically when the buffer is dropped. If
14//! `cuMemFree_v2` fails during [`Drop`], the error is logged via
15//! [`tracing::warn`] rather than panicking.
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! # use oxicuda_memory::DeviceBuffer;
21//! let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
22//! let host_data = vec![1.0_f32; 1024];
23//! buf.copy_from_host(&host_data)?;
24//!
25//! let mut result = vec![0.0_f32; 1024];
26//! buf.copy_to_host(&mut result)?;
27//! assert_eq!(result, host_data);
28//! # Ok::<(), oxicuda_driver::error::CudaError>(())
29//! ```
30
31use std::ffi::c_void;
32use std::marker::PhantomData;
33
34use oxicuda_driver::error::{CudaError, CudaResult};
35use oxicuda_driver::ffi::CUdeviceptr;
36use oxicuda_driver::loader::try_driver;
37use oxicuda_driver::stream::Stream;
38
39// ---------------------------------------------------------------------------
40// DeviceBuffer<T>
41// ---------------------------------------------------------------------------
42
43/// A contiguous buffer of `T` elements allocated in GPU device memory.
44///
45/// The buffer owns the underlying `CUdeviceptr` allocation and frees it on
46/// drop. All copy operations validate that source and destination lengths
47/// match, returning [`CudaError::InvalidValue`] on mismatch.
48pub struct DeviceBuffer<T: Copy> {
49 /// Raw CUDA device pointer to the start of the allocation.
50 ptr: CUdeviceptr,
51 /// Number of `T` elements (not bytes).
52 len: usize,
53 /// Marker to tie the generic parameter `T` to this struct.
54 _phantom: PhantomData<T>,
55}
56
57// SAFETY: Device memory is not bound to a specific host thread. The raw
58// pointer is a `u64` handle managed by the CUDA driver, which is thread-safe
59// for memory operations when properly synchronised.
60unsafe impl<T: Copy + Send> Send for DeviceBuffer<T> {}
61unsafe impl<T: Copy + Sync> Sync for DeviceBuffer<T> {}
62
63impl<T: Copy> DeviceBuffer<T> {
64 /// Allocates a device buffer capable of holding `n` elements of type `T`.
65 ///
66 /// # Errors
67 ///
68 /// * [`CudaError::InvalidValue`] if `n` is zero.
69 /// * [`CudaError::OutOfMemory`] if the GPU cannot satisfy the request.
70 /// * Other driver errors propagated from `cuMemAlloc_v2`.
71 pub fn alloc(n: usize) -> CudaResult<Self> {
72 if n == 0 {
73 return Err(CudaError::InvalidValue);
74 }
75 let byte_size = n
76 .checked_mul(std::mem::size_of::<T>())
77 .ok_or(CudaError::InvalidValue)?;
78 let api = try_driver()?;
79 let mut ptr: CUdeviceptr = 0;
80 // SAFETY: `cu_mem_alloc_v2` writes a valid device pointer on success.
81 let rc = unsafe { (api.cu_mem_alloc_v2)(&mut ptr, byte_size) };
82 oxicuda_driver::check(rc)?;
83 Ok(Self {
84 ptr,
85 len: n,
86 _phantom: PhantomData,
87 })
88 }
89
90 /// Allocates a device buffer of `n` elements and zero-initialises every byte.
91 ///
92 /// This is equivalent to [`alloc`](Self::alloc) followed by a
93 /// `cuMemsetD8_v2` call that writes `0` to every byte.
94 ///
95 /// # Errors
96 ///
97 /// Same as [`alloc`](Self::alloc), plus any error from `cuMemsetD8_v2`.
98 pub fn zeroed(n: usize) -> CudaResult<Self> {
99 let buf = Self::alloc(n)?;
100 let api = try_driver()?;
101 // SAFETY: the buffer was just allocated with the correct byte size.
102 let rc = unsafe { (api.cu_memset_d8_v2)(buf.ptr, 0, buf.byte_size()) };
103 oxicuda_driver::check(rc)?;
104 Ok(buf)
105 }
106
107 /// Allocates a device buffer and copies the contents of `data` into it.
108 ///
109 /// The resulting buffer has the same length as the input slice.
110 ///
111 /// # Errors
112 ///
113 /// * [`CudaError::InvalidValue`] if `data` is empty.
114 /// * Other driver errors from allocation or the host-to-device copy.
115 pub fn from_host(data: &[T]) -> CudaResult<Self> {
116 let mut buf = Self::alloc(data.len())?;
117 buf.copy_from_host(data)?;
118 Ok(buf)
119 }
120
121 /// Copies data from a host slice into this device buffer (synchronous).
122 ///
123 /// The slice length must exactly match the buffer length.
124 ///
125 /// # Errors
126 ///
127 /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
128 /// * Other driver errors from `cuMemcpyHtoD_v2`.
129 pub fn copy_from_host(&mut self, src: &[T]) -> CudaResult<()> {
130 if src.len() != self.len {
131 return Err(CudaError::InvalidValue);
132 }
133 let api = try_driver()?;
134 // SAFETY: `src` is a valid host slice with the correct byte count.
135 let rc = unsafe {
136 (api.cu_memcpy_htod_v2)(self.ptr, src.as_ptr().cast::<c_void>(), self.byte_size())
137 };
138 oxicuda_driver::check(rc)
139 }
140
141 /// Copies this device buffer's contents into a host slice (synchronous).
142 ///
143 /// The slice length must exactly match the buffer length.
144 ///
145 /// # Errors
146 ///
147 /// * [`CudaError::InvalidValue`] if `dst.len() != self.len()`.
148 /// * Other driver errors from `cuMemcpyDtoH_v2`.
149 pub fn copy_to_host(&self, dst: &mut [T]) -> CudaResult<()> {
150 if dst.len() != self.len {
151 return Err(CudaError::InvalidValue);
152 }
153 let api = try_driver()?;
154 // SAFETY: `dst` is a valid host slice with the correct byte count.
155 let rc = unsafe {
156 (api.cu_memcpy_dtoh_v2)(
157 dst.as_mut_ptr().cast::<c_void>(),
158 self.ptr,
159 self.byte_size(),
160 )
161 };
162 oxicuda_driver::check(rc)
163 }
164
165 /// Copies the entire contents of another device buffer into this one.
166 ///
167 /// Both buffers must have the same length.
168 ///
169 /// # Errors
170 ///
171 /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
172 /// * Other driver errors from `cuMemcpyDtoD_v2`.
173 pub fn copy_from_device(&mut self, src: &DeviceBuffer<T>) -> CudaResult<()> {
174 if src.len != self.len {
175 return Err(CudaError::InvalidValue);
176 }
177 let api = try_driver()?;
178 // SAFETY: both pointers are valid device allocations of the same size.
179 let rc = unsafe { (api.cu_memcpy_dtod_v2)(self.ptr, src.ptr, self.byte_size()) };
180 oxicuda_driver::check(rc)
181 }
182
183 /// Asynchronously copies data from a host slice into this device buffer.
184 ///
185 /// The copy is enqueued on `stream` and may not be complete when this
186 /// function returns. The caller must ensure that `src` remains valid
187 /// (i.e., is not moved or dropped) until the stream has been
188 /// synchronised. For guaranteed correctness, prefer using a
189 /// [`PinnedBuffer`](crate::PinnedBuffer) as the source.
190 ///
191 /// # Errors
192 ///
193 /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
194 /// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
195 pub fn copy_from_host_async(&mut self, src: &[T], stream: &Stream) -> CudaResult<()> {
196 if src.len() != self.len {
197 return Err(CudaError::InvalidValue);
198 }
199 let api = try_driver()?;
200 // SAFETY: the caller is responsible for keeping `src` alive until
201 // the stream completes.
202 let rc = unsafe {
203 (api.cu_memcpy_htod_async_v2)(
204 self.ptr,
205 src.as_ptr().cast::<c_void>(),
206 self.byte_size(),
207 stream.raw(),
208 )
209 };
210 oxicuda_driver::check(rc)
211 }
212
213 /// Asynchronously copies this device buffer's contents into a host slice.
214 ///
215 /// The copy is enqueued on `stream` and may not be complete when this
216 /// function returns. The caller must ensure that `dst` remains valid
217 /// and is not read until the stream has been synchronised. For
218 /// guaranteed correctness, prefer using a
219 /// [`PinnedBuffer`](crate::PinnedBuffer) as the destination.
220 ///
221 /// # Errors
222 ///
223 /// * [`CudaError::InvalidValue`] if `dst.len() != self.len()`.
224 /// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
225 pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> CudaResult<()> {
226 if dst.len() != self.len {
227 return Err(CudaError::InvalidValue);
228 }
229 let api = try_driver()?;
230 // SAFETY: the caller is responsible for keeping `dst` alive until
231 // the stream completes.
232 let rc = unsafe {
233 (api.cu_memcpy_dtoh_async_v2)(
234 dst.as_mut_ptr().cast::<c_void>(),
235 self.ptr,
236 self.byte_size(),
237 stream.raw(),
238 )
239 };
240 oxicuda_driver::check(rc)
241 }
242
243 /// Returns the number of `T` elements in this buffer.
244 #[inline]
245 pub fn len(&self) -> usize {
246 self.len
247 }
248
249 /// Returns `true` if the buffer contains zero elements.
250 ///
251 /// In practice this is always `false` because [`alloc`](Self::alloc)
252 /// rejects zero-length allocations.
253 #[inline]
254 pub fn is_empty(&self) -> bool {
255 self.len == 0
256 }
257
258 /// Returns the total size of the allocation in bytes.
259 #[inline]
260 pub fn byte_size(&self) -> usize {
261 self.len * std::mem::size_of::<T>()
262 }
263
264 /// Returns the raw [`CUdeviceptr`] handle for this buffer.
265 ///
266 /// This is useful when passing the pointer to kernel launch parameters
267 /// or other low-level driver calls.
268 #[inline]
269 pub fn as_device_ptr(&self) -> CUdeviceptr {
270 self.ptr
271 }
272
273 /// Returns a borrowed [`DeviceSlice`] referencing a sub-range of this
274 /// buffer starting at element `offset` and spanning `len` elements.
275 ///
276 /// # Errors
277 ///
278 /// Returns [`CudaError::InvalidValue`] if the requested range exceeds
279 /// the buffer bounds (i.e., `offset + len > self.len()`).
280 pub fn slice(&self, offset: usize, len: usize) -> CudaResult<DeviceSlice<'_, T>> {
281 let end = offset.checked_add(len).ok_or(CudaError::InvalidValue)?;
282 if end > self.len {
283 return Err(CudaError::InvalidValue);
284 }
285 let byte_offset = offset
286 .checked_mul(std::mem::size_of::<T>())
287 .ok_or(CudaError::InvalidValue)?;
288 Ok(DeviceSlice {
289 ptr: self.ptr + byte_offset as u64,
290 len,
291 _phantom: PhantomData,
292 })
293 }
294}
295
296impl<T: Copy> Drop for DeviceBuffer<T> {
297 fn drop(&mut self) {
298 if let Ok(api) = try_driver() {
299 // SAFETY: `self.ptr` was allocated by `cu_mem_alloc_v2` and has
300 // not yet been freed.
301 let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
302 if rc != 0 {
303 tracing::warn!(
304 cuda_error = rc,
305 ptr = self.ptr,
306 len = self.len,
307 "cuMemFree_v2 failed during DeviceBuffer drop"
308 );
309 }
310 }
311 }
312}
313
314// ---------------------------------------------------------------------------
315// DeviceSlice<'a, T>
316// ---------------------------------------------------------------------------
317
318/// A borrowed, non-owning view into a sub-range of a [`DeviceBuffer`].
319///
320/// A `DeviceSlice` does not own the memory it points to — it borrows from
321/// the parent [`DeviceBuffer`] and is lifetime-bound to it. This is useful
322/// for passing sub-regions of a buffer to kernels or copy operations without
323/// extra allocations.
324///
325/// `DeviceSlice` does **not** implement [`Drop`]; the parent buffer is
326/// responsible for freeing the allocation.
327pub struct DeviceSlice<'a, T: Copy> {
328 /// Raw device pointer to the start of this slice within the parent buffer.
329 ptr: CUdeviceptr,
330 /// Number of `T` elements in this slice.
331 len: usize,
332 /// Ties the lifetime to the parent buffer and the element type.
333 _phantom: PhantomData<&'a T>,
334}
335
336impl<T: Copy> DeviceSlice<'_, T> {
337 /// Returns the number of `T` elements in this slice.
338 #[inline]
339 pub fn len(&self) -> usize {
340 self.len
341 }
342
343 /// Returns `true` if the slice contains zero elements.
344 #[inline]
345 pub fn is_empty(&self) -> bool {
346 self.len == 0
347 }
348
349 /// Returns the total size of this slice in bytes.
350 #[inline]
351 pub fn byte_size(&self) -> usize {
352 self.len * std::mem::size_of::<T>()
353 }
354
355 /// Returns the raw [`CUdeviceptr`] handle for the start of this slice.
356 #[inline]
357 pub fn as_device_ptr(&self) -> CUdeviceptr {
358 self.ptr
359 }
360}