Skip to main content

baracuda_driver/
mempool.rs

1//! Stream-ordered memory pools (CUDA 11.2+).
2//!
3//! A [`MemoryPool`] is a GPU-backed allocator with a release threshold:
4//! `cuMemAllocFromPoolAsync` allocates out of the pool on a specific stream,
5//! `cuMemFreeAsync` returns memory to the pool, and the pool holds on to
6//! returned blocks up to a user-configured "keep" threshold so that the
7//! next same-size allocation is cheap. This is the backbone of framework
8//! memory allocators (PyTorch caching allocator, JAX, TensorFlow).
9//!
10//! Every device has a default pool accessible via [`default_pool`]. You
11//! can also create independent pools via [`MemoryPool::new`].
12
13use core::ffi::c_void;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::types::{
17    CUmemAccessDesc, CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation,
18    CUmemLocationType, CUmemPoolProps, CUmemPoolPtrExportData, CUmemPool_attribute,
19};
20use baracuda_cuda_sys::{driver, CUdeviceptr, CUmemoryPool};
21
22use crate::context::Context;
23use crate::device::Device;
24use crate::error::{check, Result};
25use crate::stream::Stream;
26use crate::vmm::AccessFlags;
27
28/// A CUDA memory pool. Dropping the handle calls `cuMemPoolDestroy` (the
29/// default per-device pool is *not* owned by this type — see [`default_pool`]
30/// vs [`MemoryPool::new`]).
31#[derive(Clone)]
32pub struct MemoryPool {
33    inner: Arc<MemoryPoolInner>,
34}
35
36struct MemoryPoolInner {
37    handle: CUmemoryPool,
38    owned: bool,
39    #[allow(dead_code)]
40    context: Context,
41}
42
43unsafe impl Send for MemoryPoolInner {}
44unsafe impl Sync for MemoryPoolInner {}
45
46impl core::fmt::Debug for MemoryPoolInner {
47    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48        f.debug_struct("MemoryPool")
49            .field("handle", &self.handle)
50            .field("owned", &self.owned)
51            .finish_non_exhaustive()
52    }
53}
54
55impl core::fmt::Debug for MemoryPool {
56    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57        self.inner.fmt(f)
58    }
59}
60
61impl MemoryPool {
62    /// Create a fresh pool whose backing memory lives on `device`. The pool
63    /// is destroyed when the last `MemoryPool` clone drops.
64    pub fn new(context: &Context, device: &Device) -> Result<Self> {
65        context.set_current()?;
66        let d = driver()?;
67        let cu = d.cu_mem_pool_create()?;
68        let props = CUmemPoolProps {
69            alloc_type: CUmemAllocationType::PINNED,
70            handle_types: CUmemAllocationHandleType::NONE,
71            location: CUmemLocation {
72                type_: CUmemLocationType::DEVICE,
73                id: device.as_raw().0,
74            },
75            ..Default::default()
76        };
77        let mut handle: CUmemoryPool = core::ptr::null_mut();
78        check(unsafe { cu(&mut handle, &props) })?;
79        Ok(Self {
80            inner: Arc::new(MemoryPoolInner {
81                handle,
82                owned: true,
83                context: context.clone(),
84            }),
85        })
86    }
87
88    /// Wrap a raw pool handle without taking ownership. Drop is a no-op.
89    ///
90    /// # Safety
91    ///
92    /// `handle` must be a valid `CUmemoryPool`. The caller guarantees it
93    /// outlives this wrapper.
94    pub unsafe fn from_borrowed(context: &Context, handle: CUmemoryPool) -> Self {
95        Self {
96            inner: Arc::new(MemoryPoolInner {
97                handle,
98                owned: false,
99                context: context.clone(),
100            }),
101        }
102    }
103
104    /// Raw `CUmemoryPool`. Use with care.
105    #[inline]
106    pub fn as_raw(&self) -> CUmemoryPool {
107        self.inner.handle
108    }
109
110    /// `u64` release threshold — bytes above which the pool starts returning
111    /// memory to the OS. Default is 0 (aggressive release).
112    pub fn set_release_threshold(&self, bytes: u64) -> Result<()> {
113        let d = driver()?;
114        let cu = d.cu_mem_pool_set_attribute()?;
115        let mut v = bytes;
116        check(unsafe {
117            cu(
118                self.inner.handle,
119                CUmemPool_attribute::RELEASE_THRESHOLD,
120                &mut v as *mut u64 as *mut c_void,
121            )
122        })
123    }
124
125    pub fn release_threshold(&self) -> Result<u64> {
126        let d = driver()?;
127        let cu = d.cu_mem_pool_get_attribute()?;
128        let mut v: u64 = 0;
129        check(unsafe {
130            cu(
131                self.inner.handle,
132                CUmemPool_attribute::RELEASE_THRESHOLD,
133                &mut v as *mut u64 as *mut c_void,
134            )
135        })?;
136        Ok(v)
137    }
138
139    /// Current bytes handed out to allocations.
140    pub fn used_bytes(&self) -> Result<u64> {
141        let d = driver()?;
142        let cu = d.cu_mem_pool_get_attribute()?;
143        let mut v: u64 = 0;
144        check(unsafe {
145            cu(
146                self.inner.handle,
147                CUmemPool_attribute::USED_MEM_CURRENT,
148                &mut v as *mut u64 as *mut c_void,
149            )
150        })?;
151        Ok(v)
152    }
153
154    /// Current bytes reserved for the pool (used + free-but-kept).
155    pub fn reserved_bytes(&self) -> Result<u64> {
156        let d = driver()?;
157        let cu = d.cu_mem_pool_get_attribute()?;
158        let mut v: u64 = 0;
159        check(unsafe {
160            cu(
161                self.inner.handle,
162                CUmemPool_attribute::RESERVED_MEM_CURRENT,
163                &mut v as *mut u64 as *mut c_void,
164            )
165        })?;
166        Ok(v)
167    }
168
169    /// Release memory down to `min_bytes_to_keep` bytes.
170    pub fn trim_to(&self, min_bytes_to_keep: usize) -> Result<()> {
171        let d = driver()?;
172        let cu = d.cu_mem_pool_trim_to()?;
173        check(unsafe { cu(self.inner.handle, min_bytes_to_keep) })
174    }
175
176    /// Grant `device` the specified access to allocations from this pool.
177    /// Required for peer-access patterns.
178    pub fn set_access(&self, device: &Device, flags: AccessFlags) -> Result<()> {
179        let d = driver()?;
180        let cu = d.cu_mem_pool_set_access()?;
181        let desc = CUmemAccessDesc {
182            location: CUmemLocation {
183                type_: CUmemLocationType::DEVICE,
184                id: device.as_raw().0,
185            },
186            flags: flags.raw(),
187        };
188        check(unsafe { cu(self.inner.handle, &desc, 1) })
189    }
190
191    /// Query `device`'s access flags for this pool.
192    pub fn access(&self, device: &Device) -> Result<AccessFlags> {
193        let d = driver()?;
194        let cu = d.cu_mem_pool_get_access()?;
195        let mut loc = CUmemLocation {
196            type_: CUmemLocationType::DEVICE,
197            id: device.as_raw().0,
198        };
199        let mut flags: core::ffi::c_int = 0;
200        check(unsafe { cu(&mut flags, self.inner.handle, &mut loc) })?;
201        Ok(AccessFlags::from_raw(flags))
202    }
203
204    /// Allocate `bytes` bytes of device memory from this pool, ordered on
205    /// `stream`. Free via [`crate::DeviceBuffer::free_async`] or by letting
206    /// the returned buffer drop (sync free).
207    pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> Result<CUdeviceptr> {
208        let d = driver()?;
209        let cu = d.cu_mem_alloc_from_pool_async()?;
210        let mut ptr = CUdeviceptr(0);
211        check(unsafe { cu(&mut ptr, bytes, self.inner.handle, stream.as_raw()) })?;
212        Ok(ptr)
213    }
214
215    /// Export an opaque blob that a peer process can import via
216    /// [`MemoryPool::import_pointer`]. Both ends must share the same pool
217    /// via its shareable-handle mechanism first (see
218    /// [`MemoryPool::export_to_shareable_handle`]).
219    pub fn export_pointer(&self, ptr: CUdeviceptr) -> Result<CUmemPoolPtrExportData> {
220        let d = driver()?;
221        let cu = d.cu_mem_pool_export_pointer()?;
222        let mut data = CUmemPoolPtrExportData::default();
223        check(unsafe { cu(&mut data, ptr) })?;
224        Ok(data)
225    }
226
227    /// Inverse of [`MemoryPool::export_pointer`]: resolve the exported blob
228    /// to a device pointer valid in the importing process.
229    pub fn import_pointer(&self, mut data: CUmemPoolPtrExportData) -> Result<CUdeviceptr> {
230        let d = driver()?;
231        let cu = d.cu_mem_pool_import_pointer()?;
232        let mut ptr = CUdeviceptr(0);
233        check(unsafe { cu(&mut ptr, self.inner.handle, &mut data) })?;
234        Ok(ptr)
235    }
236}
237
238impl AccessFlags {
239    #[inline]
240    fn from_raw(raw: core::ffi::c_int) -> Self {
241        use baracuda_cuda_sys::types::CUmemAccess_flags;
242        match raw {
243            x if x == CUmemAccess_flags::READ => AccessFlags::Read,
244            x if x == CUmemAccess_flags::READWRITE => AccessFlags::ReadWrite,
245            _ => AccessFlags::None,
246        }
247    }
248}
249
250impl Drop for MemoryPoolInner {
251    fn drop(&mut self) {
252        if !self.owned || self.handle.is_null() {
253            return;
254        }
255        if let Ok(d) = driver() {
256            if let Ok(cu) = d.cu_mem_pool_destroy() {
257                let _ = unsafe { cu(self.handle) };
258            }
259        }
260    }
261}
262
263/// Return the device's default memory pool — shared across the process,
264/// not owned by the caller.
265pub fn default_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
266    context.set_current()?;
267    let d = driver()?;
268    let cu = d.cu_device_get_default_mem_pool()?;
269    let mut handle: CUmemoryPool = core::ptr::null_mut();
270    check(unsafe { cu(&mut handle, device.as_raw()) })?;
271    // SAFETY: returned handle is owned by the driver, not us — mark
272    // non-owning so our Drop doesn't try to destroy it.
273    Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
274}
275
276/// Return the current memory pool that `cuMemAllocAsync` uses on `device`
277/// (defaults to the default pool unless changed via [`set_current_pool`]).
278pub fn current_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
279    context.set_current()?;
280    let d = driver()?;
281    let cu = d.cu_device_get_mem_pool()?;
282    let mut handle: CUmemoryPool = core::ptr::null_mut();
283    check(unsafe { cu(&mut handle, device.as_raw()) })?;
284    Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
285}
286
287/// Replace the pool used by `cuMemAllocAsync` on `device`. Pool must
288/// outlive all async allocations it services.
289pub fn set_current_pool(device: &Device, pool: &MemoryPool) -> Result<()> {
290    let d = driver()?;
291    let cu = d.cu_device_set_mem_pool()?;
292    check(unsafe { cu(device.as_raw(), pool.as_raw()) })
293}