Skip to main content

baracuda_driver/
external.rs

1//! External memory / semaphore interop — import buffers and sync
2//! primitives from Vulkan, D3D11, D3D12, NvSci, and OpaqueFd sources.
3//!
4//! A typical pipeline is:
5//!
6//! 1. A graphics API (Vulkan, D3D12) exports a buffer or image and a
7//!    timeline fence as OS-level handles (file descriptor on Linux, NT
8//!    HANDLE on Windows).
9//! 2. CUDA imports those handles via [`ExternalMemory::import`] and
10//!    [`ExternalSemaphore::import`].
11//! 3. CUDA obtains a device pointer into the shared buffer with
12//!    [`ExternalMemory::mapped_buffer`].
13//! 4. Each frame, CUDA [`ExternalSemaphore::wait_fence_async`]s on the
14//!    graphics-API fence, does compute, then
15//!    [`ExternalSemaphore::signal_fence_async`]s a fence value the
16//!    graphics API is waiting on.
17//!
18//! **Testing note:** the baracuda crate ships these APIs but cannot
19//! end-to-end test them without a live Vulkan/D3D12 context. Layout and
20//! symbol-resolution are verified via unit tests in this module; a
21//! matching external-memory/-semaphore example belongs in an
22//! examples/external_interop crate (not yet present).
23
24use std::sync::Arc;
25
26use baracuda_cuda_sys::types::{
27    CUDA_EXTERNAL_MEMORY_BUFFER_DESC, CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
28    CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC, CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS,
29    CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS,
30};
31use baracuda_cuda_sys::{driver, CUdeviceptr, CUexternalMemory, CUexternalSemaphore};
32
33use crate::context::Context;
34use crate::error::{check, Result};
35use crate::stream::Stream;
36
37/// An imported external-memory handle (Vulkan VkDeviceMemory, D3D12 heap,
38/// NvSciBuf, ...). Destroyed on drop.
39#[derive(Clone)]
40pub struct ExternalMemory {
41    inner: Arc<ExternalMemoryInner>,
42}
43
44struct ExternalMemoryInner {
45    handle: CUexternalMemory,
46    #[allow(dead_code)]
47    context: Context,
48}
49
50unsafe impl Send for ExternalMemoryInner {}
51unsafe impl Sync for ExternalMemoryInner {}
52
53impl core::fmt::Debug for ExternalMemoryInner {
54    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
55        f.debug_struct("ExternalMemory")
56            .field("handle", &self.handle)
57            .finish_non_exhaustive()
58    }
59}
60
61impl core::fmt::Debug for ExternalMemory {
62    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63        self.inner.fmt(f)
64    }
65}
66
67impl ExternalMemory {
68    /// Import an external-memory handle described by `desc`.
69    ///
70    /// # Safety
71    ///
72    /// `desc` must describe a live OS object that the calling process has
73    /// permission to access — a Vulkan-exported file descriptor, a
74    /// D3D12-exported NT HANDLE, etc. CUDA retains a reference to the
75    /// underlying object until this `ExternalMemory` drops.
76    pub unsafe fn import(
77        context: &Context,
78        desc: &CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
79    ) -> Result<Self> { unsafe {
80        context.set_current()?;
81        let d = driver()?;
82        let cu = d.cu_import_external_memory()?;
83        let mut handle: CUexternalMemory = core::ptr::null_mut();
84        check(cu(&mut handle, desc))?;
85        Ok(Self {
86            inner: Arc::new(ExternalMemoryInner {
87                handle,
88                context: context.clone(),
89            }),
90        })
91    }}
92
93    /// Expose a subregion of the external memory as a device pointer.
94    /// The returned pointer is valid until this `ExternalMemory` drops.
95    pub fn mapped_buffer(&self, offset: u64, size: u64, flags: u32) -> Result<CUdeviceptr> {
96        let d = driver()?;
97        let cu = d.cu_external_memory_get_mapped_buffer()?;
98        let desc = CUDA_EXTERNAL_MEMORY_BUFFER_DESC {
99            offset,
100            size,
101            flags,
102            reserved: [0; 16],
103        };
104        let mut ptr = CUdeviceptr(0);
105        check(unsafe { cu(&mut ptr, self.inner.handle, &desc) })?;
106        Ok(ptr)
107    }
108
109    #[inline]
110    pub fn as_raw(&self) -> CUexternalMemory {
111        self.inner.handle
112    }
113}
114
115impl Drop for ExternalMemoryInner {
116    fn drop(&mut self) {
117        if self.handle.is_null() {
118            return;
119        }
120        if let Ok(d) = driver() {
121            if let Ok(cu) = d.cu_destroy_external_memory() {
122                let _ = unsafe { cu(self.handle) };
123            }
124        }
125    }
126}
127
128/// An imported external-semaphore handle (Vulkan VkSemaphore / timeline,
129/// D3D12 fence, NvSciSync, keyed mutex). Destroyed on drop.
130#[derive(Clone)]
131pub struct ExternalSemaphore {
132    inner: Arc<ExternalSemaphoreInner>,
133}
134
135struct ExternalSemaphoreInner {
136    handle: CUexternalSemaphore,
137    #[allow(dead_code)]
138    context: Context,
139}
140
141unsafe impl Send for ExternalSemaphoreInner {}
142unsafe impl Sync for ExternalSemaphoreInner {}
143
144impl core::fmt::Debug for ExternalSemaphoreInner {
145    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
146        f.debug_struct("ExternalSemaphore")
147            .field("handle", &self.handle)
148            .finish_non_exhaustive()
149    }
150}
151
152impl core::fmt::Debug for ExternalSemaphore {
153    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
154        self.inner.fmt(f)
155    }
156}
157
158impl ExternalSemaphore {
159    /// Import an external-semaphore handle described by `desc`.
160    ///
161    /// # Safety
162    ///
163    /// Same discipline as [`ExternalMemory::import`]: `desc.handle` must
164    /// be a live OS object this process may access.
165    pub unsafe fn import(
166        context: &Context,
167        desc: &CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC,
168    ) -> Result<Self> { unsafe {
169        context.set_current()?;
170        let d = driver()?;
171        let cu = d.cu_import_external_semaphore()?;
172        let mut handle: CUexternalSemaphore = core::ptr::null_mut();
173        check(cu(&mut handle, desc))?;
174        Ok(Self {
175            inner: Arc::new(ExternalSemaphoreInner {
176                handle,
177                context: context.clone(),
178            }),
179        })
180    }}
181
182    /// Enqueue a signal of fence value `value` on `stream` for timeline /
183    /// D3D12 fence semaphores.
184    pub fn signal_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
185        let d = driver()?;
186        let cu = d.cu_signal_external_semaphores_async()?;
187        let params = CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::fence_value(value);
188        check(unsafe { cu(&self.inner.handle, &params, 1, stream.as_raw()) })
189    }
190
191    /// Enqueue a wait for fence value `value` on `stream`. The stream
192    /// blocks (on-device) until the external fence reaches that value.
193    pub fn wait_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
194        let d = driver()?;
195        let cu = d.cu_wait_external_semaphores_async()?;
196        let params = CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::fence_value(value);
197        check(unsafe { cu(&self.inner.handle, &params, 1, stream.as_raw()) })
198    }
199
200    #[inline]
201    pub fn as_raw(&self) -> CUexternalSemaphore {
202        self.inner.handle
203    }
204}
205
206impl Drop for ExternalSemaphoreInner {
207    fn drop(&mut self) {
208        if self.handle.is_null() {
209            return;
210        }
211        if let Ok(d) = driver() {
212            if let Ok(cu) = d.cu_destroy_external_semaphore() {
213                let _ = unsafe { cu(self.handle) };
214            }
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use baracuda_cuda_sys::types::CUexternalMemoryHandleType;
223
224    #[test]
225    fn struct_sizes_match_cuda_abi() {
226        // Catches accidental layout drift — these constants come from the
227        // CUDA 13.0 header. If they ever change, this test fires before
228        // anyone uses the FFI.
229        use core::mem::size_of;
230        assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_HANDLE_DESC>(), 104);
231        assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_BUFFER_DESC>(), 88);
232        assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC>(), 96);
233        // params(72) + flags(4) + reserved[16](64) = 140, padded to 144 for 8-byte alignment.
234        assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS>(), 144);
235        assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS>(), 144);
236    }
237
238    #[test]
239    fn handle_desc_builders_encode_fd_and_win32() {
240        let d = CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_fd(42, 1024);
241        assert_eq!(d.type_, CUexternalMemoryHandleType::OPAQUE_FD);
242        assert_eq!(d.size, 1024);
243        // fd sits in the low 4 bytes of handle[0].
244        assert_eq!(d.handle[0] as i32, 42);
245
246        let h: *mut core::ffi::c_void = 0xDEAD_BEEF_1234_5678u64 as *mut _;
247        let n: *const core::ffi::c_void = 0xAAAA_BBBB_CCCC_DDDDu64 as *const _;
248        let d = unsafe {
249            CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_win32_handle(
250                CUexternalMemoryHandleType::OPAQUE_WIN32,
251                h,
252                n,
253                2048,
254            )
255        };
256        assert_eq!(d.handle[0], 0xDEAD_BEEF_1234_5678);
257        assert_eq!(d.handle[1], 0xAAAA_BBBB_CCCC_DDDD);
258        assert_eq!(d.size, 2048);
259    }
260}