1use core::ffi::c_void;
13
14use baracuda_cuda_sys::runtime::runtime;
15use baracuda_cuda_sys::runtime::types::{
16 cudaMemAccessDesc, cudaMemAllocationHandleType, cudaMemAllocationProp, cudaMemAllocationType,
17 cudaMemGenericAllocationHandle_t, cudaMemLocation, cudaMemLocationType,
18};
19
20use crate::device::Device;
21use crate::error::{check, Result};
22use crate::mempool::AccessFlags;
23
24pub fn address_reserve(size: usize, alignment: usize, flags: u64) -> Result<*mut c_void> {
29 let r = runtime()?;
30 let cu = r.cuda_mem_address_reserve()?;
31 let mut ptr: *mut c_void = core::ptr::null_mut();
32 check(unsafe { cu(&mut ptr, size, alignment, core::ptr::null_mut(), flags) })?;
33 Ok(ptr)
34}
35
36pub unsafe fn address_free(ptr: *mut c_void, size: usize) -> Result<()> { unsafe {
42 let r = runtime()?;
43 let cu = r.cuda_mem_address_free()?;
44 check(cu(ptr, size))
45}}
46
47pub fn device_alloc_prop(device: &Device) -> cudaMemAllocationProp {
50 cudaMemAllocationProp {
51 alloc_type: cudaMemAllocationType::PINNED,
52 requested_handle_types: cudaMemAllocationHandleType::NONE,
53 location: cudaMemLocation {
54 type_: cudaMemLocationType::DEVICE,
55 id: device.ordinal(),
56 },
57 win32_handle_meta_data: core::ptr::null_mut(),
58 allocation_flags: [0; 32],
59 }
60}
61
62pub fn allocation_granularity(prop: &cudaMemAllocationProp, option: i32) -> Result<usize> {
65 let r = runtime()?;
66 let cu = r.cuda_mem_get_allocation_granularity()?;
67 let mut g: usize = 0;
68 check(unsafe {
69 cu(
70 &mut g,
71 prop as *const cudaMemAllocationProp as *const c_void,
72 option,
73 )
74 })?;
75 Ok(g)
76}
77
78#[derive(Debug)]
80pub struct MemHandle {
81 handle: cudaMemGenericAllocationHandle_t,
82}
83
84impl MemHandle {
85 pub fn new(size: usize, prop: &cudaMemAllocationProp, flags: u64) -> Result<Self> {
87 let r = runtime()?;
88 let cu = r.cuda_mem_create()?;
89 let mut h: cudaMemGenericAllocationHandle_t = 0;
90 check(unsafe {
91 cu(
92 &mut h,
93 size,
94 prop as *const cudaMemAllocationProp as *const c_void,
95 flags,
96 )
97 })?;
98 Ok(Self { handle: h })
99 }
100
101 pub unsafe fn from_raw(handle: cudaMemGenericAllocationHandle_t) -> Self {
108 Self { handle }
109 }
110
111 #[inline]
112 pub fn as_raw(&self) -> cudaMemGenericAllocationHandle_t {
113 self.handle
114 }
115
116 pub unsafe fn retain(addr: *mut c_void) -> Result<Self> { unsafe {
122 let r = runtime()?;
123 let cu = r.cuda_mem_retain_allocation_handle()?;
124 let mut h: cudaMemGenericAllocationHandle_t = 0;
125 check(cu(&mut h, addr))?;
126 Ok(Self { handle: h })
127 }}
128
129 pub fn properties(&self) -> Result<cudaMemAllocationProp> {
131 let r = runtime()?;
132 let cu = r.cuda_mem_get_allocation_properties_from_handle()?;
133 let mut prop = cudaMemAllocationProp::default();
134 check(unsafe {
135 cu(
136 &mut prop as *mut cudaMemAllocationProp as *mut c_void,
137 self.handle,
138 )
139 })?;
140 Ok(prop)
141 }
142}
143
144impl Drop for MemHandle {
145 fn drop(&mut self) {
146 if let Ok(r) = runtime() {
147 if let Ok(cu) = r.cuda_mem_release() {
148 let _ = unsafe { cu(self.handle) };
149 }
150 }
151 }
152}
153
154pub unsafe fn map(
161 ptr: *mut c_void,
162 size: usize,
163 offset: usize,
164 handle: &MemHandle,
165 flags: u64,
166) -> Result<()> { unsafe {
167 let r = runtime()?;
168 let cu = r.cuda_mem_map()?;
169 check(cu(ptr, size, offset, handle.as_raw(), flags))
170}}
171
172pub unsafe fn unmap(ptr: *mut c_void, size: usize) -> Result<()> { unsafe {
178 let r = runtime()?;
179 let cu = r.cuda_mem_unmap()?;
180 check(cu(ptr, size))
181}}
182
183pub unsafe fn set_access(
189 ptr: *mut c_void,
190 size: usize,
191 device: &Device,
192 flags: AccessFlags,
193) -> Result<()> { unsafe {
194 let r = runtime()?;
195 let cu = r.cuda_mem_set_access()?;
196 let flags_raw = match flags {
198 AccessFlags::None => 0,
199 AccessFlags::Read => 1,
200 AccessFlags::ReadWrite => 3,
201 };
202 let desc = cudaMemAccessDesc {
203 location: cudaMemLocation {
204 type_: cudaMemLocationType::DEVICE,
205 id: device.ordinal(),
206 },
207 flags: flags_raw,
208 };
209 check(cu(
210 ptr,
211 size,
212 &desc as *const cudaMemAccessDesc as *const c_void,
213 1,
214 ))
215}}
216
217pub unsafe fn get_access(ptr: *mut c_void, device: &Device) -> Result<u64> { unsafe {
223 let r = runtime()?;
224 let cu = r.cuda_mem_get_access()?;
225 let loc = cudaMemLocation {
226 type_: cudaMemLocationType::DEVICE,
227 id: device.ordinal(),
228 };
229 let mut flags: u64 = 0;
230 check(cu(
231 &mut flags,
232 &loc as *const cudaMemLocation as *const c_void,
233 ptr,
234 ))?;
235 Ok(flags)
236}}