Skip to main content

oxicuda_memory/
device_buffer.rs

1//! Type-safe device (GPU VRAM) memory buffer.
2//!
3//! [`DeviceBuffer<T>`] owns a contiguous allocation of `T` elements in device
4//! memory.  It supports synchronous and asynchronous copies to/from host
5//! memory, device-to-device copies, and zero-initialisation via `cuMemsetD8`.
6//!
7//! The buffer is parameterised over `T: Copy` so that only plain-old-data
8//! types can be stored — no heap pointers that would be meaningless on the
9//! GPU.
10//!
11//! # Ownership
12//!
13//! The allocation is freed automatically when the buffer is dropped.  If
14//! `cuMemFree_v2` fails during [`Drop`], the error is logged via
15//! [`tracing::warn`] rather than panicking.
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! # use oxicuda_memory::DeviceBuffer;
21//! let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
22//! let host_data = vec![1.0_f32; 1024];
23//! buf.copy_from_host(&host_data)?;
24//!
25//! let mut result = vec![0.0_f32; 1024];
26//! buf.copy_to_host(&mut result)?;
27//! assert_eq!(result, host_data);
28//! # Ok::<(), oxicuda_driver::error::CudaError>(())
29//! ```
30
31use std::ffi::c_void;
32use std::marker::PhantomData;
33
34use oxicuda_driver::error::{CudaError, CudaResult};
35use oxicuda_driver::ffi::CUdeviceptr;
36use oxicuda_driver::loader::try_driver;
37use oxicuda_driver::stream::Stream;
38
39// ---------------------------------------------------------------------------
40// DeviceBuffer<T>
41// ---------------------------------------------------------------------------
42
43/// A contiguous buffer of `T` elements allocated in GPU device memory.
44///
45/// The buffer owns the underlying `CUdeviceptr` allocation and frees it on
46/// drop.  All copy operations validate that source and destination lengths
47/// match, returning [`CudaError::InvalidValue`] on mismatch.
48pub struct DeviceBuffer<T: Copy> {
49    /// Raw CUDA device pointer to the start of the allocation.
50    ptr: CUdeviceptr,
51    /// Number of `T` elements (not bytes).
52    len: usize,
53    /// Marker to tie the generic parameter `T` to this struct.
54    _phantom: PhantomData<T>,
55}
56
57// SAFETY: Device memory is not bound to a specific host thread.  The raw
58// pointer is a `u64` handle managed by the CUDA driver, which is thread-safe
59// for memory operations when properly synchronised.
60unsafe impl<T: Copy + Send> Send for DeviceBuffer<T> {}
61unsafe impl<T: Copy + Sync> Sync for DeviceBuffer<T> {}
62
63impl<T: Copy> DeviceBuffer<T> {
64    /// Allocates a device buffer capable of holding `n` elements of type `T`.
65    ///
66    /// # Errors
67    ///
68    /// * [`CudaError::InvalidValue`] if `n` is zero.
69    /// * [`CudaError::OutOfMemory`] if the GPU cannot satisfy the request.
70    /// * Other driver errors propagated from `cuMemAlloc_v2`.
71    pub fn alloc(n: usize) -> CudaResult<Self> {
72        if n == 0 {
73            return Err(CudaError::InvalidValue);
74        }
75        let byte_size = n
76            .checked_mul(std::mem::size_of::<T>())
77            .ok_or(CudaError::InvalidValue)?;
78        let api = try_driver()?;
79        let mut ptr: CUdeviceptr = 0;
80        // SAFETY: `cu_mem_alloc_v2` writes a valid device pointer on success.
81        let rc = unsafe { (api.cu_mem_alloc_v2)(&mut ptr, byte_size) };
82        oxicuda_driver::check(rc)?;
83        Ok(Self {
84            ptr,
85            len: n,
86            _phantom: PhantomData,
87        })
88    }
89
90    /// Allocates a device buffer of `n` elements and zero-initialises every byte.
91    ///
92    /// This is equivalent to [`alloc`](Self::alloc) followed by a
93    /// `cuMemsetD8_v2` call that writes `0` to every byte.
94    ///
95    /// # Errors
96    ///
97    /// Same as [`alloc`](Self::alloc), plus any error from `cuMemsetD8_v2`.
98    pub fn zeroed(n: usize) -> CudaResult<Self> {
99        let buf = Self::alloc(n)?;
100        let api = try_driver()?;
101        // SAFETY: the buffer was just allocated with the correct byte size.
102        let rc = unsafe { (api.cu_memset_d8_v2)(buf.ptr, 0, buf.byte_size()) };
103        oxicuda_driver::check(rc)?;
104        Ok(buf)
105    }
106
107    /// Allocates a device buffer and copies the contents of `data` into it.
108    ///
109    /// The resulting buffer has the same length as the input slice.
110    ///
111    /// # Errors
112    ///
113    /// * [`CudaError::InvalidValue`] if `data` is empty.
114    /// * Other driver errors from allocation or the host-to-device copy.
115    pub fn from_host(data: &[T]) -> CudaResult<Self> {
116        let mut buf = Self::alloc(data.len())?;
117        buf.copy_from_host(data)?;
118        Ok(buf)
119    }
120
121    /// Copies data from a host slice into this device buffer (synchronous).
122    ///
123    /// The slice length must exactly match the buffer length.
124    ///
125    /// # Errors
126    ///
127    /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
128    /// * Other driver errors from `cuMemcpyHtoD_v2`.
129    pub fn copy_from_host(&mut self, src: &[T]) -> CudaResult<()> {
130        if src.len() != self.len {
131            return Err(CudaError::InvalidValue);
132        }
133        let api = try_driver()?;
134        // SAFETY: `src` is a valid host slice with the correct byte count.
135        let rc = unsafe {
136            (api.cu_memcpy_htod_v2)(self.ptr, src.as_ptr().cast::<c_void>(), self.byte_size())
137        };
138        oxicuda_driver::check(rc)
139    }
140
141    /// Copies this device buffer's contents into a host slice (synchronous).
142    ///
143    /// The slice length must exactly match the buffer length.
144    ///
145    /// # Errors
146    ///
147    /// * [`CudaError::InvalidValue`] if `dst.len() != self.len()`.
148    /// * Other driver errors from `cuMemcpyDtoH_v2`.
149    pub fn copy_to_host(&self, dst: &mut [T]) -> CudaResult<()> {
150        if dst.len() != self.len {
151            return Err(CudaError::InvalidValue);
152        }
153        let api = try_driver()?;
154        // SAFETY: `dst` is a valid host slice with the correct byte count.
155        let rc = unsafe {
156            (api.cu_memcpy_dtoh_v2)(
157                dst.as_mut_ptr().cast::<c_void>(),
158                self.ptr,
159                self.byte_size(),
160            )
161        };
162        oxicuda_driver::check(rc)
163    }
164
165    /// Copies the entire contents of another device buffer into this one.
166    ///
167    /// Both buffers must have the same length.
168    ///
169    /// # Errors
170    ///
171    /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
172    /// * Other driver errors from `cuMemcpyDtoD_v2`.
173    pub fn copy_from_device(&mut self, src: &DeviceBuffer<T>) -> CudaResult<()> {
174        if src.len != self.len {
175            return Err(CudaError::InvalidValue);
176        }
177        let api = try_driver()?;
178        // SAFETY: both pointers are valid device allocations of the same size.
179        let rc = unsafe { (api.cu_memcpy_dtod_v2)(self.ptr, src.ptr, self.byte_size()) };
180        oxicuda_driver::check(rc)
181    }
182
183    /// Asynchronously copies data from a host slice into this device buffer.
184    ///
185    /// The copy is enqueued on `stream` and may not be complete when this
186    /// function returns.  The caller must ensure that `src` remains valid
187    /// (i.e., is not moved or dropped) until the stream has been
188    /// synchronised.  For guaranteed correctness, prefer using a
189    /// [`PinnedBuffer`](crate::PinnedBuffer) as the source.
190    ///
191    /// # Errors
192    ///
193    /// * [`CudaError::InvalidValue`] if `src.len() != self.len()`.
194    /// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
195    pub fn copy_from_host_async(&mut self, src: &[T], stream: &Stream) -> CudaResult<()> {
196        if src.len() != self.len {
197            return Err(CudaError::InvalidValue);
198        }
199        let api = try_driver()?;
200        // SAFETY: the caller is responsible for keeping `src` alive until
201        // the stream completes.
202        let rc = unsafe {
203            (api.cu_memcpy_htod_async_v2)(
204                self.ptr,
205                src.as_ptr().cast::<c_void>(),
206                self.byte_size(),
207                stream.raw(),
208            )
209        };
210        oxicuda_driver::check(rc)
211    }
212
213    /// Asynchronously copies this device buffer's contents into a host slice.
214    ///
215    /// The copy is enqueued on `stream` and may not be complete when this
216    /// function returns.  The caller must ensure that `dst` remains valid
217    /// and is not read until the stream has been synchronised.  For
218    /// guaranteed correctness, prefer using a
219    /// [`PinnedBuffer`](crate::PinnedBuffer) as the destination.
220    ///
221    /// # Errors
222    ///
223    /// * [`CudaError::InvalidValue`] if `dst.len() != self.len()`.
224    /// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
225    pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> CudaResult<()> {
226        if dst.len() != self.len {
227            return Err(CudaError::InvalidValue);
228        }
229        let api = try_driver()?;
230        // SAFETY: the caller is responsible for keeping `dst` alive until
231        // the stream completes.
232        let rc = unsafe {
233            (api.cu_memcpy_dtoh_async_v2)(
234                dst.as_mut_ptr().cast::<c_void>(),
235                self.ptr,
236                self.byte_size(),
237                stream.raw(),
238            )
239        };
240        oxicuda_driver::check(rc)
241    }
242
243    /// Returns the number of `T` elements in this buffer.
244    #[inline]
245    pub fn len(&self) -> usize {
246        self.len
247    }
248
249    /// Returns `true` if the buffer contains zero elements.
250    ///
251    /// In practice this is always `false` because [`alloc`](Self::alloc)
252    /// rejects zero-length allocations.
253    #[inline]
254    pub fn is_empty(&self) -> bool {
255        self.len == 0
256    }
257
258    /// Returns the total size of the allocation in bytes.
259    #[inline]
260    pub fn byte_size(&self) -> usize {
261        self.len * std::mem::size_of::<T>()
262    }
263
264    /// Returns the raw [`CUdeviceptr`] handle for this buffer.
265    ///
266    /// This is useful when passing the pointer to kernel launch parameters
267    /// or other low-level driver calls.
268    #[inline]
269    pub fn as_device_ptr(&self) -> CUdeviceptr {
270        self.ptr
271    }
272
273    /// Returns a borrowed [`DeviceSlice`] referencing a sub-range of this
274    /// buffer starting at element `offset` and spanning `len` elements.
275    ///
276    /// # Errors
277    ///
278    /// Returns [`CudaError::InvalidValue`] if the requested range exceeds
279    /// the buffer bounds (i.e., `offset + len > self.len()`).
280    pub fn slice(&self, offset: usize, len: usize) -> CudaResult<DeviceSlice<'_, T>> {
281        let end = offset.checked_add(len).ok_or(CudaError::InvalidValue)?;
282        if end > self.len {
283            return Err(CudaError::InvalidValue);
284        }
285        let byte_offset = offset
286            .checked_mul(std::mem::size_of::<T>())
287            .ok_or(CudaError::InvalidValue)?;
288        Ok(DeviceSlice {
289            ptr: self.ptr + byte_offset as u64,
290            len,
291            _phantom: PhantomData,
292        })
293    }
294}
295
296impl<T: Copy> Drop for DeviceBuffer<T> {
297    fn drop(&mut self) {
298        if let Ok(api) = try_driver() {
299            // SAFETY: `self.ptr` was allocated by `cu_mem_alloc_v2` and has
300            // not yet been freed.
301            let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
302            if rc != 0 {
303                tracing::warn!(
304                    cuda_error = rc,
305                    ptr = self.ptr,
306                    len = self.len,
307                    "cuMemFree_v2 failed during DeviceBuffer drop"
308                );
309            }
310        }
311    }
312}
313
314// ---------------------------------------------------------------------------
315// DeviceSlice<'a, T>
316// ---------------------------------------------------------------------------
317
318/// A borrowed, non-owning view into a sub-range of a [`DeviceBuffer`].
319///
320/// A `DeviceSlice` does not own the memory it points to — it borrows from
321/// the parent [`DeviceBuffer`] and is lifetime-bound to it.  This is useful
322/// for passing sub-regions of a buffer to kernels or copy operations without
323/// extra allocations.
324///
325/// `DeviceSlice` does **not** implement [`Drop`]; the parent buffer is
326/// responsible for freeing the allocation.
327pub struct DeviceSlice<'a, T: Copy> {
328    /// Raw device pointer to the start of this slice within the parent buffer.
329    ptr: CUdeviceptr,
330    /// Number of `T` elements in this slice.
331    len: usize,
332    /// Ties the lifetime to the parent buffer and the element type.
333    _phantom: PhantomData<&'a T>,
334}
335
336impl<T: Copy> DeviceSlice<'_, T> {
337    /// Returns the number of `T` elements in this slice.
338    #[inline]
339    pub fn len(&self) -> usize {
340        self.len
341    }
342
343    /// Returns `true` if the slice contains zero elements.
344    #[inline]
345    pub fn is_empty(&self) -> bool {
346        self.len == 0
347    }
348
349    /// Returns the total size of this slice in bytes.
350    #[inline]
351    pub fn byte_size(&self) -> usize {
352        self.len * std::mem::size_of::<T>()
353    }
354
355    /// Returns the raw [`CUdeviceptr`] handle for the start of this slice.
356    #[inline]
357    pub fn as_device_ptr(&self) -> CUdeviceptr {
358        self.ptr
359    }
360}