use std::ffi::c_void;
use std::os::unix::io::{FromRawFd, OwnedFd};
use ash::khr::external_memory_fd;
use ash::vk;
use gpufft_cuda_sys as sys;
use super::SharedMemoryError;
pub struct SharedMemory {
vk_buffer: vk::Buffer,
vk_memory: vk::DeviceMemory,
ash_device: ash::Device,
alloc_size: u64,
size_bytes: u64,
ext_mem_handle: sys::cudaExternalMemory_t,
device_ptr: *mut c_void,
}
unsafe impl Send for SharedMemory {}
unsafe impl Sync for SharedMemory {}
impl SharedMemory {
pub fn new(
vk_dev: &crate::vulkan::VulkanDevice,
cuda_dev: &crate::cuda::CudaDevice,
size_bytes: u64,
) -> Result<Self, SharedMemoryError> {
cuda_dev
.make_current()
.map_err(|e| SharedMemoryError::Cuda(format!("{e:?}")))?;
let handles = vk_dev.raw_handles();
let ash_device = handles.device.clone();
let ash_instance = handles.instance.clone();
let phys_dev = handles.physical_device;
let vk_buffer = {
let mut external_buf_info = vk::ExternalMemoryBufferCreateInfo::default()
.handle_types(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
let buf_info = vk::BufferCreateInfo::default()
.size(size_bytes)
.usage(
vk::BufferUsageFlags::STORAGE_BUFFER
| vk::BufferUsageFlags::TRANSFER_SRC
| vk::BufferUsageFlags::TRANSFER_DST,
)
.sharing_mode(vk::SharingMode::EXCLUSIVE)
.push_next(&mut external_buf_info);
unsafe {
ash_device
.create_buffer(&buf_info, None)
.map_err(|e| SharedMemoryError::Vulkan(format!("create_buffer: {e:?}")))?
}
};
let mem_req = unsafe { ash_device.get_buffer_memory_requirements(vk_buffer) };
let mem_props =
unsafe { ash_instance.get_physical_device_memory_properties(phys_dev) };
let wanted =
vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
let mem_type_idx = (0..mem_props.memory_type_count)
.find(|&i| {
let supported = (mem_req.memory_type_bits & (1 << i)) != 0;
let has_flags = mem_props.memory_types[i as usize]
.property_flags
.contains(wanted);
supported && has_flags
})
.ok_or_else(|| {
SharedMemoryError::Vulkan(
"no host-visible coherent memory type supports OPAQUE_FD export".into(),
)
})?;
let vk_memory = {
let mut export_info = vk::ExportMemoryAllocateInfo::default()
.handle_types(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
let alloc_info = vk::MemoryAllocateInfo::default()
.allocation_size(mem_req.size)
.memory_type_index(mem_type_idx)
.push_next(&mut export_info);
unsafe {
ash_device
.allocate_memory(&alloc_info, None)
.map_err(|e| SharedMemoryError::Vulkan(format!("allocate_memory: {e:?}")))?
}
};
unsafe {
ash_device
.bind_buffer_memory(vk_buffer, vk_memory, 0)
.map_err(|e| SharedMemoryError::Vulkan(format!("bind_buffer_memory: {e:?}")))?;
}
let raw_fd = {
let loader = external_memory_fd::Device::new(&ash_instance, &ash_device);
let fd_info = vk::MemoryGetFdInfoKHR::default()
.memory(vk_memory)
.handle_type(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
unsafe {
loader
.get_memory_fd(&fd_info)
.map_err(|e| SharedMemoryError::Vulkan(format!("vkGetMemoryFdKHR: {e:?}")))?
}
};
let mut ext_mem_handle: sys::cudaExternalMemory_t = std::ptr::null_mut();
{
let mut desc: sys::cudaExternalMemoryHandleDesc =
unsafe { std::mem::zeroed() };
desc.type_ =
sys::cudaExternalMemoryHandleType_cudaExternalMemoryHandleTypeOpaqueFd;
desc.handle.fd = raw_fd;
desc.size = mem_req.size;
desc.flags = 0;
let rc = unsafe {
sys::cudaImportExternalMemory(&mut ext_mem_handle, &desc)
};
if rc != sys::cudaError_cudaSuccess {
let _guard = unsafe { OwnedFd::from_raw_fd(raw_fd) };
return Err(SharedMemoryError::Cuda(format!(
"cudaImportExternalMemory: {rc:?}"
)));
}
}
let mut device_ptr: *mut c_void = std::ptr::null_mut();
{
let mut buf_desc: sys::cudaExternalMemoryBufferDesc =
unsafe { std::mem::zeroed() };
buf_desc.offset = 0;
buf_desc.size = mem_req.size;
buf_desc.flags = 0;
let rc = unsafe {
sys::cudaExternalMemoryGetMappedBuffer(&mut device_ptr, ext_mem_handle, &buf_desc)
};
if rc != sys::cudaError_cudaSuccess {
unsafe { sys::cudaDestroyExternalMemory(ext_mem_handle) };
unsafe { ash_device.destroy_buffer(vk_buffer, None) };
unsafe { ash_device.free_memory(vk_memory, None) };
return Err(SharedMemoryError::Cuda(format!(
"cudaExternalMemoryGetMappedBuffer: {rc:?}"
)));
}
}
Ok(Self {
vk_buffer,
vk_memory,
ash_device,
alloc_size: mem_req.size,
size_bytes,
ext_mem_handle,
device_ptr,
})
}
pub fn write_host_bytes(&self, bytes: &[u8]) -> Result<(), SharedMemoryError> {
assert!(
bytes.len() as u64 <= self.size_bytes,
"write_host_bytes: {} bytes exceeds allocation size {}",
bytes.len(),
self.size_bytes,
);
unsafe {
let ptr = self
.ash_device
.map_memory(
self.vk_memory,
0,
self.alloc_size,
vk::MemoryMapFlags::empty(),
)
.map_err(|e| SharedMemoryError::Vulkan(format!("map_memory: {e:?}")))?;
std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.cast::<u8>(), bytes.len());
self.ash_device.unmap_memory(self.vk_memory);
}
Ok(())
}
pub fn cuda_device_ptr(&self) -> *mut c_void {
self.device_ptr
}
pub fn len(&self) -> u64 {
self.size_bytes
}
pub fn is_empty(&self) -> bool {
self.size_bytes == 0
}
}
impl Drop for SharedMemory {
fn drop(&mut self) {
unsafe {
sys::cudaFree(self.device_ptr);
}
unsafe {
sys::cudaDestroyExternalMemory(self.ext_mem_handle);
}
unsafe {
self.ash_device.destroy_buffer(self.vk_buffer, None);
self.ash_device.free_memory(self.vk_memory, None);
}
}
}