baracuda_driver/
mempool.rs1use core::ffi::c_void;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::types::{
17 CUmemAccessDesc, CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation,
18 CUmemLocationType, CUmemPoolProps, CUmemPoolPtrExportData, CUmemPool_attribute,
19};
20use baracuda_cuda_sys::{driver, CUdeviceptr, CUmemoryPool};
21
22use crate::context::Context;
23use crate::device::Device;
24use crate::error::{check, Result};
25use crate::stream::Stream;
26use crate::vmm::AccessFlags;
27
28#[derive(Clone)]
32pub struct MemoryPool {
33 inner: Arc<MemoryPoolInner>,
34}
35
36struct MemoryPoolInner {
37 handle: CUmemoryPool,
38 owned: bool,
39 #[allow(dead_code)]
40 context: Context,
41}
42
43unsafe impl Send for MemoryPoolInner {}
44unsafe impl Sync for MemoryPoolInner {}
45
46impl core::fmt::Debug for MemoryPoolInner {
47 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48 f.debug_struct("MemoryPool")
49 .field("handle", &self.handle)
50 .field("owned", &self.owned)
51 .finish_non_exhaustive()
52 }
53}
54
55impl core::fmt::Debug for MemoryPool {
56 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57 self.inner.fmt(f)
58 }
59}
60
61impl MemoryPool {
62 pub fn new(context: &Context, device: &Device) -> Result<Self> {
65 context.set_current()?;
66 let d = driver()?;
67 let cu = d.cu_mem_pool_create()?;
68 let props = CUmemPoolProps {
69 alloc_type: CUmemAllocationType::PINNED,
70 handle_types: CUmemAllocationHandleType::NONE,
71 location: CUmemLocation {
72 type_: CUmemLocationType::DEVICE,
73 id: device.as_raw().0,
74 },
75 ..Default::default()
76 };
77 let mut handle: CUmemoryPool = core::ptr::null_mut();
78 check(unsafe { cu(&mut handle, &props) })?;
79 Ok(Self {
80 inner: Arc::new(MemoryPoolInner {
81 handle,
82 owned: true,
83 context: context.clone(),
84 }),
85 })
86 }
87
88 pub unsafe fn from_borrowed(context: &Context, handle: CUmemoryPool) -> Self {
95 Self {
96 inner: Arc::new(MemoryPoolInner {
97 handle,
98 owned: false,
99 context: context.clone(),
100 }),
101 }
102 }
103
104 #[inline]
106 pub fn as_raw(&self) -> CUmemoryPool {
107 self.inner.handle
108 }
109
110 pub fn set_release_threshold(&self, bytes: u64) -> Result<()> {
113 let d = driver()?;
114 let cu = d.cu_mem_pool_set_attribute()?;
115 let mut v = bytes;
116 check(unsafe {
117 cu(
118 self.inner.handle,
119 CUmemPool_attribute::RELEASE_THRESHOLD,
120 &mut v as *mut u64 as *mut c_void,
121 )
122 })
123 }
124
125 pub fn release_threshold(&self) -> Result<u64> {
126 let d = driver()?;
127 let cu = d.cu_mem_pool_get_attribute()?;
128 let mut v: u64 = 0;
129 check(unsafe {
130 cu(
131 self.inner.handle,
132 CUmemPool_attribute::RELEASE_THRESHOLD,
133 &mut v as *mut u64 as *mut c_void,
134 )
135 })?;
136 Ok(v)
137 }
138
139 pub fn used_bytes(&self) -> Result<u64> {
141 let d = driver()?;
142 let cu = d.cu_mem_pool_get_attribute()?;
143 let mut v: u64 = 0;
144 check(unsafe {
145 cu(
146 self.inner.handle,
147 CUmemPool_attribute::USED_MEM_CURRENT,
148 &mut v as *mut u64 as *mut c_void,
149 )
150 })?;
151 Ok(v)
152 }
153
154 pub fn reserved_bytes(&self) -> Result<u64> {
156 let d = driver()?;
157 let cu = d.cu_mem_pool_get_attribute()?;
158 let mut v: u64 = 0;
159 check(unsafe {
160 cu(
161 self.inner.handle,
162 CUmemPool_attribute::RESERVED_MEM_CURRENT,
163 &mut v as *mut u64 as *mut c_void,
164 )
165 })?;
166 Ok(v)
167 }
168
169 pub fn trim_to(&self, min_bytes_to_keep: usize) -> Result<()> {
171 let d = driver()?;
172 let cu = d.cu_mem_pool_trim_to()?;
173 check(unsafe { cu(self.inner.handle, min_bytes_to_keep) })
174 }
175
176 pub fn set_access(&self, device: &Device, flags: AccessFlags) -> Result<()> {
179 let d = driver()?;
180 let cu = d.cu_mem_pool_set_access()?;
181 let desc = CUmemAccessDesc {
182 location: CUmemLocation {
183 type_: CUmemLocationType::DEVICE,
184 id: device.as_raw().0,
185 },
186 flags: flags.raw(),
187 };
188 check(unsafe { cu(self.inner.handle, &desc, 1) })
189 }
190
191 pub fn access(&self, device: &Device) -> Result<AccessFlags> {
193 let d = driver()?;
194 let cu = d.cu_mem_pool_get_access()?;
195 let mut loc = CUmemLocation {
196 type_: CUmemLocationType::DEVICE,
197 id: device.as_raw().0,
198 };
199 let mut flags: core::ffi::c_int = 0;
200 check(unsafe { cu(&mut flags, self.inner.handle, &mut loc) })?;
201 Ok(AccessFlags::from_raw(flags))
202 }
203
204 pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> Result<CUdeviceptr> {
208 let d = driver()?;
209 let cu = d.cu_mem_alloc_from_pool_async()?;
210 let mut ptr = CUdeviceptr(0);
211 check(unsafe { cu(&mut ptr, bytes, self.inner.handle, stream.as_raw()) })?;
212 Ok(ptr)
213 }
214
215 pub fn export_pointer(&self, ptr: CUdeviceptr) -> Result<CUmemPoolPtrExportData> {
220 let d = driver()?;
221 let cu = d.cu_mem_pool_export_pointer()?;
222 let mut data = CUmemPoolPtrExportData::default();
223 check(unsafe { cu(&mut data, ptr) })?;
224 Ok(data)
225 }
226
227 pub fn import_pointer(&self, mut data: CUmemPoolPtrExportData) -> Result<CUdeviceptr> {
230 let d = driver()?;
231 let cu = d.cu_mem_pool_import_pointer()?;
232 let mut ptr = CUdeviceptr(0);
233 check(unsafe { cu(&mut ptr, self.inner.handle, &mut data) })?;
234 Ok(ptr)
235 }
236}
237
238impl AccessFlags {
239 #[inline]
240 fn from_raw(raw: core::ffi::c_int) -> Self {
241 use baracuda_cuda_sys::types::CUmemAccess_flags;
242 match raw {
243 x if x == CUmemAccess_flags::READ => AccessFlags::Read,
244 x if x == CUmemAccess_flags::READWRITE => AccessFlags::ReadWrite,
245 _ => AccessFlags::None,
246 }
247 }
248}
249
250impl Drop for MemoryPoolInner {
251 fn drop(&mut self) {
252 if !self.owned || self.handle.is_null() {
253 return;
254 }
255 if let Ok(d) = driver() {
256 if let Ok(cu) = d.cu_mem_pool_destroy() {
257 let _ = unsafe { cu(self.handle) };
258 }
259 }
260 }
261}
262
263pub fn default_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
266 context.set_current()?;
267 let d = driver()?;
268 let cu = d.cu_device_get_default_mem_pool()?;
269 let mut handle: CUmemoryPool = core::ptr::null_mut();
270 check(unsafe { cu(&mut handle, device.as_raw()) })?;
271 Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
274}
275
276pub fn current_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
279 context.set_current()?;
280 let d = driver()?;
281 let cu = d.cu_device_get_mem_pool()?;
282 let mut handle: CUmemoryPool = core::ptr::null_mut();
283 check(unsafe { cu(&mut handle, device.as_raw()) })?;
284 Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
285}
286
287pub fn set_current_pool(device: &Device, pool: &MemoryPool) -> Result<()> {
290 let d = driver()?;
291 let cu = d.cu_device_set_mem_pool()?;
292 check(unsafe { cu(device.as_raw(), pool.as_raw()) })
293}