Skip to main content

dynamo_memory/pool/
cuda.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! CUDA memory pool for efficient device memory allocation in hot paths.
5//!
6//! This module provides a safe wrapper around CUDA's memory pool APIs, enabling
7//! fast async allocations that avoid the overhead of cudaMalloc/cudaFree per call.
8//! Memory is returned to the pool on free and reused for subsequent allocations.
9//!
10//! # Thread Safety
11//!
12//! [`CudaMemPool`] uses internal locking to serialize host-side calls to the CUDA
13//! driver. This is required because `cuMemAllocFromPoolAsync` is not host-thread
14//! reentrant. The GPU-side operations remain stream-ordered and asynchronous.
15
16use anyhow::{Result, anyhow};
17use cudarc::driver::sys::{
18    self, CUmemAllocationType, CUmemLocationType, CUmemPool_attribute, CUmemPoolProps,
19    CUmemoryPool, CUresult, CUstream,
20};
21use cudarc::driver::{CudaContext, CudaStream};
22use std::ptr;
23use std::sync::{Arc, Mutex};
24
25/// Builder for creating a CUDA memory pool with configurable parameters.
26///
27/// # Example
28/// ```ignore
29/// let pool = CudaMemPoolBuilder::new(context, 64 * 1024 * 1024) // 64 MiB reserve
30///     .release_threshold(32 * 1024 * 1024) // 32 MiB release threshold
31///     .build()?;
32/// ```
33pub struct CudaMemPoolBuilder {
34    /// CUDA context for the target device.
35    context: Arc<CudaContext>,
36    /// Bytes to pre-allocate to warm the pool.
37    reserve_size: usize,
38    /// Optional threshold above which memory is returned to the system on free.
39    release_threshold: Option<u64>,
40}
41
42impl CudaMemPoolBuilder {
43    /// Create a new builder with the required reserve size.
44    ///
45    /// # Arguments
46    /// * `context` - CUDA context for the device
47    /// * `reserve_size` - Number of bytes to pre-allocate to warm the pool
48    pub fn new(context: Arc<CudaContext>, reserve_size: usize) -> Self {
49        Self {
50            context,
51            reserve_size,
52            release_threshold: None,
53        }
54    }
55
56    /// Set the release threshold for the pool.
57    ///
58    /// Memory above this threshold is returned to the system when freed.
59    /// If not set, no release threshold is configured (CUDA default behavior).
60    pub fn release_threshold(mut self, threshold: u64) -> Self {
61        self.release_threshold = Some(threshold);
62        self
63    }
64
65    /// Build the CUDA memory pool.
66    ///
67    /// This will:
68    /// 1. Create the pool
69    /// 2. Set the release threshold if configured
70    /// 3. Pre-allocate and free memory to warm the pool
71    pub fn build(self) -> Result<CudaMemPool> {
72        // Initialize pool properties
73        let mut props: CUmemPoolProps = unsafe { std::mem::zeroed() };
74        props.allocType = CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
75        props.location.type_ = CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
76        props.location.id = self.context.cu_device();
77
78        let mut pool: CUmemoryPool = ptr::null_mut();
79
80        // Create the pool
81        let result = unsafe { sys::cuMemPoolCreate(&mut pool, &props) };
82        if result != CUresult::CUDA_SUCCESS {
83            return Err(anyhow!("cuMemPoolCreate failed with error: {:?}", result));
84        }
85
86        // Set release threshold if configured
87        if let Some(threshold) = self.release_threshold {
88            let result = unsafe {
89                sys::cuMemPoolSetAttribute(
90                    pool,
91                    CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD,
92                    &threshold as *const u64 as *mut std::ffi::c_void,
93                )
94            };
95            if result != CUresult::CUDA_SUCCESS {
96                // Clean up on failure
97                unsafe { sys::cuMemPoolDestroy(pool) };
98                return Err(anyhow!(
99                    "cuMemPoolSetAttribute failed with error: {:?}",
100                    result
101                ));
102            }
103        }
104
105        let cuda_pool = CudaMemPool {
106            inner: Mutex::new(pool),
107        };
108
109        // Warm the pool by pre-allocating and freeing memory
110        if self.reserve_size > 0 {
111            // Create a temporary stream for warming
112            let stream = self.context.new_stream()?;
113
114            // Allocate to warm the pool (using safe variant)
115            let ptr = cuda_pool.alloc_async(self.reserve_size, &stream)?;
116
117            // Free back to pool (memory stays reserved)
118            cuda_pool.free_async(ptr, &stream)?;
119
120            // Synchronize to ensure operations complete
121            // SAFETY: stream.cu_stream() is valid for the lifetime of `stream`
122            let result = unsafe { sys::cuStreamSynchronize(stream.cu_stream()) };
123            if result != CUresult::CUDA_SUCCESS {
124                return Err(anyhow!(
125                    "cuStreamSynchronize failed with error: {:?}",
126                    result
127                ));
128            }
129        }
130
131        Ok(cuda_pool)
132    }
133}
134
135/// Safe wrapper around a CUDA memory pool.
136///
137/// The pool amortizes allocation overhead by maintaining a reservoir of device memory.
138/// Allocations are fast sub-allocations from this reservoir, and frees return memory
139/// to the pool rather than the OS (until the release threshold is exceeded).
140///
141/// # Thread Safety
142///
143/// This type uses internal locking to serialize host-side calls to CUDA driver APIs.
144/// `cuMemAllocFromPoolAsync` is not host-thread reentrant, so concurrent calls from
145/// multiple threads must be serialized. The GPU-side operations remain asynchronous
146/// and stream-ordered.
147///
148/// Use [`CudaMemPoolBuilder`] for configurable pool creation with pre-allocation.
149pub struct CudaMemPool {
150    /// Mutex protecting the pool handle for host-thread serialization.
151    ///
152    /// CUDA's `cuMemAllocFromPoolAsync` does not guarantee host-thread reentrancy,
153    /// so all calls to the pool must be serialized on the host side.
154    inner: Mutex<CUmemoryPool>,
155}
156
157// SAFETY: CudaMemPool is Send because the Mutex serializes all host-side access
158// to the pool handle, and CUDA driver state is thread-safe when properly serialized.
159unsafe impl Send for CudaMemPool {}
160
161// SAFETY: CudaMemPool is Sync because all access to the pool handle goes through
162// the Mutex, which serializes host-thread access. The CUDA driver requires this
163// serialization because cuMemAllocFromPoolAsync is not host-thread reentrant.
164unsafe impl Sync for CudaMemPool {}
165
166impl CudaMemPool {
167    /// Create a builder for a new CUDA memory pool.
168    ///
169    /// # Arguments
170    /// * `context` - CUDA context for the device
171    /// * `reserve_size` - Number of bytes to pre-allocate to warm the pool
172    pub fn builder(context: Arc<CudaContext>, reserve_size: usize) -> CudaMemPoolBuilder {
173        CudaMemPoolBuilder::new(context, reserve_size)
174    }
175
176    /// Allocate memory from the pool asynchronously.
177    ///
178    /// This is the safe variant that takes a `&CudaStream` reference, ensuring
179    /// the stream is valid for the duration of the call.
180    ///
181    /// The allocation is stream-ordered; the memory is available for use
182    /// after all preceding operations on the stream complete.
183    ///
184    /// # Host Serialization
185    ///
186    /// This method acquires an internal mutex because `cuMemAllocFromPoolAsync`
187    /// is not host-thread reentrant. The allocation itself is stream-ordered on
188    /// the GPU side.
189    ///
190    /// # Arguments
191    /// * `size` - Size in bytes to allocate
192    /// * `stream` - CUDA stream for async ordering
193    ///
194    /// # Returns
195    /// Device pointer to the allocated memory
196    pub fn alloc_async(&self, size: usize, stream: &CudaStream) -> Result<u64> {
197        // SAFETY: stream.cu_stream() returns a valid handle owned by the CudaStream,
198        // and the borrow ensures the stream lives for the duration of this call.
199        unsafe { self.alloc_async_raw(size, stream.cu_stream()) }
200    }
201
202    /// Allocate memory from the pool asynchronously (raw stream handle variant).
203    ///
204    /// This is the unsafe variant for use when you have a raw `CUstream` handle
205    /// from sources other than cudarc's `CudaStream`.
206    ///
207    /// # Host Serialization
208    ///
209    /// This method acquires an internal mutex because `cuMemAllocFromPoolAsync`
210    /// is not host-thread reentrant.
211    ///
212    /// # Arguments
213    /// * `size` - Size in bytes to allocate
214    /// * `stream` - Raw CUDA stream handle for async ordering
215    ///
216    /// # Returns
217    /// Device pointer to the allocated memory
218    ///
219    /// # Safety
220    ///
221    /// The caller must ensure that `stream` is a valid CUDA stream handle that
222    /// will remain valid for the duration of this call.
223    pub unsafe fn alloc_async_raw(&self, size: usize, stream: CUstream) -> Result<u64> {
224        let pool = self
225            .inner
226            .lock()
227            .map_err(|e| anyhow!("mutex poisoned: {}", e))?;
228
229        let mut ptr: u64 = 0;
230
231        let result = unsafe { sys::cuMemAllocFromPoolAsync(&mut ptr, size, *pool, stream) };
232
233        if result != CUresult::CUDA_SUCCESS {
234            return Err(anyhow!(
235                "cuMemAllocFromPoolAsync failed with error: {:?}",
236                result
237            ));
238        }
239
240        Ok(ptr)
241    }
242
243    /// Free memory back to the pool asynchronously.
244    ///
245    /// This is the safe variant that takes a `&CudaStream` reference.
246    ///
247    /// The memory is returned to the pool's reservoir (not the OS) and can be
248    /// reused by subsequent allocations. The free is stream-ordered.
249    ///
250    /// # Arguments
251    /// * `ptr` - Device pointer previously allocated from this pool
252    /// * `stream` - CUDA stream for async ordering
253    pub fn free_async(&self, ptr: u64, stream: &CudaStream) -> Result<()> {
254        // SAFETY: stream.cu_stream() returns a valid handle owned by the CudaStream,
255        // and the borrow ensures the stream lives for the duration of this call.
256        unsafe { self.free_async_raw(ptr, stream.cu_stream()) }
257    }
258
259    // NOTE: Unlike alloc_async_raw, this method does NOT acquire the pool mutex.
260    // The mutex in alloc_async_raw ensures each allocation returns a unique pointer.
261    // cuMemFreeAsync only enqueues a stream-ordered free operation for that unique
262    // pointer - multiple threads can safely enqueue frees for different unique pointers
263    // concurrently. The actual return-to-pool happens asynchronously on the GPU side.
264
265    /// Free memory back to the pool asynchronously (raw stream handle variant).
266    ///
267    /// This is the unsafe variant for use when you have a raw `CUstream` handle.
268    ///
269    /// The memory is returned to the pool's reservoir (not the OS) and can be
270    /// reused by subsequent allocations. The free is stream-ordered.
271    ///
272    /// # Arguments
273    /// * `ptr` - Device pointer previously allocated from this pool
274    /// * `stream` - Raw CUDA stream handle for async ordering
275    ///
276    /// # Safety
277    ///
278    /// The caller must ensure that:
279    /// - `ptr` is a valid device pointer previously allocated from this pool
280    /// - `stream` is a valid CUDA stream handle
281    pub unsafe fn free_async_raw(&self, ptr: u64, stream: CUstream) -> Result<()> {
282        let result = unsafe { sys::cuMemFreeAsync(ptr, stream) };
283
284        if result != CUresult::CUDA_SUCCESS {
285            return Err(anyhow!("cuMemFreeAsync failed with error: {:?}", result));
286        }
287
288        Ok(())
289    }
290}
291
292impl Drop for CudaMemPool {
293    fn drop(&mut self) {
294        // No need to lock - we have &mut self so exclusive access is guaranteed
295        let pool = self
296            .inner
297            .get_mut()
298            .expect("mutex should not be poisoned during drop");
299
300        // Destroy the pool, releasing all memory back to the system
301        let result = unsafe { sys::cuMemPoolDestroy(*pool) };
302        if result != CUresult::CUDA_SUCCESS {
303            tracing::warn!("cuMemPoolDestroy failed with error: {:?}", result);
304        }
305    }
306}
307
308#[cfg(all(test, feature = "testing-cuda"))]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_pool_creation_with_builder() {
314        // Skip if no CUDA device available
315        let context = match CudaContext::new(0) {
316            Ok(ctx) => ctx,
317            Err(e) => {
318                eprintln!("Skipping test - no CUDA device: {:?}", e);
319                return;
320            }
321        };
322
323        // Test builder with reserve size and release threshold
324        let result = CudaMemPool::builder(context.clone(), 1024 * 1024) // 1 MiB reserve
325            .release_threshold(64 * 1024 * 1024) // 64 MiB threshold
326            .build();
327
328        if result.is_err() {
329            eprintln!("Skipping test - pool creation failed: {:?}", result.err());
330            return;
331        }
332        let pool = result.unwrap();
333        drop(pool);
334    }
335
336    #[test]
337    fn test_pool_creation_no_threshold() {
338        // Skip if no CUDA device available
339        let context = match CudaContext::new(0) {
340            Ok(ctx) => ctx,
341            Err(e) => {
342                eprintln!("Skipping test - no CUDA device: {:?}", e);
343                return;
344            }
345        };
346
347        // Test builder without release threshold
348        let result = CudaMemPool::builder(context, 0).build();
349
350        if result.is_err() {
351            eprintln!("Skipping test - pool creation failed: {:?}", result.err());
352            return;
353        }
354        let pool = result.unwrap();
355        drop(pool);
356    }
357}