baracuda_driver/
external.rs1use std::sync::Arc;
25
26use baracuda_cuda_sys::types::{
27 CUDA_EXTERNAL_MEMORY_BUFFER_DESC, CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
28 CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC, CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS,
29 CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS,
30};
31use baracuda_cuda_sys::{driver, CUdeviceptr, CUexternalMemory, CUexternalSemaphore};
32
33use crate::context::Context;
34use crate::error::{check, Result};
35use crate::stream::Stream;
36
37#[derive(Clone)]
40pub struct ExternalMemory {
41 inner: Arc<ExternalMemoryInner>,
42}
43
44struct ExternalMemoryInner {
45 handle: CUexternalMemory,
46 #[allow(dead_code)]
47 context: Context,
48}
49
50unsafe impl Send for ExternalMemoryInner {}
51unsafe impl Sync for ExternalMemoryInner {}
52
53impl core::fmt::Debug for ExternalMemoryInner {
54 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
55 f.debug_struct("ExternalMemory")
56 .field("handle", &self.handle)
57 .finish_non_exhaustive()
58 }
59}
60
61impl core::fmt::Debug for ExternalMemory {
62 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63 self.inner.fmt(f)
64 }
65}
66
67impl ExternalMemory {
68 pub unsafe fn import(
77 context: &Context,
78 desc: &CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
79 ) -> Result<Self> { unsafe {
80 context.set_current()?;
81 let d = driver()?;
82 let cu = d.cu_import_external_memory()?;
83 let mut handle: CUexternalMemory = core::ptr::null_mut();
84 check(cu(&mut handle, desc))?;
85 Ok(Self {
86 inner: Arc::new(ExternalMemoryInner {
87 handle,
88 context: context.clone(),
89 }),
90 })
91 }}
92
93 pub fn mapped_buffer(&self, offset: u64, size: u64, flags: u32) -> Result<CUdeviceptr> {
96 let d = driver()?;
97 let cu = d.cu_external_memory_get_mapped_buffer()?;
98 let desc = CUDA_EXTERNAL_MEMORY_BUFFER_DESC {
99 offset,
100 size,
101 flags,
102 reserved: [0; 16],
103 };
104 let mut ptr = CUdeviceptr(0);
105 check(unsafe { cu(&mut ptr, self.inner.handle, &desc) })?;
106 Ok(ptr)
107 }
108
109 #[inline]
110 pub fn as_raw(&self) -> CUexternalMemory {
111 self.inner.handle
112 }
113}
114
115impl Drop for ExternalMemoryInner {
116 fn drop(&mut self) {
117 if self.handle.is_null() {
118 return;
119 }
120 if let Ok(d) = driver() {
121 if let Ok(cu) = d.cu_destroy_external_memory() {
122 let _ = unsafe { cu(self.handle) };
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
131pub struct ExternalSemaphore {
132 inner: Arc<ExternalSemaphoreInner>,
133}
134
135struct ExternalSemaphoreInner {
136 handle: CUexternalSemaphore,
137 #[allow(dead_code)]
138 context: Context,
139}
140
141unsafe impl Send for ExternalSemaphoreInner {}
142unsafe impl Sync for ExternalSemaphoreInner {}
143
144impl core::fmt::Debug for ExternalSemaphoreInner {
145 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
146 f.debug_struct("ExternalSemaphore")
147 .field("handle", &self.handle)
148 .finish_non_exhaustive()
149 }
150}
151
152impl core::fmt::Debug for ExternalSemaphore {
153 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
154 self.inner.fmt(f)
155 }
156}
157
158impl ExternalSemaphore {
159 pub unsafe fn import(
166 context: &Context,
167 desc: &CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC,
168 ) -> Result<Self> { unsafe {
169 context.set_current()?;
170 let d = driver()?;
171 let cu = d.cu_import_external_semaphore()?;
172 let mut handle: CUexternalSemaphore = core::ptr::null_mut();
173 check(cu(&mut handle, desc))?;
174 Ok(Self {
175 inner: Arc::new(ExternalSemaphoreInner {
176 handle,
177 context: context.clone(),
178 }),
179 })
180 }}
181
182 pub fn signal_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
185 let d = driver()?;
186 let cu = d.cu_signal_external_semaphores_async()?;
187 let params = CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::fence_value(value);
188 check(unsafe { cu(&self.inner.handle, ¶ms, 1, stream.as_raw()) })
189 }
190
191 pub fn wait_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
194 let d = driver()?;
195 let cu = d.cu_wait_external_semaphores_async()?;
196 let params = CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::fence_value(value);
197 check(unsafe { cu(&self.inner.handle, ¶ms, 1, stream.as_raw()) })
198 }
199
200 #[inline]
201 pub fn as_raw(&self) -> CUexternalSemaphore {
202 self.inner.handle
203 }
204}
205
206impl Drop for ExternalSemaphoreInner {
207 fn drop(&mut self) {
208 if self.handle.is_null() {
209 return;
210 }
211 if let Ok(d) = driver() {
212 if let Ok(cu) = d.cu_destroy_external_semaphore() {
213 let _ = unsafe { cu(self.handle) };
214 }
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use baracuda_cuda_sys::types::CUexternalMemoryHandleType;
223
224 #[test]
225 fn struct_sizes_match_cuda_abi() {
226 use core::mem::size_of;
230 assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_HANDLE_DESC>(), 104);
231 assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_BUFFER_DESC>(), 88);
232 assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC>(), 96);
233 assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS>(), 144);
235 assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS>(), 144);
236 }
237
238 #[test]
239 fn handle_desc_builders_encode_fd_and_win32() {
240 let d = CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_fd(42, 1024);
241 assert_eq!(d.type_, CUexternalMemoryHandleType::OPAQUE_FD);
242 assert_eq!(d.size, 1024);
243 assert_eq!(d.handle[0] as i32, 42);
245
246 let h: *mut core::ffi::c_void = 0xDEAD_BEEF_1234_5678u64 as *mut _;
247 let n: *const core::ffi::c_void = 0xAAAA_BBBB_CCCC_DDDDu64 as *const _;
248 let d = unsafe {
249 CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_win32_handle(
250 CUexternalMemoryHandleType::OPAQUE_WIN32,
251 h,
252 n,
253 2048,
254 )
255 };
256 assert_eq!(d.handle[0], 0xDEAD_BEEF_1234_5678);
257 assert_eq!(d.handle[1], 0xAAAA_BBBB_CCCC_DDDD);
258 assert_eq!(d.size, 2048);
259 }
260}