use std::sync::Arc;
use baracuda_cuda_sys::types::{
CUDA_EXTERNAL_MEMORY_BUFFER_DESC, CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC, CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS,
CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS,
};
use baracuda_cuda_sys::{driver, CUdeviceptr, CUexternalMemory, CUexternalSemaphore};
use crate::context::Context;
use crate::error::{check, Result};
use crate::stream::Stream;
#[derive(Clone)]
pub struct ExternalMemory {
inner: Arc<ExternalMemoryInner>,
}
struct ExternalMemoryInner {
handle: CUexternalMemory,
#[allow(dead_code)]
context: Context,
}
unsafe impl Send for ExternalMemoryInner {}
unsafe impl Sync for ExternalMemoryInner {}
impl core::fmt::Debug for ExternalMemoryInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ExternalMemory")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for ExternalMemory {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl ExternalMemory {
pub unsafe fn import(
context: &Context,
desc: &CUDA_EXTERNAL_MEMORY_HANDLE_DESC,
) -> Result<Self> { unsafe {
context.set_current()?;
let d = driver()?;
let cu = d.cu_import_external_memory()?;
let mut handle: CUexternalMemory = core::ptr::null_mut();
check(cu(&mut handle, desc))?;
Ok(Self {
inner: Arc::new(ExternalMemoryInner {
handle,
context: context.clone(),
}),
})
}}
pub fn mapped_buffer(&self, offset: u64, size: u64, flags: u32) -> Result<CUdeviceptr> {
let d = driver()?;
let cu = d.cu_external_memory_get_mapped_buffer()?;
let desc = CUDA_EXTERNAL_MEMORY_BUFFER_DESC {
offset,
size,
flags,
reserved: [0; 16],
};
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, self.inner.handle, &desc) })?;
Ok(ptr)
}
#[inline]
pub fn as_raw(&self) -> CUexternalMemory {
self.inner.handle
}
}
impl Drop for ExternalMemoryInner {
fn drop(&mut self) {
if self.handle.is_null() {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_destroy_external_memory() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[derive(Clone)]
pub struct ExternalSemaphore {
inner: Arc<ExternalSemaphoreInner>,
}
struct ExternalSemaphoreInner {
handle: CUexternalSemaphore,
#[allow(dead_code)]
context: Context,
}
unsafe impl Send for ExternalSemaphoreInner {}
unsafe impl Sync for ExternalSemaphoreInner {}
impl core::fmt::Debug for ExternalSemaphoreInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ExternalSemaphore")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for ExternalSemaphore {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl ExternalSemaphore {
pub unsafe fn import(
context: &Context,
desc: &CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC,
) -> Result<Self> { unsafe {
context.set_current()?;
let d = driver()?;
let cu = d.cu_import_external_semaphore()?;
let mut handle: CUexternalSemaphore = core::ptr::null_mut();
check(cu(&mut handle, desc))?;
Ok(Self {
inner: Arc::new(ExternalSemaphoreInner {
handle,
context: context.clone(),
}),
})
}}
pub fn signal_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_signal_external_semaphores_async()?;
let params = CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::fence_value(value);
check(unsafe { cu(&self.inner.handle, ¶ms, 1, stream.as_raw()) })
}
pub fn wait_fence_async(&self, value: u64, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_wait_external_semaphores_async()?;
let params = CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::fence_value(value);
check(unsafe { cu(&self.inner.handle, ¶ms, 1, stream.as_raw()) })
}
#[inline]
pub fn as_raw(&self) -> CUexternalSemaphore {
self.inner.handle
}
}
impl Drop for ExternalSemaphoreInner {
fn drop(&mut self) {
if self.handle.is_null() {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_destroy_external_semaphore() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use baracuda_cuda_sys::types::CUexternalMemoryHandleType;
#[test]
fn struct_sizes_match_cuda_abi() {
use core::mem::size_of;
assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_HANDLE_DESC>(), 104);
assert_eq!(size_of::<CUDA_EXTERNAL_MEMORY_BUFFER_DESC>(), 88);
assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC>(), 96);
assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS>(), 144);
assert_eq!(size_of::<CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS>(), 144);
}
#[test]
fn handle_desc_builders_encode_fd_and_win32() {
let d = CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_fd(42, 1024);
assert_eq!(d.type_, CUexternalMemoryHandleType::OPAQUE_FD);
assert_eq!(d.size, 1024);
assert_eq!(d.handle[0] as i32, 42);
let h: *mut core::ffi::c_void = 0xDEAD_BEEF_1234_5678u64 as *mut _;
let n: *const core::ffi::c_void = 0xAAAA_BBBB_CCCC_DDDDu64 as *const _;
let d = unsafe {
CUDA_EXTERNAL_MEMORY_HANDLE_DESC::from_win32_handle(
CUexternalMemoryHandleType::OPAQUE_WIN32,
h,
n,
2048,
)
};
assert_eq!(d.handle[0], 0xDEAD_BEEF_1234_5678);
assert_eq!(d.handle[1], 0xAAAA_BBBB_CCCC_DDDD);
assert_eq!(d.size, 2048);
}
}