baracuda_runtime/
mempool.rs1use std::sync::Arc;
9
10use baracuda_cuda_sys::runtime::runtime;
11use baracuda_cuda_sys::runtime::types::{
12 cudaMemAccessDesc, cudaMemAllocationHandleType, cudaMemAllocationType, cudaMemLocation,
13 cudaMemLocationType, cudaMemPoolAttr, cudaMemPoolProps, cudaMemPoolPtrExportData,
14 cudaMemPool_t,
15};
16
17use crate::device::Device;
18use crate::error::{check, Result};
19use crate::stream::Stream;
20
21#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum AccessFlags {
24 None,
25 Read,
26 ReadWrite,
27}
28
29impl AccessFlags {
30 #[inline]
31 fn raw(self) -> core::ffi::c_int {
32 use baracuda_cuda_sys::runtime::types::cudaMemAccessFlags;
33 match self {
34 AccessFlags::None => cudaMemAccessFlags::NONE,
35 AccessFlags::Read => cudaMemAccessFlags::READ,
36 AccessFlags::ReadWrite => cudaMemAccessFlags::READ_WRITE,
37 }
38 }
39
40 #[inline]
41 fn from_raw(raw: core::ffi::c_int) -> Self {
42 use baracuda_cuda_sys::runtime::types::cudaMemAccessFlags;
43 match raw {
44 x if x == cudaMemAccessFlags::READ => AccessFlags::Read,
45 x if x == cudaMemAccessFlags::READ_WRITE => AccessFlags::ReadWrite,
46 _ => AccessFlags::None,
47 }
48 }
49}
50
51#[derive(Clone)]
54pub struct MemoryPool {
55 inner: Arc<MemoryPoolInner>,
56}
57
58struct MemoryPoolInner {
59 handle: cudaMemPool_t,
60 owned: bool,
61}
62
63unsafe impl Send for MemoryPoolInner {}
64unsafe impl Sync for MemoryPoolInner {}
65
66impl core::fmt::Debug for MemoryPoolInner {
67 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68 f.debug_struct("MemoryPool")
69 .field("handle", &self.handle)
70 .field("owned", &self.owned)
71 .finish()
72 }
73}
74
75impl core::fmt::Debug for MemoryPool {
76 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
77 self.inner.fmt(f)
78 }
79}
80
81impl MemoryPool {
82 pub fn new(device: &Device) -> Result<Self> {
84 let r = runtime()?;
85 let cu = r.cuda_mem_pool_create()?;
86 let props = cudaMemPoolProps {
87 alloc_type: cudaMemAllocationType::PINNED,
88 handle_types: cudaMemAllocationHandleType::NONE,
89 location: cudaMemLocation {
90 type_: cudaMemLocationType::DEVICE,
91 id: device.ordinal(),
92 },
93 ..Default::default()
94 };
95 let mut handle: cudaMemPool_t = core::ptr::null_mut();
96 check(unsafe { cu(&mut handle, &props) })?;
97 Ok(Self {
98 inner: Arc::new(MemoryPoolInner {
99 handle,
100 owned: true,
101 }),
102 })
103 }
104
105 pub unsafe fn from_borrowed(handle: cudaMemPool_t) -> Self {
111 Self {
112 inner: Arc::new(MemoryPoolInner {
113 handle,
114 owned: false,
115 }),
116 }
117 }
118
119 #[inline]
120 pub fn as_raw(&self) -> cudaMemPool_t {
121 self.inner.handle
122 }
123
124 pub fn set_release_threshold(&self, bytes: u64) -> Result<()> {
127 let r = runtime()?;
128 let cu = r.cuda_mem_pool_set_attribute()?;
129 let mut v = bytes;
130 check(unsafe {
131 cu(
132 self.inner.handle,
133 cudaMemPoolAttr::RELEASE_THRESHOLD,
134 &mut v as *mut u64 as *mut core::ffi::c_void,
135 )
136 })
137 }
138
139 pub fn release_threshold(&self) -> Result<u64> {
140 self.get_u64_attr(cudaMemPoolAttr::RELEASE_THRESHOLD)
141 }
142
143 pub fn used_bytes(&self) -> Result<u64> {
145 self.get_u64_attr(cudaMemPoolAttr::USED_MEM_CURRENT)
146 }
147
148 pub fn reserved_bytes(&self) -> Result<u64> {
150 self.get_u64_attr(cudaMemPoolAttr::RESERVED_MEM_CURRENT)
151 }
152
153 fn get_u64_attr(&self, attr: i32) -> Result<u64> {
154 let r = runtime()?;
155 let cu = r.cuda_mem_pool_get_attribute()?;
156 let mut v: u64 = 0;
157 check(unsafe {
158 cu(
159 self.inner.handle,
160 attr,
161 &mut v as *mut u64 as *mut core::ffi::c_void,
162 )
163 })?;
164 Ok(v)
165 }
166
167 pub fn trim_to(&self, min_bytes_to_keep: usize) -> Result<()> {
169 let r = runtime()?;
170 let cu = r.cuda_mem_pool_trim_to()?;
171 check(unsafe { cu(self.inner.handle, min_bytes_to_keep) })
172 }
173
174 pub fn set_access(&self, device: &Device, flags: AccessFlags) -> Result<()> {
176 let r = runtime()?;
177 let cu = r.cuda_mem_pool_set_access()?;
178 let desc = cudaMemAccessDesc {
179 location: cudaMemLocation {
180 type_: cudaMemLocationType::DEVICE,
181 id: device.ordinal(),
182 },
183 flags: flags.raw(),
184 };
185 check(unsafe { cu(self.inner.handle, &desc, 1) })
186 }
187
188 pub fn access(&self, device: &Device) -> Result<AccessFlags> {
190 let r = runtime()?;
191 let cu = r.cuda_mem_pool_get_access()?;
192 let mut loc = cudaMemLocation {
193 type_: cudaMemLocationType::DEVICE,
194 id: device.ordinal(),
195 };
196 let mut flags: core::ffi::c_int = 0;
197 check(unsafe { cu(&mut flags, self.inner.handle, &mut loc) })?;
198 Ok(AccessFlags::from_raw(flags))
199 }
200
201 pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> Result<*mut core::ffi::c_void> {
206 let r = runtime()?;
207 let cu = r.cuda_malloc_from_pool_async()?;
208 let mut ptr: *mut core::ffi::c_void = core::ptr::null_mut();
209 check(unsafe { cu(&mut ptr, bytes, self.inner.handle, stream.as_raw()) })?;
210 Ok(ptr)
211 }
212
213 pub unsafe fn free_async(&self, ptr: *mut core::ffi::c_void, stream: &Stream) -> Result<()> { unsafe {
220 let r = runtime()?;
221 let cu = r.cuda_free_async()?;
222 check(cu(ptr, stream.as_raw()))
223 }}
224
225 pub unsafe fn export_pointer(
231 &self,
232 ptr: *mut core::ffi::c_void,
233 ) -> Result<cudaMemPoolPtrExportData> { unsafe {
234 let r = runtime()?;
235 let cu = r.cuda_mem_pool_export_pointer()?;
236 let mut data = cudaMemPoolPtrExportData::default();
237 check(cu(&mut data, ptr))?;
238 Ok(data)
239 }}
240
241 pub fn import_pointer(
243 &self,
244 mut data: cudaMemPoolPtrExportData,
245 ) -> Result<*mut core::ffi::c_void> {
246 let r = runtime()?;
247 let cu = r.cuda_mem_pool_import_pointer()?;
248 let mut ptr: *mut core::ffi::c_void = core::ptr::null_mut();
249 check(unsafe { cu(&mut ptr, self.inner.handle, &mut data) })?;
250 Ok(ptr)
251 }
252}
253
254impl Drop for MemoryPoolInner {
255 fn drop(&mut self) {
256 if !self.owned || self.handle.is_null() {
257 return;
258 }
259 if let Ok(r) = runtime() {
260 if let Ok(cu) = r.cuda_mem_pool_destroy() {
261 let _ = unsafe { cu(self.handle) };
262 }
263 }
264 }
265}
266
267pub fn default_pool(device: &Device) -> Result<MemoryPool> {
269 let r = runtime()?;
270 let cu = r.cuda_device_get_default_mem_pool()?;
271 let mut handle: cudaMemPool_t = core::ptr::null_mut();
272 check(unsafe { cu(&mut handle, device.ordinal()) })?;
273 Ok(unsafe { MemoryPool::from_borrowed(handle) })
275}
276
277pub fn current_pool(device: &Device) -> Result<MemoryPool> {
279 let r = runtime()?;
280 let cu = r.cuda_device_get_mem_pool()?;
281 let mut handle: cudaMemPool_t = core::ptr::null_mut();
282 check(unsafe { cu(&mut handle, device.ordinal()) })?;
283 Ok(unsafe { MemoryPool::from_borrowed(handle) })
284}
285
286pub fn set_current_pool(device: &Device, pool: &MemoryPool) -> Result<()> {
288 let r = runtime()?;
289 let cu = r.cuda_device_set_mem_pool()?;
290 check(unsafe { cu(device.ordinal(), pool.as_raw()) })
291}