Skip to main content

oxicuda_memory/
pool.rs

1//! Stream-ordered memory pool for efficient async allocation.
2//!
3//! Requires CUDA 11.2+ driver.  Gated behind the `pool` feature.
4//!
5//! Stream-ordered memory pools allow allocation and deallocation to be
6//! ordered relative to other operations on a CUDA stream, enabling the
7//! driver to reuse memory more aggressively and avoid synchronisation
8//! barriers that would otherwise be needed for conventional
9//! `cuMemAlloc` / `cuMemFree` calls.
10//!
11//! # Implementation note
12//!
13//! This implementation provides a practical fallback pool that reuses freed
14//! allocations by size and uses `cuMemAlloc_v2` / `cuMemFree_v2` under the
15//! hood.  It keeps the same API surface as a stream-ordered pool, but does
16//! not yet expose native CUDA mempool handles.
17//!
18//! # API
19//!
20//! ```rust,ignore
21//! let pool = MemoryPool::new(device)?;
22//! let buf = PooledBuffer::<f32>::alloc_async(&pool, 1024, &stream)?;
23//! // … use buf in kernels on `stream` …
24//! // buf is freed asynchronously when dropped (enqueued on the pool's stream).
25//! ```
26
27#![cfg(feature = "pool")]
28
29use std::collections::HashMap;
30use std::marker::PhantomData;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::{Arc, Mutex};
33
34use oxicuda_driver::error::{CudaError, CudaResult, check};
35use oxicuda_driver::ffi::{
36    CUdeviceptr, CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation, CUmemLocationType,
37    CUmemPoolProps, CUmemoryPool,
38};
39use oxicuda_driver::loader::try_driver;
40use oxicuda_driver::stream::Stream;
41use tracing::warn;
42
43// ---------------------------------------------------------------------------
44// MemoryPool
45// ---------------------------------------------------------------------------
46
47/// A stream-ordered memory pool (CUDA 11.2+).
48///
49/// Memory pools allow the driver to reuse freed allocations without
50/// returning them to the OS, reducing allocation latency and avoiding
51/// the implicit synchronisation of `cuMemFree`.
52///
53/// # Status
54///
55/// `MemoryPool` is a software pool layered on top of `cuMemAlloc_v2`.
56/// For a thin wrapper over the *native* CUDA stream-ordered memory pool
57/// API (`cuMemPoolCreate`, `cuMemPoolDestroy`, `cuMemAllocFromPoolAsync`,
58/// `cuMemFreeAsync`), use [`NativeMemoryPool`].
59///
60/// Statistics for a memory pool's allocation behaviour.
61///
62/// These statistics track the total bytes allocated, peak usage,
63/// allocation count, and free count for a given pool.
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
65pub struct PoolStats {
66    /// Total bytes currently allocated from the pool.
67    pub allocated_bytes: usize,
68    /// Peak bytes allocated at any point during the pool's lifetime.
69    pub peak_bytes: usize,
70    /// Total number of allocations performed.
71    pub allocation_count: u64,
72    /// Total number of frees performed.
73    pub free_count: u64,
74}
75
76#[derive(Debug)]
77struct MemoryPoolInner {
78    handle: u64,
79    device_ordinal: i32,
80    threshold_bytes: AtomicUsize,
81    cached_bytes: AtomicUsize,
82    stats: Mutex<PoolStats>,
83    free_bins: Mutex<HashMap<usize, Vec<CUdeviceptr>>>,
84}
85
86impl MemoryPoolInner {
87    fn allocate_fresh(&self, bytes: usize) -> CudaResult<CUdeviceptr> {
88        let api = try_driver()?;
89        let mut ptr: CUdeviceptr = 0;
90        let rc = unsafe { (api.cu_mem_alloc_v2)(&mut ptr, bytes) };
91        oxicuda_driver::check(rc)?;
92        Ok(ptr)
93    }
94
95    fn free_ptr(&self, ptr: CUdeviceptr) -> CudaResult<()> {
96        let api = try_driver()?;
97        let rc = unsafe { (api.cu_mem_free_v2)(ptr) };
98        oxicuda_driver::check(rc)
99    }
100
101    fn try_pop_reuse(&self, bytes: usize) -> CudaResult<Option<CUdeviceptr>> {
102        let mut bins = self.free_bins.lock().map_err(|_| CudaError::Unknown(0))?;
103        let maybe_ptr = bins.get_mut(&bytes).and_then(Vec::pop);
104        if maybe_ptr.is_some() {
105            self.cached_bytes.fetch_sub(bytes, Ordering::Relaxed);
106        }
107        Ok(maybe_ptr)
108    }
109
110    fn stash_freed(&self, ptr: CUdeviceptr, bytes: usize) -> CudaResult<()> {
111        let mut bins = self.free_bins.lock().map_err(|_| CudaError::Unknown(0))?;
112        bins.entry(bytes).or_default().push(ptr);
113        self.cached_bytes.fetch_add(bytes, Ordering::Relaxed);
114        Ok(())
115    }
116
117    fn release_cached_until(&self, keep_bytes: usize) -> CudaResult<()> {
118        loop {
119            let cached = self.cached_bytes.load(Ordering::Relaxed);
120            if cached <= keep_bytes {
121                return Ok(());
122            }
123
124            let popped = {
125                let mut bins = self.free_bins.lock().map_err(|_| CudaError::Unknown(0))?;
126                let mut candidate: Option<(usize, CUdeviceptr)> = None;
127                for (size, vec) in bins.iter_mut() {
128                    if let Some(ptr) = vec.pop() {
129                        candidate = Some((*size, ptr));
130                        break;
131                    }
132                }
133                candidate
134            };
135
136            let Some((size, ptr)) = popped else {
137                return Ok(());
138            };
139            self.free_ptr(ptr)?;
140            self.cached_bytes.fetch_sub(size, Ordering::Relaxed);
141        }
142    }
143
144    fn update_alloc_stats(&self, bytes: usize) {
145        if let Ok(mut stats) = self.stats.lock() {
146            stats.allocated_bytes = stats.allocated_bytes.saturating_add(bytes);
147            stats.allocation_count = stats.allocation_count.saturating_add(1);
148            if stats.allocated_bytes > stats.peak_bytes {
149                stats.peak_bytes = stats.allocated_bytes;
150            }
151        }
152    }
153
154    fn update_free_stats(&self, bytes: usize) {
155        if let Ok(mut stats) = self.stats.lock() {
156            stats.allocated_bytes = stats.allocated_bytes.saturating_sub(bytes);
157            stats.free_count = stats.free_count.saturating_add(1);
158        }
159    }
160}
161
162impl Drop for MemoryPoolInner {
163    fn drop(&mut self) {
164        let Ok(mut bins) = self.free_bins.lock() else {
165            return;
166        };
167        let mut to_free: Vec<CUdeviceptr> = Vec::new();
168        for vec in bins.values_mut() {
169            to_free.append(vec);
170        }
171        drop(bins);
172
173        for ptr in to_free {
174            if let Err(e) = self.free_ptr(ptr) {
175                warn!("failed to free pooled pointer {ptr:#x} during drop: {e}");
176            }
177        }
178    }
179}
180
181/// A stream-ordered memory pool (CUDA 11.2+).
182pub struct MemoryPool {
183    inner: Arc<MemoryPoolInner>,
184}
185
186impl MemoryPool {
187    /// Creates a new memory pool on the given device.
188    ///
189    /// # Errors
190    ///
191    /// Creates an in-process pooling allocator for the given device.
192    pub fn new(device_ordinal: i32) -> CudaResult<Self> {
193        if device_ordinal < 0 {
194            return Err(CudaError::InvalidDevice);
195        }
196        Ok(Self {
197            inner: Arc::new(MemoryPoolInner {
198                handle: 0,
199                device_ordinal,
200                threshold_bytes: AtomicUsize::new(0),
201                cached_bytes: AtomicUsize::new(0),
202                stats: Mutex::new(PoolStats::default()),
203                free_bins: Mutex::new(HashMap::new()),
204            }),
205        })
206    }
207
208    /// Returns the raw pool handle.
209    ///
210    /// # Status
211    ///
212    /// Returns `0` until the pool is properly initialised.
213    #[inline]
214    pub fn raw_handle(&self) -> u64 {
215        self.inner.handle
216    }
217
218    /// Returns the device ordinal this pool targets.
219    #[inline]
220    pub fn device_ordinal(&self) -> i32 {
221        self.inner.device_ordinal
222    }
223
224    /// Returns current pool statistics.
225    ///
226    /// The statistics track allocation behaviour over the pool's lifetime.
227    #[inline]
228    pub fn stats(&self) -> PoolStats {
229        self.inner.stats.lock().map(|s| *s).unwrap_or_default()
230    }
231
232    /// Trims the pool, releasing unused memory back to the OS.
233    ///
234    /// Attempts to release memory such that the pool retains at most
235    /// `min_bytes` of unused memory.
236    ///
237    /// # Errors
238    ///
239    pub fn trim(&mut self, min_bytes: usize) -> CudaResult<()> {
240        self.inner.release_cached_until(min_bytes)
241    }
242
243    /// Sets the threshold at which the pool will automatically release
244    /// memory back to the OS.
245    ///
246    /// When the pool's unused memory exceeds `bytes`, subsequent frees
247    /// will trigger automatic trimming.
248    ///
249    /// # Errors
250    ///
251    pub fn set_threshold(&mut self, bytes: usize) -> CudaResult<()> {
252        self.inner.threshold_bytes.store(bytes, Ordering::Relaxed);
253        self.inner.release_cached_until(bytes)
254    }
255}
256
257// ---------------------------------------------------------------------------
258// PooledBuffer<T>
259// ---------------------------------------------------------------------------
260
261/// A device buffer allocated from a [`MemoryPool`].
262///
263/// Unlike [`DeviceBuffer`](crate::DeviceBuffer), a `PooledBuffer` is freed
264/// asynchronously — the free operation is enqueued on the stream rather
265/// than blocking the CPU.  This enables overlap of allocation, computation,
266/// and deallocation across multiple stream operations.
267///
268/// # Status
269///
270/// This type allocates from an in-process memory pool and returns buffers to
271/// that pool on drop.
272pub struct PooledBuffer<T: Copy> {
273    /// Raw device pointer to the pooled allocation.
274    ptr: CUdeviceptr,
275    /// Number of `T` elements.
276    len: usize,
277    /// Number of bytes in this allocation.
278    bytes: usize,
279    /// Owning pool.
280    pool: Arc<MemoryPoolInner>,
281    /// Marker for the element type.
282    _phantom: PhantomData<T>,
283}
284
285impl<T: Copy> PooledBuffer<T> {
286    /// Asynchronously allocates a buffer of `n` elements from the given pool.
287    ///
288    /// The allocation is ordered relative to other operations on `stream`.
289    ///
290    /// # Errors
291    ///
292    pub fn alloc_async(pool: &MemoryPool, n: usize, _stream: &Stream) -> CudaResult<Self> {
293        if n == 0 {
294            return Err(CudaError::InvalidValue);
295        }
296        let bytes = n
297            .checked_mul(std::mem::size_of::<T>())
298            .ok_or(CudaError::InvalidValue)?;
299        let ptr = if let Some(reused) = pool.inner.try_pop_reuse(bytes)? {
300            reused
301        } else {
302            pool.inner.allocate_fresh(bytes)?
303        };
304        pool.inner.update_alloc_stats(bytes);
305
306        Ok(Self {
307            ptr,
308            len: n,
309            bytes,
310            pool: Arc::clone(&pool.inner),
311            _phantom: PhantomData,
312        })
313    }
314
315    /// Returns the number of `T` elements in this buffer.
316    #[inline]
317    pub fn len(&self) -> usize {
318        self.len
319    }
320
321    /// Returns `true` if the buffer contains zero elements.
322    #[inline]
323    pub fn is_empty(&self) -> bool {
324        self.len == 0
325    }
326
327    /// Returns the total size of the allocation in bytes.
328    #[inline]
329    pub fn byte_size(&self) -> usize {
330        self.bytes
331    }
332
333    /// Returns the raw [`CUdeviceptr`] handle.
334    #[inline]
335    pub fn as_device_ptr(&self) -> CUdeviceptr {
336        self.ptr
337    }
338}
339
340impl<T: Copy> Drop for PooledBuffer<T> {
341    fn drop(&mut self) {
342        if self.ptr == 0 {
343            return;
344        }
345
346        if let Err(e) = self.pool.stash_freed(self.ptr, self.bytes) {
347            warn!("failed to return pooled pointer to free list: {e}; freeing directly");
348            if let Err(free_err) = self.pool.free_ptr(self.ptr) {
349                warn!("direct free of pooled pointer failed: {free_err}");
350            }
351            self.pool.update_free_stats(self.bytes);
352            self.ptr = 0;
353            return;
354        }
355
356        self.pool.update_free_stats(self.bytes);
357        let threshold = self.pool.threshold_bytes.load(Ordering::Relaxed);
358        if let Err(e) = self.pool.release_cached_until(threshold) {
359            warn!("pool threshold trim failed: {e}");
360        }
361        self.ptr = 0;
362    }
363}
364
365// ---------------------------------------------------------------------------
366// NativeMemoryPool — thin wrapper over the CUDA stream-ordered pool API
367// ---------------------------------------------------------------------------
368
369/// Configuration for a [`NativeMemoryPool`].
370#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
371pub struct NativeMemoryPoolProps {
372    /// Device ordinal that physically backs the pool.
373    pub device_ordinal: i32,
374    /// Maximum aggregate size (bytes) the pool may hold.  `0` = unlimited.
375    pub max_size_bytes: usize,
376}
377
378/// Thin wrapper around the CUDA driver's stream-ordered memory pool
379/// (`cuMemPoolCreate` / `cuMemPoolDestroy`).
380///
381/// Allocations are issued via [`NativeMemoryPool::alloc_async`] which
382/// invokes `cuMemAllocFromPoolAsync`; frees are issued via
383/// [`NativeMemoryPool::free_async`] which invokes `cuMemFreeAsync`.
384///
385/// # Stream-ordering
386///
387/// The CUDA stream-ordered pool API requires the caller to ensure all
388/// outstanding work on the stream has completed before destroying the
389/// pool.  The [`Drop`] implementation calls `cuMemPoolDestroy` and
390/// silently swallows any error to honour the standard Drop convention.
391/// Call [`NativeMemoryPool::destroy`] explicitly to surface destruction
392/// errors.
393///
394/// # Status
395///
396/// On systems without a CUDA driver (e.g. macOS), [`NativeMemoryPool::new`]
397/// fails with [`CudaError::NotInitialized`].  On older drivers that lack
398/// the pool entry points it fails with [`CudaError::NotSupported`].
399pub struct NativeMemoryPool {
400    raw: CUmemoryPool,
401    device_ordinal: i32,
402}
403
404// SAFETY: `CUmemoryPool` is an opaque driver handle.  The CUDA driver is
405// thread-safe; multiple threads may issue stream-ordered allocations from
406// the same pool concurrently.
407unsafe impl Send for NativeMemoryPool {}
408unsafe impl Sync for NativeMemoryPool {}
409
410impl NativeMemoryPool {
411    /// Creates a new native memory pool on the device described by `props`.
412    ///
413    /// # Errors
414    ///
415    /// * [`CudaError::InvalidValue`] if `device_ordinal` is negative.
416    /// * [`CudaError::NotInitialized`] if no CUDA driver is available.
417    /// * [`CudaError::NotSupported`] if the driver does not export
418    ///   `cuMemPoolCreate`.
419    /// * Other [`CudaError`] variants on driver failure.
420    pub fn new(props: NativeMemoryPoolProps) -> CudaResult<Self> {
421        if props.device_ordinal < 0 {
422            return Err(CudaError::InvalidDevice);
423        }
424
425        let api = try_driver()?;
426        let f = api.cu_mem_pool_create.ok_or(CudaError::NotSupported)?;
427
428        let pool_props = CUmemPoolProps {
429            alloc_type: CUmemAllocationType::Pinned as u32,
430            handle_types: CUmemAllocationHandleType::None as u32,
431            location: CUmemLocation {
432                loc_type: CUmemLocationType::Device as u32,
433                id: props.device_ordinal,
434            },
435            max_size: props.max_size_bytes,
436            ..CUmemPoolProps::default()
437        };
438
439        let mut raw = CUmemoryPool::default();
440        check(unsafe { f(&mut raw, &pool_props) })?;
441
442        Ok(Self {
443            raw,
444            device_ordinal: props.device_ordinal,
445        })
446    }
447
448    /// Returns the raw [`CUmemoryPool`] handle.
449    #[inline]
450    pub fn raw(&self) -> CUmemoryPool {
451        self.raw
452    }
453
454    /// Returns the device ordinal that backs this pool.
455    #[inline]
456    pub fn device_ordinal(&self) -> i32 {
457        self.device_ordinal
458    }
459
460    /// Asynchronously allocates `bytes` of memory from the pool, ordered
461    /// against `stream`.
462    ///
463    /// # Errors
464    ///
465    /// * [`CudaError::InvalidValue`] if `bytes` is zero.
466    /// * [`CudaError::NotInitialized`] if no CUDA driver is available.
467    /// * [`CudaError::NotSupported`] if the driver does not export
468    ///   `cuMemAllocFromPoolAsync`.
469    /// * Other [`CudaError`] variants on driver failure.
470    pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> CudaResult<CUdeviceptr> {
471        if bytes == 0 {
472            return Err(CudaError::InvalidValue);
473        }
474        let api = try_driver()?;
475        let f = api
476            .cu_mem_alloc_from_pool_async
477            .ok_or(CudaError::NotSupported)?;
478        let mut ptr: CUdeviceptr = 0;
479        check(unsafe { f(&mut ptr, bytes, self.raw, stream.raw()) })?;
480        Ok(ptr)
481    }
482
483    /// Asynchronously frees a pointer previously returned by
484    /// [`alloc_async`](Self::alloc_async), ordered against `stream`.
485    ///
486    /// # Errors
487    ///
488    /// * [`CudaError::NotInitialized`] if no CUDA driver is available.
489    /// * [`CudaError::NotSupported`] if the driver does not export
490    ///   `cuMemFreeAsync`.
491    /// * Other [`CudaError`] variants on driver failure.
492    pub fn free_async(&self, ptr: CUdeviceptr, stream: &Stream) -> CudaResult<()> {
493        let api = try_driver()?;
494        let f = api.cu_mem_free_async.ok_or(CudaError::NotSupported)?;
495        check(unsafe { f(ptr, stream.raw()) })
496    }
497
498    /// Destroys the pool, returning any driver error to the caller.
499    ///
500    /// The caller is responsible for ensuring all outstanding work on
501    /// streams that allocated from this pool has completed before calling
502    /// `destroy`.
503    ///
504    /// After this call returns, the [`Drop`] implementation will be a
505    /// no-op.
506    ///
507    /// # Errors
508    ///
509    /// * [`CudaError::NotInitialized`] if no CUDA driver is available.
510    /// * [`CudaError::NotSupported`] if the driver does not export
511    ///   `cuMemPoolDestroy`.
512    /// * Other [`CudaError`] variants on driver failure.
513    pub fn destroy(mut self) -> CudaResult<()> {
514        self.destroy_inner()
515    }
516
517    fn destroy_inner(&mut self) -> CudaResult<()> {
518        if self.raw.is_null() {
519            return Ok(());
520        }
521        let api = try_driver()?;
522        let f = api.cu_mem_pool_destroy.ok_or(CudaError::NotSupported)?;
523        let result = check(unsafe { f(self.raw) });
524        // Always clear the handle so Drop is a no-op even if destroy fails.
525        self.raw = CUmemoryPool::default();
526        result
527    }
528}
529
530impl Drop for NativeMemoryPool {
531    fn drop(&mut self) {
532        if let Err(e) = self.destroy_inner() {
533            warn!("failed to destroy native memory pool during drop: {e}");
534        }
535    }
536}
537
538// ---------------------------------------------------------------------------
539// Tests
540// ---------------------------------------------------------------------------
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    fn is_driver_unavailable(err: &CudaError) -> bool {
547        matches!(err, CudaError::NotInitialized | CudaError::NotSupported)
548    }
549
550    #[test]
551    fn native_memory_pool_props_default() {
552        let props = NativeMemoryPoolProps::default();
553        assert_eq!(props.device_ordinal, 0);
554        assert_eq!(props.max_size_bytes, 0);
555    }
556
557    #[test]
558    fn native_memory_pool_new_negative_device_fails() {
559        let props = NativeMemoryPoolProps {
560            device_ordinal: -1,
561            max_size_bytes: 0,
562        };
563        let result = NativeMemoryPool::new(props);
564        assert_eq!(result.err(), Some(CudaError::InvalidDevice));
565    }
566
567    /// Without a CUDA driver, `NativeMemoryPool::new` must fail with one of
568    /// the driver-unavailability error kinds rather than panicking.
569    #[test]
570    fn native_memory_pool_new_no_driver_returns_driver_unavailable() {
571        let result = NativeMemoryPool::new(NativeMemoryPoolProps::default());
572        match result {
573            Ok(pool) => {
574                // CUDA available: explicit destroy must succeed too.
575                let destroy = pool.destroy();
576                assert!(destroy.is_ok(), "destroy failed: {destroy:?}");
577            }
578            Err(e) => assert!(
579                is_driver_unavailable(&e),
580                "expected driver-unavailable error, got {e:?}"
581            ),
582        }
583    }
584
585    /// On macOS specifically, every driver-calling method must return
586    /// [`CudaError::NotInitialized`] (no library to load).
587    #[cfg(target_os = "macos")]
588    #[test]
589    fn macos_native_pool_returns_not_initialized() {
590        let result = NativeMemoryPool::new(NativeMemoryPoolProps::default());
591        let err = match result {
592            Err(e) => e,
593            Ok(_) => panic!("expected NotInitialized on macOS, got Ok"),
594        };
595        assert!(
596            matches!(err, CudaError::NotInitialized),
597            "expected NotInitialized, got {err:?}"
598        );
599    }
600}