dynamo_memory/pool/cuda.rs
1// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! CUDA memory pool for efficient device memory allocation in hot paths.
5//!
6//! This module provides a safe wrapper around CUDA's memory pool APIs, enabling
7//! fast async allocations that avoid the overhead of cudaMalloc/cudaFree per call.
8//! Memory is returned to the pool on free and reused for subsequent allocations.
9
10use anyhow::{Result, anyhow};
11use cudarc::driver::sys::{
12 self, CUmemAllocationType, CUmemLocationType, CUmemPool_attribute, CUmemPoolProps,
13 CUmemoryPool, CUresult, CUstream,
14};
15use cudarc::driver::{CudaContext, CudaStream};
16use std::ptr;
17use std::sync::{Arc, Mutex};
18
19/// Builder for creating a CUDA memory pool with configurable parameters.
20///
21/// # Example
22/// ```ignore
23/// let pool = CudaMemPoolBuilder::new(context, 64 * 1024 * 1024) // 64 MiB reserve
24/// .release_threshold(32 * 1024 * 1024) // 32 MiB release threshold
25/// .build()?;
26/// ```
27pub struct CudaMemPoolBuilder {
28 /// CUDA context for the target device.
29 context: Arc<CudaContext>,
30 /// Bytes to pre-allocate to warm the pool.
31 reserve_size: usize,
32 /// Optional threshold above which memory is returned to the system on free.
33 release_threshold: Option<u64>,
34}
35
36impl CudaMemPoolBuilder {
37 /// Create a new builder with the required reserve size.
38 ///
39 /// # Arguments
40 /// * `context` - CUDA context for the device
41 /// * `reserve_size` - Number of bytes to pre-allocate to warm the pool
42 pub fn new(context: Arc<CudaContext>, reserve_size: usize) -> Self {
43 Self {
44 context,
45 reserve_size,
46 release_threshold: None,
47 }
48 }
49
50 /// Set the release threshold for the pool.
51 ///
52 /// Memory above this threshold is returned to the system when freed.
53 /// If not set, no release threshold is configured (CUDA default behavior).
54 pub fn release_threshold(mut self, threshold: u64) -> Self {
55 self.release_threshold = Some(threshold);
56 self
57 }
58
59 /// Build the CUDA memory pool.
60 ///
61 /// This will:
62 /// 1. Create the pool
63 /// 2. Set the release threshold if configured
64 /// 3. Pre-allocate and free memory to warm the pool
65 pub fn build(self) -> Result<CudaMemPool> {
66 // Initialize pool properties
67 let mut props: CUmemPoolProps = unsafe { std::mem::zeroed() };
68 props.allocType = CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
69 props.location.type_ = CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
70 props.location.id = self.context.cu_device();
71
72 let mut pool: CUmemoryPool = ptr::null_mut();
73
74 // Create the pool
75 let result = unsafe { sys::cuMemPoolCreate(&mut pool, &props) };
76 if result != CUresult::CUDA_SUCCESS {
77 return Err(anyhow!("cuMemPoolCreate failed with error: {:?}", result));
78 }
79
80 // Set release threshold if configured
81 if let Some(threshold) = self.release_threshold {
82 let result = unsafe {
83 sys::cuMemPoolSetAttribute(
84 pool,
85 CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD,
86 &threshold as *const u64 as *mut std::ffi::c_void,
87 )
88 };
89 if result != CUresult::CUDA_SUCCESS {
90 // Clean up on failure
91 unsafe { sys::cuMemPoolDestroy(pool) };
92 return Err(anyhow!(
93 "cuMemPoolSetAttribute failed with error: {:?}",
94 result
95 ));
96 }
97 }
98
99 let cuda_pool = CudaMemPool {
100 inner: Mutex::new(pool),
101 };
102
103 // Warm the pool by pre-allocating and freeing memory
104 if self.reserve_size > 0 {
105 // Create a temporary stream for warming
106 let stream = self.context.new_stream()?;
107
108 // Allocate to warm the pool (using safe variant)
109 let ptr = cuda_pool.alloc_async(self.reserve_size, &stream)?;
110
111 // Free back to pool (memory stays reserved)
112 cuda_pool.free_async(ptr, &stream)?;
113
114 // Synchronize to ensure operations complete
115 // SAFETY: stream.cu_stream() is valid for the lifetime of `stream`
116 let result = unsafe { sys::cuStreamSynchronize(stream.cu_stream()) };
117 if result != CUresult::CUDA_SUCCESS {
118 return Err(anyhow!(
119 "cuStreamSynchronize failed with error: {:?}",
120 result
121 ));
122 }
123 }
124
125 Ok(cuda_pool)
126 }
127}
128
129/// Safe wrapper around a CUDA memory pool.
130///
131/// The pool amortizes allocation overhead by maintaining a reservoir of device memory.
132/// Allocations are fast sub-allocations from this reservoir, and frees return memory
133/// to the pool rather than the OS (until the release threshold is exceeded).
134///
135/// # Thread Safety
136///
137/// This type uses internal locking to serialize host-side calls to CUDA driver APIs.
138/// `cuMemAllocFromPoolAsync` is not host-thread reentrant, so concurrent calls from
139/// multiple threads must be serialized. The GPU-side operations remain asynchronous
140/// and stream-ordered.
141///
142/// Use [`CudaMemPoolBuilder`] for configurable pool creation with pre-allocation.
143pub struct CudaMemPool {
144 /// Mutex protecting the pool handle for host-thread serialization.
145 ///
146 /// CUDA's `cuMemAllocFromPoolAsync` does not guarantee host-thread reentrancy,
147 /// so all calls to the pool must be serialized on the host side.
148 inner: Mutex<CUmemoryPool>,
149}
150
151// SAFETY: CudaMemPool is Send because the Mutex serializes all host-side access
152// to the pool handle, and CUDA driver state is thread-safe when properly serialized.
153unsafe impl Send for CudaMemPool {}
154
155// SAFETY: CudaMemPool is Sync because all access to the pool handle goes through
156// the Mutex, which serializes host-thread access. The CUDA driver requires this
157// serialization because cuMemAllocFromPoolAsync is not host-thread reentrant.
158unsafe impl Sync for CudaMemPool {}
159
160impl CudaMemPool {
161 /// Create a builder for a new CUDA memory pool.
162 ///
163 /// # Arguments
164 /// * `context` - CUDA context for the device
165 /// * `reserve_size` - Number of bytes to pre-allocate to warm the pool
166 pub fn builder(context: Arc<CudaContext>, reserve_size: usize) -> CudaMemPoolBuilder {
167 CudaMemPoolBuilder::new(context, reserve_size)
168 }
169
170 /// Allocate memory from the pool asynchronously.
171 ///
172 /// This is the safe variant that takes a `&CudaStream` reference, ensuring
173 /// the stream is valid for the duration of the call.
174 ///
175 /// The allocation is stream-ordered; the memory is available for use
176 /// after all preceding operations on the stream complete.
177 ///
178 /// # Host Serialization
179 ///
180 /// This method acquires an internal mutex because `cuMemAllocFromPoolAsync`
181 /// is not host-thread reentrant. The allocation itself is stream-ordered on
182 /// the GPU side.
183 ///
184 /// # Arguments
185 /// * `size` - Size in bytes to allocate
186 /// * `stream` - CUDA stream for async ordering
187 ///
188 /// # Returns
189 /// Device pointer to the allocated memory
190 pub fn alloc_async(&self, size: usize, stream: &CudaStream) -> Result<u64> {
191 // SAFETY: stream.cu_stream() returns a valid handle owned by the CudaStream,
192 // and the borrow ensures the stream lives for the duration of this call.
193 unsafe { self.alloc_async_raw(size, stream.cu_stream()) }
194 }
195
196 /// Allocate memory from the pool asynchronously (raw stream handle variant).
197 ///
198 /// This is the unsafe variant for use when you have a raw `CUstream` handle
199 /// from sources other than cudarc's `CudaStream`.
200 ///
201 /// # Host Serialization
202 ///
203 /// This method acquires an internal mutex because `cuMemAllocFromPoolAsync`
204 /// is not host-thread reentrant.
205 ///
206 /// # Arguments
207 /// * `size` - Size in bytes to allocate
208 /// * `stream` - Raw CUDA stream handle for async ordering
209 ///
210 /// # Returns
211 /// Device pointer to the allocated memory
212 ///
213 /// # Safety
214 ///
215 /// The caller must ensure that `stream` is a valid CUDA stream handle that
216 /// will remain valid for the duration of this call.
217 pub unsafe fn alloc_async_raw(&self, size: usize, stream: CUstream) -> Result<u64> {
218 let pool = self
219 .inner
220 .lock()
221 .map_err(|e| anyhow!("mutex poisoned: {}", e))?;
222
223 let mut ptr: u64 = 0;
224
225 let result = unsafe { sys::cuMemAllocFromPoolAsync(&mut ptr, size, *pool, stream) };
226
227 if result != CUresult::CUDA_SUCCESS {
228 return Err(anyhow!(
229 "cuMemAllocFromPoolAsync failed with error: {:?}",
230 result
231 ));
232 }
233
234 Ok(ptr)
235 }
236
237 /// Free memory back to the pool asynchronously.
238 ///
239 /// This is the safe variant that takes a `&CudaStream` reference.
240 ///
241 /// The memory is returned to the pool's reservoir (not the OS) and can be
242 /// reused by subsequent allocations. The free is stream-ordered.
243 ///
244 /// # Arguments
245 /// * `ptr` - Device pointer previously allocated from this pool
246 /// * `stream` - CUDA stream for async ordering
247 pub fn free_async(&self, ptr: u64, stream: &CudaStream) -> Result<()> {
248 // SAFETY: stream.cu_stream() returns a valid handle owned by the CudaStream,
249 // and the borrow ensures the stream lives for the duration of this call.
250 unsafe { self.free_async_raw(ptr, stream.cu_stream()) }
251 }
252
253 // NOTE: Unlike alloc_async_raw, this method does NOT acquire the pool mutex.
254 // The mutex in alloc_async_raw ensures each allocation returns a unique pointer.
255 // cuMemFreeAsync only enqueues a stream-ordered free operation for that unique
256 // pointer - multiple threads can safely enqueue frees for different unique pointers
257 // concurrently. The actual return-to-pool happens asynchronously on the GPU side.
258
259 /// Free memory back to the pool asynchronously (raw stream handle variant).
260 ///
261 /// This is the unsafe variant for use when you have a raw `CUstream` handle.
262 ///
263 /// The memory is returned to the pool's reservoir (not the OS) and can be
264 /// reused by subsequent allocations. The free is stream-ordered.
265 ///
266 /// # Arguments
267 /// * `ptr` - Device pointer previously allocated from this pool
268 /// * `stream` - Raw CUDA stream handle for async ordering
269 ///
270 /// # Safety
271 ///
272 /// The caller must ensure that:
273 /// - `ptr` is a valid device pointer previously allocated from this pool
274 /// - `stream` is a valid CUDA stream handle
275 pub unsafe fn free_async_raw(&self, ptr: u64, stream: CUstream) -> Result<()> {
276 let result = unsafe { sys::cuMemFreeAsync(ptr, stream) };
277
278 if result != CUresult::CUDA_SUCCESS {
279 return Err(anyhow!("cuMemFreeAsync failed with error: {:?}", result));
280 }
281
282 Ok(())
283 }
284}
285
286impl Drop for CudaMemPool {
287 fn drop(&mut self) {
288 // No need to lock - we have &mut self so exclusive access is guaranteed
289 let pool = self
290 .inner
291 .get_mut()
292 .expect("mutex should not be poisoned during drop");
293
294 // Destroy the pool, releasing all memory back to the system
295 let result = unsafe { sys::cuMemPoolDestroy(*pool) };
296 if result != CUresult::CUDA_SUCCESS {
297 tracing::warn!("cuMemPoolDestroy failed with error: {:?}", result);
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_pool_creation_with_builder() {
308 // Skip if no CUDA device available
309 let context = match CudaContext::new(0) {
310 Ok(ctx) => ctx,
311 Err(e) => {
312 eprintln!("Skipping test - no CUDA device: {:?}", e);
313 return;
314 }
315 };
316
317 // Test builder with reserve size and release threshold
318 let result = CudaMemPool::builder(context.clone(), 1024 * 1024) // 1 MiB reserve
319 .release_threshold(64 * 1024 * 1024) // 64 MiB threshold
320 .build();
321
322 if result.is_err() {
323 eprintln!("Skipping test - pool creation failed: {:?}", result.err());
324 return;
325 }
326 let pool = result.unwrap();
327 drop(pool);
328 }
329
330 #[test]
331 fn test_pool_creation_no_threshold() {
332 // Skip if no CUDA device available
333 let context = match CudaContext::new(0) {
334 Ok(ctx) => ctx,
335 Err(e) => {
336 eprintln!("Skipping test - no CUDA device: {:?}", e);
337 return;
338 }
339 };
340
341 // Test builder without release threshold
342 let result = CudaMemPool::builder(context, 0).build();
343
344 if result.is_err() {
345 eprintln!("Skipping test - pool creation failed: {:?}", result.err());
346 return;
347 }
348 let pool = result.unwrap();
349 drop(pool);
350 }
351}