use ash::vk;
use gpu_allocator::vulkan::{Allocator, AllocatorCreateDesc, Allocation, AllocationCreateDesc, AllocationScheme};
use gpu_allocator::MemoryLocation;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use crate::backend::{MkGpuBackend, MkGpuCapabilities};
use crate::buffer::MkBufferUsage;
use crate::memory::MkMemoryType;
#[derive(Debug, Clone)]
pub struct VulkanConfig {
pub app_name: String,
pub app_version: u32,
pub validation: bool,
pub device_index: Option<usize>,
}
impl Default for VulkanConfig {
fn default() -> Self {
Self {
app_name: "memkit-gpu".to_string(),
app_version: 1,
validation: cfg!(debug_assertions),
device_index: None,
}
}
}
pub struct VulkanBackend {
#[allow(dead_code)]
entry: ash::Entry,
instance: ash::Instance,
physical_device: vk::PhysicalDevice,
device: ash::Device,
queue: vk::Queue,
queue_family_index: u32,
command_pool: vk::CommandPool,
allocator: Mutex<Allocator>,
buffers: Mutex<HashMap<u64, VulkanBuffer>>,
next_id: AtomicU64,
device_properties: vk::PhysicalDeviceProperties,
}
struct VulkanBuffer {
buffer: vk::Buffer,
allocation: Allocation,
size: usize,
usage: MkBufferUsage,
memory_type: MkMemoryType,
mapped_ptr: Option<*mut u8>,
}
unsafe impl Send for VulkanBuffer {}
unsafe impl Sync for VulkanBuffer {}
#[derive(Clone, Debug)]
pub struct VulkanBufferHandle {
id: u64,
size: usize,
memory_type: MkMemoryType,
}
unsafe impl Send for VulkanBufferHandle {}
unsafe impl Sync for VulkanBufferHandle {}
#[derive(Debug)]
pub enum VulkanError {
Vulkan(vk::Result),
Allocator(gpu_allocator::AllocationError),
NoSuitableDevice,
BufferNotFound(u64),
NotMappable,
LoadError(String),
}
impl std::fmt::Display for VulkanError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VulkanError::Vulkan(e) => write!(f, "Vulkan error: {:?}", e),
VulkanError::Allocator(e) => write!(f, "Allocator error: {}", e),
VulkanError::NoSuitableDevice => write!(f, "No suitable GPU device found"),
VulkanError::BufferNotFound(id) => write!(f, "Buffer {} not found", id),
VulkanError::NotMappable => write!(f, "Buffer is not mappable"),
VulkanError::LoadError(msg) => write!(f, "Failed to load Vulkan: {}", msg),
}
}
}
impl std::error::Error for VulkanError {}
impl From<vk::Result> for VulkanError {
fn from(e: vk::Result) -> Self {
VulkanError::Vulkan(e)
}
}
impl From<gpu_allocator::AllocationError> for VulkanError {
fn from(e: gpu_allocator::AllocationError) -> Self {
VulkanError::Allocator(e)
}
}
impl VulkanBackend {
pub fn new(config: VulkanConfig) -> Result<Self, VulkanError> {
unsafe { Self::create_internal(config) }
}
unsafe fn create_internal(config: VulkanConfig) -> Result<Self, VulkanError> {
let entry = ash::Entry::load()
.map_err(|e| VulkanError::LoadError(e.to_string()))?;
let app_name = std::ffi::CString::new(config.app_name.as_str()).unwrap();
let engine_name = std::ffi::CString::new("memkit-gpu").unwrap();
let app_info = vk::ApplicationInfo::default()
.application_name(&app_name)
.application_version(config.app_version)
.engine_name(&engine_name)
.engine_version(1)
.api_version(vk::make_api_version(0, 1, 2, 0));
let mut layers = Vec::new();
if config.validation {
let validation_layer = std::ffi::CString::new("VK_LAYER_KHRONOS_validation").unwrap();
layers.push(validation_layer);
}
let layer_ptrs: Vec<*const i8> = layers.iter().map(|l| l.as_ptr()).collect();
let create_info = vk::InstanceCreateInfo::default()
.application_info(&app_info)
.enabled_layer_names(&layer_ptrs);
let instance = entry.create_instance(&create_info, None)?;
let physical_devices = instance.enumerate_physical_devices()?;
if physical_devices.is_empty() {
return Err(VulkanError::NoSuitableDevice);
}
let physical_device = if let Some(idx) = config.device_index {
*physical_devices.get(idx).ok_or(VulkanError::NoSuitableDevice)?
} else {
physical_devices
.iter()
.find(|&&pd| {
let props = instance.get_physical_device_properties(pd);
props.device_type == vk::PhysicalDeviceType::DISCRETE_GPU
})
.copied()
.unwrap_or(physical_devices[0])
};
let device_properties = instance.get_physical_device_properties(physical_device);
let queue_families = instance.get_physical_device_queue_family_properties(physical_device);
let queue_family_index = queue_families
.iter()
.enumerate()
.find(|(_, props)| props.queue_flags.contains(vk::QueueFlags::GRAPHICS | vk::QueueFlags::TRANSFER))
.map(|(i, _)| i as u32)
.ok_or(VulkanError::NoSuitableDevice)?;
let queue_priorities = [1.0f32];
let queue_create_info = vk::DeviceQueueCreateInfo::default()
.queue_family_index(queue_family_index)
.queue_priorities(&queue_priorities);
let device_create_info = vk::DeviceCreateInfo::default()
.queue_create_infos(std::slice::from_ref(&queue_create_info));
let device = instance.create_device(physical_device, &device_create_info, None)?;
let queue = device.get_device_queue(queue_family_index, 0);
let pool_create_info = vk::CommandPoolCreateInfo::default()
.queue_family_index(queue_family_index)
.flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER);
let command_pool = device.create_command_pool(&pool_create_info, None)?;
let allocator = Allocator::new(&AllocatorCreateDesc {
instance: instance.clone(),
device: device.clone(),
physical_device,
debug_settings: Default::default(),
buffer_device_address: false,
allocation_sizes: Default::default(),
})?;
Ok(Self {
entry,
instance,
physical_device,
device,
queue,
queue_family_index,
command_pool,
allocator: Mutex::new(allocator),
buffers: Mutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
device_properties,
})
}
fn memory_location_from_type(memory_type: MkMemoryType) -> MemoryLocation {
match memory_type {
MkMemoryType::DeviceLocal => MemoryLocation::GpuOnly,
MkMemoryType::HostVisible => MemoryLocation::CpuToGpu,
MkMemoryType::HostCached => MemoryLocation::GpuToCpu,
MkMemoryType::Unified => MemoryLocation::CpuToGpu,
}
}
fn vk_buffer_usage_from_mk(usage: MkBufferUsage) -> vk::BufferUsageFlags {
let mut flags = vk::BufferUsageFlags::empty();
if usage.contains(MkBufferUsage::TRANSFER_SRC) {
flags |= vk::BufferUsageFlags::TRANSFER_SRC;
}
if usage.contains(MkBufferUsage::TRANSFER_DST) {
flags |= vk::BufferUsageFlags::TRANSFER_DST;
}
if usage.contains(MkBufferUsage::UNIFORM) {
flags |= vk::BufferUsageFlags::UNIFORM_BUFFER;
}
if usage.contains(MkBufferUsage::STORAGE) {
flags |= vk::BufferUsageFlags::STORAGE_BUFFER;
}
if usage.contains(MkBufferUsage::VERTEX) {
flags |= vk::BufferUsageFlags::VERTEX_BUFFER;
}
if usage.contains(MkBufferUsage::INDEX) {
flags |= vk::BufferUsageFlags::INDEX_BUFFER;
}
if flags.is_empty() {
flags = vk::BufferUsageFlags::TRANSFER_DST;
}
flags
}
}
impl Drop for VulkanBackend {
fn drop(&mut self) {
unsafe {
self.device.device_wait_idle().ok();
let mut buffers = self.buffers.lock().unwrap();
let mut allocator = self.allocator.lock().unwrap();
for (_, buffer) in buffers.drain() {
self.device.destroy_buffer(buffer.buffer, None);
allocator.free(buffer.allocation).ok();
}
drop(allocator);
drop(buffers);
self.device.destroy_command_pool(self.command_pool, None);
self.device.destroy_device(None);
self.instance.destroy_instance(None);
}
}
}
impl MkGpuBackend for VulkanBackend {
type BufferHandle = VulkanBufferHandle;
type Error = VulkanError;
fn name(&self) -> &'static str {
"Vulkan"
}
fn capabilities(&self) -> MkGpuCapabilities {
let props = &self.device_properties;
MkGpuCapabilities {
max_buffer_size: props.limits.max_storage_buffer_range as usize,
max_allocations: 4096,
unified_memory: false,
coherent_memory: true,
device_name: unsafe {
std::ffi::CStr::from_ptr(props.device_name.as_ptr())
.to_string_lossy()
.to_string()
},
vendor_name: format!("Vendor ID: {}", props.vendor_id),
}
}
fn create_buffer(
&self,
size: usize,
usage: MkBufferUsage,
memory_type: MkMemoryType,
) -> Result<Self::BufferHandle, Self::Error> {
let vk_usage = Self::vk_buffer_usage_from_mk(usage);
let buffer_info = vk::BufferCreateInfo::default()
.size(size as u64)
.usage(vk_usage)
.sharing_mode(vk::SharingMode::EXCLUSIVE);
let buffer = unsafe { self.device.create_buffer(&buffer_info, None)? };
let requirements = unsafe { self.device.get_buffer_memory_requirements(buffer) };
let allocation = self.allocator.lock().unwrap().allocate(&AllocationCreateDesc {
name: "memkit buffer",
requirements,
location: Self::memory_location_from_type(memory_type),
linear: true,
allocation_scheme: AllocationScheme::GpuAllocatorManaged,
})?;
unsafe {
self.device.bind_buffer_memory(buffer, allocation.memory(), allocation.offset())?;
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let vk_buffer = VulkanBuffer {
buffer,
allocation,
size,
usage,
memory_type,
mapped_ptr: None,
};
self.buffers.lock().unwrap().insert(id, vk_buffer);
Ok(VulkanBufferHandle { id, size, memory_type })
}
fn destroy_buffer(&self, handle: &Self::BufferHandle) {
if let Some(buffer) = self.buffers.lock().unwrap().remove(&handle.id) {
unsafe {
self.device.destroy_buffer(buffer.buffer, None);
}
self.allocator.lock().unwrap().free(buffer.allocation).ok();
}
}
fn map(&self, handle: &Self::BufferHandle) -> Option<*mut u8> {
let mut buffers = self.buffers.lock().unwrap();
let buffer = buffers.get_mut(&handle.id)?;
if buffer.memory_type == MkMemoryType::DeviceLocal {
return None;
}
buffer.allocation.mapped_ptr().map(|p| p.as_ptr() as *mut u8)
}
fn unmap(&self, _handle: &Self::BufferHandle) {
}
fn flush(&self, handle: &Self::BufferHandle, offset: usize, size: usize) {
let buffers = self.buffers.lock().unwrap();
if let Some(buffer) = buffers.get(&handle.id) {
unsafe {
let range = vk::MappedMemoryRange::default()
.memory(buffer.allocation.memory())
.offset(buffer.allocation.offset() + offset as u64)
.size(size as u64);
self.device.flush_mapped_memory_ranges(&[range]).ok();
}
}
}
fn invalidate(&self, handle: &Self::BufferHandle, offset: usize, size: usize) {
let buffers = self.buffers.lock().unwrap();
if let Some(buffer) = buffers.get(&handle.id) {
unsafe {
let range = vk::MappedMemoryRange::default()
.memory(buffer.allocation.memory())
.offset(buffer.allocation.offset() + offset as u64)
.size(size as u64);
self.device.invalidate_mapped_memory_ranges(&[range]).ok();
}
}
}
fn copy_buffer(
&self,
src: &Self::BufferHandle,
dst: &Self::BufferHandle,
size: usize,
) -> Result<(), Self::Error> {
self.copy_buffer_regions(src, 0, dst, 0, size)
}
fn copy_buffer_regions(
&self,
src: &Self::BufferHandle,
src_offset: usize,
dst: &Self::BufferHandle,
dst_offset: usize,
size: usize,
) -> Result<(), Self::Error> {
let buffers = self.buffers.lock().unwrap();
let src_buf = buffers.get(&src.id).ok_or(VulkanError::BufferNotFound(src.id))?;
let dst_buf = buffers.get(&dst.id).ok_or(VulkanError::BufferNotFound(dst.id))?;
let src_vk = src_buf.buffer;
let dst_vk = dst_buf.buffer;
drop(buffers);
unsafe {
let alloc_info = vk::CommandBufferAllocateInfo::default()
.command_pool(self.command_pool)
.level(vk::CommandBufferLevel::PRIMARY)
.command_buffer_count(1);
let cmd_buffers = self.device.allocate_command_buffers(&alloc_info)?;
let cmd = cmd_buffers[0];
let begin_info = vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT);
self.device.begin_command_buffer(cmd, &begin_info)?;
let copy_region = vk::BufferCopy {
src_offset: src_offset as u64,
dst_offset: dst_offset as u64,
size: size as u64,
};
self.device.cmd_copy_buffer(cmd, src_vk, dst_vk, &[copy_region]);
self.device.end_command_buffer(cmd)?;
let submit_info = vk::SubmitInfo::default()
.command_buffers(&cmd_buffers);
self.device.queue_submit(self.queue, &[submit_info], vk::Fence::null())?;
self.device.queue_wait_idle(self.queue)?;
self.device.free_command_buffers(self.command_pool, &cmd_buffers);
}
Ok(())
}
fn wait_idle(&self) -> Result<(), Self::Error> {
unsafe {
self.device.device_wait_idle()?;
}
Ok(())
}
}