oxicuda_memory/unified.rs
1//! Unified (managed) memory buffer.
2//!
3//! [`UnifiedBuffer<T>`] wraps `cuMemAllocManaged`, which allocates memory
4//! that is automatically migrated between host and device by the CUDA
5//! Unified Memory subsystem. The allocation is accessible from both CPU
6//! code (via [`as_slice`](UnifiedBuffer::as_slice) /
7//! [`as_mut_slice`](UnifiedBuffer::as_mut_slice)) and GPU kernels (via
8//! [`as_device_ptr`](UnifiedBuffer::as_device_ptr)).
9//!
10//! # Coherence caveat
11//!
12//! The host-side accessors are only safe to call when no GPU kernel is
13//! concurrently reading or writing the same memory. After launching a
14//! kernel that touches a unified buffer, synchronise the stream (or the
15//! entire context) before accessing the data from the host.
16//!
17//! # Ownership
18//!
19//! The allocation is freed with `cuMemFree_v2` on drop. Errors during
20//! drop are logged via [`tracing::warn`].
21//!
22//! # Example
23//!
24//! ```rust,no_run
25//! # use oxicuda_memory::UnifiedBuffer;
26//! let mut ubuf = UnifiedBuffer::<f32>::alloc(512)?;
27//! // Write from the host side (no kernel running).
28//! for (i, v) in ubuf.as_mut_slice().iter_mut().enumerate() {
29//! *v = i as f32;
30//! }
31//! // Pass ubuf.as_device_ptr() to a kernel…
32//! # Ok::<(), oxicuda_driver::error::CudaError>(())
33//! ```
34
35use std::marker::PhantomData;
36
37use oxicuda_driver::error::{CudaError, CudaResult};
38use oxicuda_driver::ffi::{CU_MEM_ATTACH_GLOBAL, CUdeviceptr};
39use oxicuda_driver::loader::try_driver;
40
41// ---------------------------------------------------------------------------
42// UnifiedBuffer<T>
43// ---------------------------------------------------------------------------
44
45/// A contiguous buffer of `T` elements in CUDA unified (managed) memory.
46///
47/// Unified memory is accessible from both the host CPU and the GPU device.
48/// The CUDA driver transparently migrates pages between host and device as
49/// needed. This simplifies programming at the cost of potential migration
50/// overhead compared to explicit device buffers.
51pub struct UnifiedBuffer<T: Copy> {
52 /// The CUDA device pointer. For managed memory this value is also a
53 /// valid host pointer (on 64-bit systems with UVA).
54 ptr: CUdeviceptr,
55 /// Host-accessible pointer derived from `ptr`.
56 host_ptr: *mut T,
57 /// Number of `T` elements (not bytes).
58 len: usize,
59 /// Marker to tie the generic parameter `T` to this struct.
60 _phantom: PhantomData<T>,
61}
62
63// SAFETY: Unified memory is accessible from any thread on both host and
64// device. Proper synchronisation is the caller's responsibility.
65unsafe impl<T: Copy + Send> Send for UnifiedBuffer<T> {}
66unsafe impl<T: Copy + Sync> Sync for UnifiedBuffer<T> {}
67
68impl<T: Copy> UnifiedBuffer<T> {
69 /// Allocates a unified memory buffer capable of holding `n` elements of
70 /// type `T`.
71 ///
72 /// The memory is allocated with [`CU_MEM_ATTACH_GLOBAL`], making it
73 /// accessible from any stream on any device in the system.
74 ///
75 /// # Errors
76 ///
77 /// * [`CudaError::InvalidValue`] if `n` is zero.
78 /// * [`CudaError::OutOfMemory`] if the allocation fails.
79 /// * Other driver errors from `cuMemAllocManaged`.
80 pub fn alloc(n: usize) -> CudaResult<Self> {
81 if n == 0 {
82 return Err(CudaError::InvalidValue);
83 }
84 let byte_size = n
85 .checked_mul(std::mem::size_of::<T>())
86 .ok_or(CudaError::InvalidValue)?;
87 let api = try_driver()?;
88 let mut dev_ptr: CUdeviceptr = 0;
89 // SAFETY: `cu_mem_alloc_managed` writes a valid device pointer that
90 // is also host-accessible (UVA).
91 let rc =
92 unsafe { (api.cu_mem_alloc_managed)(&mut dev_ptr, byte_size, CU_MEM_ATTACH_GLOBAL) };
93 oxicuda_driver::check(rc)?;
94 // On 64-bit systems with UVA, the device pointer value is the same
95 // as the host virtual address.
96 let host_ptr = dev_ptr as *mut T;
97 Ok(Self {
98 ptr: dev_ptr,
99 host_ptr,
100 len: n,
101 _phantom: PhantomData,
102 })
103 }
104
105 /// Returns the number of `T` elements in this buffer.
106 #[inline]
107 pub fn len(&self) -> usize {
108 self.len
109 }
110
111 /// Returns `true` if the buffer contains zero elements.
112 #[inline]
113 pub fn is_empty(&self) -> bool {
114 self.len == 0
115 }
116
117 /// Returns the total size of the allocation in bytes.
118 #[inline]
119 pub fn byte_size(&self) -> usize {
120 self.len * std::mem::size_of::<T>()
121 }
122
123 /// Returns the raw [`CUdeviceptr`] handle for use in kernel launches
124 /// and other device-side operations.
125 #[inline]
126 pub fn as_device_ptr(&self) -> CUdeviceptr {
127 self.ptr
128 }
129
130 /// Returns a shared slice over the buffer's host-accessible contents.
131 ///
132 /// # Safety note
133 ///
134 /// This is only safe to call when no GPU kernel is concurrently
135 /// reading or writing this buffer. Synchronise the relevant stream
136 /// or context before calling this method.
137 #[inline]
138 pub fn as_slice(&self) -> &[T] {
139 // SAFETY: `host_ptr` is valid for `len` elements when no device
140 // kernel is concurrently accessing the memory. The caller is
141 // responsible for proper synchronisation.
142 unsafe { std::slice::from_raw_parts(self.host_ptr, self.len) }
143 }
144
145 /// Returns a mutable slice over the buffer's host-accessible contents.
146 ///
147 /// # Safety note
148 ///
149 /// This is only safe to call when no GPU kernel is concurrently
150 /// reading or writing this buffer. Synchronise the relevant stream
151 /// or context before calling this method.
152 #[inline]
153 pub fn as_mut_slice(&mut self) -> &mut [T] {
154 // SAFETY: `host_ptr` is valid for `len` elements when no device
155 // kernel is concurrently accessing the memory. The caller is
156 // responsible for proper synchronisation.
157 unsafe { std::slice::from_raw_parts_mut(self.host_ptr, self.len) }
158 }
159}
160
161impl<T: Copy> Drop for UnifiedBuffer<T> {
162 fn drop(&mut self) {
163 if let Ok(api) = try_driver() {
164 // SAFETY: `self.ptr` was allocated by `cu_mem_alloc_managed`
165 // and has not yet been freed.
166 let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
167 if rc != 0 {
168 tracing::warn!(
169 cuda_error = rc,
170 ptr = self.ptr,
171 len = self.len,
172 "cuMemFree_v2 failed during UnifiedBuffer drop"
173 );
174 }
175 }
176 }
177}