#![cfg(all(feature = "vkfft", feature = "cufft", target_os = "linux"))]
use std::fs::File;
use std::os::fd::FromRawFd;
use std::sync::Arc;
use ash::khr::external_memory_fd;
use ash::vk;
use cudarc::driver::{CudaContext, MappedBuffer};
use crate::{Device, GpuError};
pub struct SharedMemory {
vk_buffer: vk::Buffer,
vk_memory: vk::DeviceMemory,
ash_device: ash::Device,
size_bytes: u64,
mapped: MappedBuffer,
}
impl SharedMemory {
pub fn new(
vk_dev: &Device,
cuda_ctx: &Arc<CudaContext>,
size_bytes: u64,
) -> Result<Self, GpuError> {
use wgpu::hal::api::Vulkan;
crate::shared_fft::check_same_gpu(vk_dev, cuda_ctx)?;
let (ash_device, ash_instance) = unsafe {
let hal_device = vk_dev
.device
.as_hal::<Vulkan>()
.ok_or(GpuError::VulkanHandlesUnavailable)?;
let device = hal_device.raw_device().clone();
let hal_instance = vk_dev
.instance
.as_hal::<Vulkan>()
.ok_or(GpuError::VulkanHandlesUnavailable)?;
let instance = hal_instance.shared_instance().raw_instance().clone();
(device, instance)
};
let mut external_buf_info = vk::ExternalMemoryBufferCreateInfo::default()
.handle_types(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
let buf_info = vk::BufferCreateInfo::default()
.size(size_bytes)
.usage(
vk::BufferUsageFlags::STORAGE_BUFFER
| vk::BufferUsageFlags::TRANSFER_SRC
| vk::BufferUsageFlags::TRANSFER_DST,
)
.sharing_mode(vk::SharingMode::EXCLUSIVE)
.push_next(&mut external_buf_info);
let vk_buffer = unsafe {
ash_device
.create_buffer(&buf_info, None)
.map_err(|e| GpuError::ShaderCompilation {
msg: format!("create_buffer (shared): {e:?}"),
})?
};
let mem_req = unsafe { ash_device.get_buffer_memory_requirements(vk_buffer) };
use ash::vk::Handle;
let phys_dev_handle: u64 = vk_dev.raw_vulkan()?.physical_device;
let phys_dev = vk::PhysicalDevice::from_raw(phys_dev_handle);
let mem_props = unsafe { ash_instance.get_physical_device_memory_properties(phys_dev) };
let wanted =
vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
let mem_type_idx = (0..mem_props.memory_type_count)
.find(|&i| {
let supported = (mem_req.memory_type_bits & (1 << i)) != 0;
let has_flags = mem_props.memory_types[i as usize]
.property_flags
.contains(wanted);
supported && has_flags
})
.ok_or_else(|| GpuError::ShaderCompilation {
msg: "no host-visible coherent memory type supports OPAQUE_FD export".into(),
})?;
let mut export_info = vk::ExportMemoryAllocateInfo::default()
.handle_types(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
let alloc_info = vk::MemoryAllocateInfo::default()
.allocation_size(mem_req.size)
.memory_type_index(mem_type_idx)
.push_next(&mut export_info);
let vk_memory = unsafe {
ash_device
.allocate_memory(&alloc_info, None)
.map_err(|e| GpuError::ShaderCompilation {
msg: format!("allocate_memory (shared): {e:?}"),
})?
};
unsafe {
ash_device
.bind_buffer_memory(vk_buffer, vk_memory, 0)
.map_err(|e| GpuError::ShaderCompilation {
msg: format!("bind_buffer_memory (shared): {e:?}"),
})?;
}
let ext_mem_fd_loader = external_memory_fd::Device::new(&ash_instance, &ash_device);
let fd_info = vk::MemoryGetFdInfoKHR::default()
.memory(vk_memory)
.handle_type(vk::ExternalMemoryHandleTypeFlags::OPAQUE_FD);
let raw_fd = unsafe {
ext_mem_fd_loader
.get_memory_fd(&fd_info)
.map_err(|e| GpuError::ShaderCompilation {
msg: format!("vkGetMemoryFdKHR: {e:?}"),
})?
};
let file = unsafe { File::from_raw_fd(raw_fd) };
let ext_mem = unsafe {
cuda_ctx
.import_external_memory(file, mem_req.size)
.map_err(|e| GpuError::CudaError(format!("import_external_memory: {e:?}")))?
};
let mapped = ext_mem
.map_all()
.map_err(|e| GpuError::CudaError(format!("ExternalMemory::map_all: {e:?}")))?;
Ok(Self {
vk_buffer,
vk_memory,
ash_device,
size_bytes,
mapped,
})
}
pub fn write_host_bytes(&self, bytes: &[u8]) -> Result<(), GpuError> {
assert!(bytes.len() as u64 <= self.size_bytes);
unsafe {
let ptr = self
.ash_device
.map_memory(
self.vk_memory,
0,
self.size_bytes,
vk::MemoryMapFlags::empty(),
)
.map_err(|e| GpuError::ShaderCompilation {
msg: format!("map_memory: {e:?}"),
})?;
core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.cast::<u8>(), bytes.len());
self.ash_device.unmap_memory(self.vk_memory);
}
Ok(())
}
pub fn cuda_view(&self) -> &MappedBuffer {
&self.mapped
}
pub fn len(&self) -> u64 {
self.size_bytes
}
pub fn is_empty(&self) -> bool {
self.size_bytes == 0
}
}
impl Drop for SharedMemory {
fn drop(&mut self) {
unsafe {
self.ash_device.destroy_buffer(self.vk_buffer, None);
self.ash_device.free_memory(self.vk_memory, None);
}
}
}