Skip to main content

oxicuda_memory/
unified.rs

1//! Unified (managed) memory buffer.
2//!
3//! [`UnifiedBuffer<T>`] wraps `cuMemAllocManaged`, which allocates memory
4//! that is automatically migrated between host and device by the CUDA
5//! Unified Memory subsystem.  The allocation is accessible from both CPU
6//! code (via [`as_slice`](UnifiedBuffer::as_slice) /
7//! [`as_mut_slice`](UnifiedBuffer::as_mut_slice)) and GPU kernels (via
8//! [`as_device_ptr`](UnifiedBuffer::as_device_ptr)).
9//!
10//! # Coherence caveat
11//!
12//! The host-side accessors are only safe to call when no GPU kernel is
13//! concurrently reading or writing the same memory.  After launching a
14//! kernel that touches a unified buffer, synchronise the stream (or the
15//! entire context) before accessing the data from the host.
16//!
17//! # Ownership
18//!
19//! The allocation is freed with `cuMemFree_v2` on drop.  Errors during
20//! drop are logged via [`tracing::warn`].
21//!
22//! # Example
23//!
24//! ```rust,no_run
25//! # use oxicuda_memory::UnifiedBuffer;
26//! let mut ubuf = UnifiedBuffer::<f32>::alloc(512)?;
27//! // Write from the host side (no kernel running).
28//! for (i, v) in ubuf.as_mut_slice().iter_mut().enumerate() {
29//!     *v = i as f32;
30//! }
31//! // Pass ubuf.as_device_ptr() to a kernel…
32//! # Ok::<(), oxicuda_driver::error::CudaError>(())
33//! ```
34
35use std::marker::PhantomData;
36
37use oxicuda_driver::error::{CudaError, CudaResult};
38use oxicuda_driver::ffi::{CU_MEM_ATTACH_GLOBAL, CUdeviceptr};
39use oxicuda_driver::loader::try_driver;
40
41// ---------------------------------------------------------------------------
42// UnifiedBuffer<T>
43// ---------------------------------------------------------------------------
44
45/// A contiguous buffer of `T` elements in CUDA unified (managed) memory.
46///
47/// Unified memory is accessible from both the host CPU and the GPU device.
48/// The CUDA driver transparently migrates pages between host and device as
49/// needed.  This simplifies programming at the cost of potential migration
50/// overhead compared to explicit device buffers.
51pub struct UnifiedBuffer<T: Copy> {
52    /// The CUDA device pointer.  For managed memory this value is also a
53    /// valid host pointer (on 64-bit systems with UVA).
54    ptr: CUdeviceptr,
55    /// Host-accessible pointer derived from `ptr`.
56    host_ptr: *mut T,
57    /// Number of `T` elements (not bytes).
58    len: usize,
59    /// Marker to tie the generic parameter `T` to this struct.
60    _phantom: PhantomData<T>,
61}
62
63// SAFETY: Unified memory is accessible from any thread on both host and
64// device.  Proper synchronisation is the caller's responsibility.
65unsafe impl<T: Copy + Send> Send for UnifiedBuffer<T> {}
66unsafe impl<T: Copy + Sync> Sync for UnifiedBuffer<T> {}
67
68impl<T: Copy> UnifiedBuffer<T> {
69    /// Allocates a unified memory buffer capable of holding `n` elements of
70    /// type `T`.
71    ///
72    /// The memory is allocated with [`CU_MEM_ATTACH_GLOBAL`], making it
73    /// accessible from any stream on any device in the system.
74    ///
75    /// # Errors
76    ///
77    /// * [`CudaError::InvalidValue`] if `n` is zero.
78    /// * [`CudaError::OutOfMemory`] if the allocation fails.
79    /// * Other driver errors from `cuMemAllocManaged`.
80    pub fn alloc(n: usize) -> CudaResult<Self> {
81        if n == 0 {
82            return Err(CudaError::InvalidValue);
83        }
84        let byte_size = n
85            .checked_mul(std::mem::size_of::<T>())
86            .ok_or(CudaError::InvalidValue)?;
87        let api = try_driver()?;
88        let mut dev_ptr: CUdeviceptr = 0;
89        // SAFETY: `cu_mem_alloc_managed` writes a valid device pointer that
90        // is also host-accessible (UVA).
91        let rc =
92            unsafe { (api.cu_mem_alloc_managed)(&mut dev_ptr, byte_size, CU_MEM_ATTACH_GLOBAL) };
93        oxicuda_driver::check(rc)?;
94        // On 64-bit systems with UVA, the device pointer value is the same
95        // as the host virtual address.
96        let host_ptr = dev_ptr as *mut T;
97        Ok(Self {
98            ptr: dev_ptr,
99            host_ptr,
100            len: n,
101            _phantom: PhantomData,
102        })
103    }
104
105    /// Returns the number of `T` elements in this buffer.
106    #[inline]
107    pub fn len(&self) -> usize {
108        self.len
109    }
110
111    /// Returns `true` if the buffer contains zero elements.
112    #[inline]
113    pub fn is_empty(&self) -> bool {
114        self.len == 0
115    }
116
117    /// Returns the total size of the allocation in bytes.
118    #[inline]
119    pub fn byte_size(&self) -> usize {
120        self.len * std::mem::size_of::<T>()
121    }
122
123    /// Returns the raw [`CUdeviceptr`] handle for use in kernel launches
124    /// and other device-side operations.
125    #[inline]
126    pub fn as_device_ptr(&self) -> CUdeviceptr {
127        self.ptr
128    }
129
130    /// Returns a shared slice over the buffer's host-accessible contents.
131    ///
132    /// # Safety note
133    ///
134    /// This is only safe to call when no GPU kernel is concurrently
135    /// reading or writing this buffer.  Synchronise the relevant stream
136    /// or context before calling this method.
137    #[inline]
138    pub fn as_slice(&self) -> &[T] {
139        // SAFETY: `host_ptr` is valid for `len` elements when no device
140        // kernel is concurrently accessing the memory.  The caller is
141        // responsible for proper synchronisation.
142        unsafe { std::slice::from_raw_parts(self.host_ptr, self.len) }
143    }
144
145    /// Returns a mutable slice over the buffer's host-accessible contents.
146    ///
147    /// # Safety note
148    ///
149    /// This is only safe to call when no GPU kernel is concurrently
150    /// reading or writing this buffer.  Synchronise the relevant stream
151    /// or context before calling this method.
152    #[inline]
153    pub fn as_mut_slice(&mut self) -> &mut [T] {
154        // SAFETY: `host_ptr` is valid for `len` elements when no device
155        // kernel is concurrently accessing the memory.  The caller is
156        // responsible for proper synchronisation.
157        unsafe { std::slice::from_raw_parts_mut(self.host_ptr, self.len) }
158    }
159}
160
161impl<T: Copy> Drop for UnifiedBuffer<T> {
162    fn drop(&mut self) {
163        if let Ok(api) = try_driver() {
164            // SAFETY: `self.ptr` was allocated by `cu_mem_alloc_managed`
165            // and has not yet been freed.
166            let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
167            if rc != 0 {
168                tracing::warn!(
169                    cuda_error = rc,
170                    ptr = self.ptr,
171                    len = self.len,
172                    "cuMemFree_v2 failed during UnifiedBuffer drop"
173                );
174            }
175        }
176    }
177}