rotex-vulkan 0.1.0

A Vulkan backend for rotex_core
Documentation
use std::collections::BTreeMap;
use std::ffi::CStr;

use ash::vk;

use crate::core::Instance;
use crate::error::{Error, ErrorKind, Severity, vk_error};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueCategory {
    Graphics,
    Compute,
    Transfer,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QueueRequest {
    pub category: QueueCategory,
    pub count: u32,
}

#[derive(Debug, Clone)]
pub struct DeviceDescriptor {
    pub required_features: vk::PhysicalDeviceFeatures,
    pub enable_swapchain: bool,
    pub queues: Vec<QueueRequest>,
}

pub struct Adapter {
    pub(crate) handle: vk::PhysicalDevice,
    name: String,
    device_type: vk::PhysicalDeviceType,
    limits: vk::PhysicalDeviceLimits,
}

impl Adapter {
    pub(crate) fn new(
        handle: vk::PhysicalDevice,
        name: String,
        device_type: vk::PhysicalDeviceType,
        limits: vk::PhysicalDeviceLimits,
    ) -> Self {
        Self {
            handle,
            name,
            device_type,
            limits,
        }
    }

    pub fn name(&self) -> &str {
        &self.name
    }

    pub fn device_type(&self) -> vk::PhysicalDeviceType {
        self.device_type
    }

    pub fn limits(&self) -> &vk::PhysicalDeviceLimits {
        &self.limits
    }

    pub fn physical_device(&self) -> vk::PhysicalDevice {
        self.handle
    }

    pub fn selection_score(&self) -> u32 {
        match self.device_type {
            vk::PhysicalDeviceType::DISCRETE_GPU => 400,
            vk::PhysicalDeviceType::INTEGRATED_GPU => 300,
            vk::PhysicalDeviceType::VIRTUAL_GPU => 200,
            vk::PhysicalDeviceType::CPU => 100,
            _ => 0,
        }
    }

    pub fn has_swapchain_extension(&self, instance: &Instance) -> Result<bool, Error> {
        let extensions = unsafe {
            instance
                .instance()
                .enumerate_device_extension_properties(self.handle)
        }
        .map_err(vk_error)?;
        Ok(extensions.iter().any(|ext| unsafe {
            CStr::from_ptr(ext.extension_name.as_ptr()) == vk::KHR_SWAPCHAIN_NAME
        }))
    }

    pub fn supports_queue_requests(&self, instance: &Instance, queues: &[QueueRequest]) -> bool {
        let queue_families = unsafe {
            instance
                .instance()
                .get_physical_device_queue_family_properties(self.handle)
        };
        let graphics_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::GRAPHICS))
            .map(|(index, _)| index as u32);
        let compute_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::COMPUTE))
            .map(|(index, _)| index as u32);
        let transfer_any_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::TRANSFER))
            .map(|(index, _)| index as u32);
        let transfer_dedicated_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| {
                family.queue_flags.contains(vk::QueueFlags::TRANSFER)
                    && !family.queue_flags.contains(vk::QueueFlags::GRAPHICS)
                    && !family.queue_flags.contains(vk::QueueFlags::COMPUTE)
            })
            .map(|(index, _)| index as u32);

        let mut has_request = false;
        for request in queues {
            if request.count == 0 {
                continue;
            }
            has_request = true;
            let family_index = match request.category {
                QueueCategory::Graphics => graphics_index,
                QueueCategory::Compute => compute_index,
                QueueCategory::Transfer => transfer_dedicated_index
                    .or(graphics_index)
                    .or(transfer_any_index)
                    .or(compute_index),
            };
            if family_index.is_none() {
                return false;
            }
        }
        has_request
    }

    pub fn request_device(
        &self,
        instance: &Instance,
        desc: DeviceDescriptor,
    ) -> Result<Device, Error> {
        let queue_families = unsafe {
            instance
                .instance()
                .get_physical_device_queue_family_properties(self.handle)
        };
        let graphics_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::GRAPHICS))
            .map(|(index, _)| index as u32);
        let compute_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::COMPUTE))
            .map(|(index, _)| index as u32);
        let transfer_any_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| family.queue_flags.contains(vk::QueueFlags::TRANSFER))
            .map(|(index, _)| index as u32);
        let transfer_dedicated_index = queue_families
            .iter()
            .enumerate()
            .find(|(_, family)| {
                family.queue_flags.contains(vk::QueueFlags::TRANSFER)
                    && !family.queue_flags.contains(vk::QueueFlags::GRAPHICS)
                    && !family.queue_flags.contains(vk::QueueFlags::COMPUTE)
            })
            .map(|(index, _)| index as u32);

        if desc.enable_swapchain {
            let extensions = unsafe {
                instance
                    .instance()
                    .enumerate_device_extension_properties(self.handle)
            }
            .map_err(vk_error)?;
            let has_swapchain = extensions.iter().any(|ext| unsafe {
                CStr::from_ptr(ext.extension_name.as_ptr()) == vk::KHR_SWAPCHAIN_NAME
            });
            if !has_swapchain {
                return Err(Error {
                    kind: ErrorKind::NoCompatibleDevice,
                    severity: Severity::Fatal,
                });
            }
        }

        let mut allocations = Vec::new();
        for request in desc.queues {
            if request.count == 0 {
                continue;
            }
            let family_index = match request.category {
                QueueCategory::Graphics => graphics_index,
                QueueCategory::Compute => compute_index,
                QueueCategory::Transfer => transfer_dedicated_index
                    .or(graphics_index)
                    .or(transfer_any_index)
                    .or(compute_index),
            };
            let family_index = match family_index {
                Some(index) => index,
                None => {
                    return Err(Error {
                        kind: ErrorKind::NoCompatibleDevice,
                        severity: Severity::Fatal,
                    });
                }
            };
            allocations.push(QueueAllocation {
                category: request.category,
                family_index,
                count: request.count,
            });
        }

        if allocations.is_empty() {
            return Err(Error {
                kind: ErrorKind::NoCompatibleDevice,
                severity: Severity::Fatal,
            });
        }

        let mut queue_priorities: BTreeMap<u32, Vec<f32>> = BTreeMap::new();
        for allocation in &allocations {
            let entry = queue_priorities
                .entry(allocation.family_index)
                .or_insert_with(Vec::new);
            entry.extend(std::iter::repeat(1.0).take(allocation.count as usize));
        }

        for (family_index, priorities) in queue_priorities.iter_mut() {
            let max_supported = queue_families[*family_index as usize].queue_count as usize;
            if priorities.len() > max_supported {
                priorities.truncate(max_supported);
            }
        }

        let mut priorities_store = Vec::new();
        let mut queue_layouts = Vec::new();
        for (family_index, priorities) in queue_priorities {
            priorities_store.push(priorities);
            let idx = priorities_store.len() - 1;
            queue_layouts.push((family_index, idx));
        }

        let queue_create_infos: Vec<vk::DeviceQueueCreateInfo> = queue_layouts
            .into_iter()
            .map(|(family_index, idx)| {
                vk::DeviceQueueCreateInfo::default()
                    .queue_family_index(family_index)
                    .queue_priorities(&priorities_store[idx])
            })
            .collect();

        let device_extensions: Vec<*const i8> = if desc.enable_swapchain {
            vec![vk::KHR_SWAPCHAIN_NAME.as_ptr()]
        } else {
            Vec::new()
        };
        let device_create_info = vk::DeviceCreateInfo::default()
            .queue_create_infos(&queue_create_infos)
            .enabled_extension_names(&device_extensions)
            .enabled_features(&desc.required_features);

        let device = unsafe {
            instance
                .instance()
                .create_device(self.handle, &device_create_info, None)
        }
        .map_err(vk_error)?;

        let properties = unsafe {
            instance
                .instance()
                .get_physical_device_properties(self.handle)
        };

        Ok(Device {
            handle: self.handle,
            device,
            properties,
            queues: allocations,
        })
    }
}

#[derive(Debug, Clone)]
pub struct QueueAllocation {
    pub category: QueueCategory,
    pub family_index: u32,
    pub count: u32,
}

pub struct Device {
    pub(crate) handle: vk::PhysicalDevice,
    pub(crate) device: ash::Device,
    properties: vk::PhysicalDeviceProperties,
    queues: Vec<QueueAllocation>,
}

impl Device {
    pub fn logical_device(&self) -> &ash::Device {
        &self.device
    }

    pub fn physical_device(&self) -> vk::PhysicalDevice {
        self.handle
    }

    pub fn properties(&self) -> &vk::PhysicalDeviceProperties {
        &self.properties
    }

    pub fn queues(&self) -> &[QueueAllocation] {
        &self.queues
    }

    pub fn get_queue(&self, family_index: u32, queue_index: u32) -> vk::Queue {
        unsafe { self.device.get_device_queue(family_index, queue_index) }
    }

    pub fn find_memory_type(
        &self,
        instance: &Instance,
        type_filter: u32,
        properties: vk::MemoryPropertyFlags,
    ) -> Result<u32, Error> {
        let memory_properties = unsafe {
            instance
                .instance()
                .get_physical_device_memory_properties(self.physical_device())
        };

        for (index, memory_type) in memory_properties.memory_types.iter().enumerate() {
            let is_allowed_by_hardware = (type_filter & (1 << index)) != 0;
            let has_required_properties = memory_type.property_flags.contains(properties);

            if is_allowed_by_hardware && has_required_properties {
                return Ok(index as u32);
            }
        }

        Err(Error::fatal(ErrorKind::NoCompatibleDevice))
    }

    pub fn pad_uniform_buffer_size(&self, original_size: usize) -> usize {
        let min_alignment = self.properties.limits.min_uniform_buffer_offset_alignment as usize;
        let mut aligned_size = original_size;

        if min_alignment > 0 {
            aligned_size = (aligned_size + min_alignment - 1) & !(min_alignment - 1);
        }

        aligned_size
    }

    pub fn destroy(&mut self) {
        unsafe {
            self.device.destroy_device(None);
        }
    }
}