Skip to main content

baracuda_runtime/
mempool.rs

1//! Stream-ordered memory pools (Runtime API, CUDA 11.2+).
2//!
3//! Mirrors [`baracuda_driver::mempool`] — a pool is a device-backed
4//! allocator with a configurable release threshold, accessed via
5//! `cudaMallocFromPoolAsync` and returned with `cudaFreeAsync`.
6//! Each device exposes a default pool via [`default_pool`].
7
8use std::sync::Arc;
9
10use baracuda_cuda_sys::runtime::runtime;
11use baracuda_cuda_sys::runtime::types::{
12    cudaMemAccessDesc, cudaMemAllocationHandleType, cudaMemAllocationType, cudaMemLocation,
13    cudaMemLocationType, cudaMemPoolAttr, cudaMemPoolProps, cudaMemPoolPtrExportData,
14    cudaMemPool_t,
15};
16
17use crate::device::Device;
18use crate::error::{check, Result};
19use crate::stream::Stream;
20
21/// Access rights granted to a device for a pool's allocations.
22#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum AccessFlags {
24    None,
25    Read,
26    ReadWrite,
27}
28
29impl AccessFlags {
30    #[inline]
31    fn raw(self) -> core::ffi::c_int {
32        use baracuda_cuda_sys::runtime::types::cudaMemAccessFlags;
33        match self {
34            AccessFlags::None => cudaMemAccessFlags::NONE,
35            AccessFlags::Read => cudaMemAccessFlags::READ,
36            AccessFlags::ReadWrite => cudaMemAccessFlags::READ_WRITE,
37        }
38    }
39
40    #[inline]
41    fn from_raw(raw: core::ffi::c_int) -> Self {
42        use baracuda_cuda_sys::runtime::types::cudaMemAccessFlags;
43        match raw {
44            x if x == cudaMemAccessFlags::READ => AccessFlags::Read,
45            x if x == cudaMemAccessFlags::READ_WRITE => AccessFlags::ReadWrite,
46            _ => AccessFlags::None,
47        }
48    }
49}
50
51/// A memory pool. Owned pools are destroyed on last-clone drop; borrowed
52/// pools (returned by [`default_pool`] / [`current_pool`]) are not.
53#[derive(Clone)]
54pub struct MemoryPool {
55    inner: Arc<MemoryPoolInner>,
56}
57
58struct MemoryPoolInner {
59    handle: cudaMemPool_t,
60    owned: bool,
61}
62
63unsafe impl Send for MemoryPoolInner {}
64unsafe impl Sync for MemoryPoolInner {}
65
66impl core::fmt::Debug for MemoryPoolInner {
67    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68        f.debug_struct("MemoryPool")
69            .field("handle", &self.handle)
70            .field("owned", &self.owned)
71            .finish()
72    }
73}
74
75impl core::fmt::Debug for MemoryPool {
76    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
77        self.inner.fmt(f)
78    }
79}
80
81impl MemoryPool {
82    /// Create a fresh pool backed on `device`.
83    pub fn new(device: &Device) -> Result<Self> {
84        let r = runtime()?;
85        let cu = r.cuda_mem_pool_create()?;
86        let props = cudaMemPoolProps {
87            alloc_type: cudaMemAllocationType::PINNED,
88            handle_types: cudaMemAllocationHandleType::NONE,
89            location: cudaMemLocation {
90                type_: cudaMemLocationType::DEVICE,
91                id: device.ordinal(),
92            },
93            ..Default::default()
94        };
95        let mut handle: cudaMemPool_t = core::ptr::null_mut();
96        check(unsafe { cu(&mut handle, &props) })?;
97        Ok(Self {
98            inner: Arc::new(MemoryPoolInner {
99                handle,
100                owned: true,
101            }),
102        })
103    }
104
105    /// Wrap a raw pool handle without taking ownership.
106    ///
107    /// # Safety
108    ///
109    /// `handle` must outlive this wrapper.
110    pub unsafe fn from_borrowed(handle: cudaMemPool_t) -> Self {
111        Self {
112            inner: Arc::new(MemoryPoolInner {
113                handle,
114                owned: false,
115            }),
116        }
117    }
118
119    #[inline]
120    pub fn as_raw(&self) -> cudaMemPool_t {
121        self.inner.handle
122    }
123
124    /// Set the release threshold (bytes retained before the pool starts
125    /// returning memory to the OS). Default is 0.
126    pub fn set_release_threshold(&self, bytes: u64) -> Result<()> {
127        let r = runtime()?;
128        let cu = r.cuda_mem_pool_set_attribute()?;
129        let mut v = bytes;
130        check(unsafe {
131            cu(
132                self.inner.handle,
133                cudaMemPoolAttr::RELEASE_THRESHOLD,
134                &mut v as *mut u64 as *mut core::ffi::c_void,
135            )
136        })
137    }
138
139    pub fn release_threshold(&self) -> Result<u64> {
140        self.get_u64_attr(cudaMemPoolAttr::RELEASE_THRESHOLD)
141    }
142
143    /// Current bytes handed out to allocations.
144    pub fn used_bytes(&self) -> Result<u64> {
145        self.get_u64_attr(cudaMemPoolAttr::USED_MEM_CURRENT)
146    }
147
148    /// Current bytes reserved for the pool (used + kept-free).
149    pub fn reserved_bytes(&self) -> Result<u64> {
150        self.get_u64_attr(cudaMemPoolAttr::RESERVED_MEM_CURRENT)
151    }
152
153    fn get_u64_attr(&self, attr: i32) -> Result<u64> {
154        let r = runtime()?;
155        let cu = r.cuda_mem_pool_get_attribute()?;
156        let mut v: u64 = 0;
157        check(unsafe {
158            cu(
159                self.inner.handle,
160                attr,
161                &mut v as *mut u64 as *mut core::ffi::c_void,
162            )
163        })?;
164        Ok(v)
165    }
166
167    /// Release memory down to `min_bytes_to_keep`.
168    pub fn trim_to(&self, min_bytes_to_keep: usize) -> Result<()> {
169        let r = runtime()?;
170        let cu = r.cuda_mem_pool_trim_to()?;
171        check(unsafe { cu(self.inner.handle, min_bytes_to_keep) })
172    }
173
174    /// Grant `device` the specified access to allocations from this pool.
175    pub fn set_access(&self, device: &Device, flags: AccessFlags) -> Result<()> {
176        let r = runtime()?;
177        let cu = r.cuda_mem_pool_set_access()?;
178        let desc = cudaMemAccessDesc {
179            location: cudaMemLocation {
180                type_: cudaMemLocationType::DEVICE,
181                id: device.ordinal(),
182            },
183            flags: flags.raw(),
184        };
185        check(unsafe { cu(self.inner.handle, &desc, 1) })
186    }
187
188    /// Query `device`'s access flags for this pool.
189    pub fn access(&self, device: &Device) -> Result<AccessFlags> {
190        let r = runtime()?;
191        let cu = r.cuda_mem_pool_get_access()?;
192        let mut loc = cudaMemLocation {
193            type_: cudaMemLocationType::DEVICE,
194            id: device.ordinal(),
195        };
196        let mut flags: core::ffi::c_int = 0;
197        check(unsafe { cu(&mut flags, self.inner.handle, &mut loc) })?;
198        Ok(AccessFlags::from_raw(flags))
199    }
200
201    /// Allocate `bytes` bytes of device memory from this pool, ordered on
202    /// `stream`. Returns a raw device pointer — free via
203    /// [`crate::DeviceBuffer::free_async`] or by calling
204    /// [`Self::free_async`] on the raw pointer.
205    pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> Result<*mut core::ffi::c_void> {
206        let r = runtime()?;
207        let cu = r.cuda_malloc_from_pool_async()?;
208        let mut ptr: *mut core::ffi::c_void = core::ptr::null_mut();
209        check(unsafe { cu(&mut ptr, bytes, self.inner.handle, stream.as_raw()) })?;
210        Ok(ptr)
211    }
212
213    /// Free a device pointer previously returned by
214    /// [`Self::alloc_async`] (routes through `cudaFreeAsync`).
215    ///
216    /// # Safety
217    ///
218    /// `ptr` must be a live allocation from this (or another) pool.
219    pub unsafe fn free_async(&self, ptr: *mut core::ffi::c_void, stream: &Stream) -> Result<()> { unsafe {
220        let r = runtime()?;
221        let cu = r.cuda_free_async()?;
222        check(cu(ptr, stream.as_raw()))
223    }}
224
225    /// Export a pointer in this pool for sharing with a peer process.
226    ///
227    /// # Safety
228    ///
229    /// `ptr` must be a live allocation from this pool.
230    pub unsafe fn export_pointer(
231        &self,
232        ptr: *mut core::ffi::c_void,
233    ) -> Result<cudaMemPoolPtrExportData> { unsafe {
234        let r = runtime()?;
235        let cu = r.cuda_mem_pool_export_pointer()?;
236        let mut data = cudaMemPoolPtrExportData::default();
237        check(cu(&mut data, ptr))?;
238        Ok(data)
239    }}
240
241    /// Import an exported pointer into this pool.
242    pub fn import_pointer(
243        &self,
244        mut data: cudaMemPoolPtrExportData,
245    ) -> Result<*mut core::ffi::c_void> {
246        let r = runtime()?;
247        let cu = r.cuda_mem_pool_import_pointer()?;
248        let mut ptr: *mut core::ffi::c_void = core::ptr::null_mut();
249        check(unsafe { cu(&mut ptr, self.inner.handle, &mut data) })?;
250        Ok(ptr)
251    }
252}
253
254impl Drop for MemoryPoolInner {
255    fn drop(&mut self) {
256        if !self.owned || self.handle.is_null() {
257            return;
258        }
259        if let Ok(r) = runtime() {
260            if let Ok(cu) = r.cuda_mem_pool_destroy() {
261                let _ = unsafe { cu(self.handle) };
262            }
263        }
264    }
265}
266
267/// Return the device's default memory pool (borrowed — not destroyed on drop).
268pub fn default_pool(device: &Device) -> Result<MemoryPool> {
269    let r = runtime()?;
270    let cu = r.cuda_device_get_default_mem_pool()?;
271    let mut handle: cudaMemPool_t = core::ptr::null_mut();
272    check(unsafe { cu(&mut handle, device.ordinal()) })?;
273    // SAFETY: the runtime owns the default pool; we wrap non-owning.
274    Ok(unsafe { MemoryPool::from_borrowed(handle) })
275}
276
277/// Return the pool currently used by `cudaMallocAsync` on `device`.
278pub fn current_pool(device: &Device) -> Result<MemoryPool> {
279    let r = runtime()?;
280    let cu = r.cuda_device_get_mem_pool()?;
281    let mut handle: cudaMemPool_t = core::ptr::null_mut();
282    check(unsafe { cu(&mut handle, device.ordinal()) })?;
283    Ok(unsafe { MemoryPool::from_borrowed(handle) })
284}
285
286/// Replace the pool used by `cudaMallocAsync` on `device`.
287pub fn set_current_pool(device: &Device, pool: &MemoryPool) -> Result<()> {
288    let r = runtime()?;
289    let cu = r.cuda_device_set_mem_pool()?;
290    check(unsafe { cu(device.ordinal(), pool.as_raw()) })
291}