use crate::core::Instance;
use ash::vk;
use std::ffi::{CStr, CString};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DeviceError {
#[error("Failed to create device: {0}")]
CreationFailed(vk::Result),
#[error("No suitable physical device found")]
NoSuitableDevice,
#[error("Extension not supported: {0}")]
ExtensionNotSupported(String),
#[error("Memory type not found")]
MemoryTypeNotFound,
#[error("Queue family not found")]
QueueFamilyNotFound,
#[error("Failed to convert string: {0}")]
StringConversion(#[from] std::ffi::NulError),
}
#[derive(Debug, Clone)]
pub struct PhysicalDeviceInfo {
pub device: vk::PhysicalDevice,
pub properties: vk::PhysicalDeviceProperties,
pub features: vk::PhysicalDeviceFeatures,
pub queue_families: Vec<vk::QueueFamilyProperties>,
pub memory_properties: vk::PhysicalDeviceMemoryProperties,
}
impl PhysicalDeviceInfo {
pub unsafe fn from_device(instance: &Instance, physical_device: vk::PhysicalDevice) -> Self {
unsafe {
Self {
device: physical_device,
properties: instance.get_physical_device_properties(physical_device),
features: instance.get_physical_device_features(physical_device),
queue_families: instance
.get_physical_device_queue_family_properties(physical_device),
memory_properties: instance.get_physical_device_memory_properties(physical_device),
}
}
}
pub fn device_name(&self) -> String {
unsafe {
CStr::from_ptr(self.properties.device_name.as_ptr())
.to_string_lossy()
.into_owned()
}
}
#[inline]
pub fn is_discrete_gpu(&self) -> bool {
self.properties.device_type == vk::PhysicalDeviceType::DISCRETE_GPU
}
pub fn supports_queue_family(&self, flags: vk::QueueFlags) -> bool {
self.queue_families
.iter()
.any(|qf| qf.queue_flags.contains(flags))
}
}
#[derive(Debug, Clone, Default)]
pub struct DeviceCreateInfo {
pub extensions: Vec<String>,
pub features: vk::PhysicalDeviceFeatures,
pub queue_create_infos: Vec<QueueCreateInfo>,
}
#[derive(Debug, Clone)]
pub struct QueueCreateInfo {
pub queue_family_index: u32,
pub queue_count: u32,
pub queue_priorities: Vec<f32>,
}
pub struct Device {
logical_device: ash::Device,
physical_device: vk::PhysicalDevice,
memory_properties: vk::PhysicalDeviceMemoryProperties,
queue_families: Vec<vk::QueueFamilyProperties>,
}
impl Device {
pub fn new(
instance: &Instance,
physical_device: vk::PhysicalDevice,
create_info: DeviceCreateInfo,
) -> Result<Self, DeviceError> {
let memory_properties =
unsafe { instance.get_physical_device_memory_properties(physical_device) };
let queue_families =
unsafe { instance.get_physical_device_queue_family_properties(physical_device) };
let extensions: Vec<CString> = create_info
.extensions
.iter()
.map(|s| CString::new(s.as_str()))
.collect::<Result<Vec<_>, _>>()?;
let extension_ptrs: Vec<*const i8> = extensions.iter().map(|s| s.as_ptr()).collect();
let queue_create_infos: Vec<vk::DeviceQueueCreateInfo> = create_info
.queue_create_infos
.iter()
.map(|qci| vk::DeviceQueueCreateInfo {
queue_family_index: qci.queue_family_index,
queue_count: qci.queue_count,
p_queue_priorities: qci.queue_priorities.as_ptr(),
..Default::default()
})
.collect();
let device_create_info = vk::DeviceCreateInfo {
queue_create_info_count: queue_create_infos.len() as u32,
p_queue_create_infos: queue_create_infos.as_ptr(),
enabled_extension_count: extension_ptrs.len() as u32,
pp_enabled_extension_names: extension_ptrs.as_ptr(),
p_enabled_features: &create_info.features,
..Default::default()
};
let logical_device = unsafe {
instance
.handle()
.create_device(physical_device, &device_create_info, None)
.map_err(DeviceError::CreationFailed)?
};
Ok(Self {
logical_device,
physical_device,
memory_properties,
queue_families,
})
}
#[inline]
pub fn handle(&self) -> &ash::Device {
&self.logical_device
}
#[inline]
pub fn physical_device(&self) -> vk::PhysicalDevice {
self.physical_device
}
#[inline]
pub fn memory_properties(&self) -> &vk::PhysicalDeviceMemoryProperties {
&self.memory_properties
}
#[inline]
pub fn queue_families(&self) -> &[vk::QueueFamilyProperties] {
&self.queue_families
}
pub fn find_memory_type(
&self,
type_filter: u32,
properties: vk::MemoryPropertyFlags,
) -> Option<u32> {
for i in 0..self.memory_properties.memory_type_count {
let memory_type = &self.memory_properties.memory_types[i as usize];
if (type_filter & (1 << i)) != 0 && memory_type.property_flags.contains(properties) {
return Some(i);
}
}
None
}
pub fn find_queue_family(&self, flags: vk::QueueFlags) -> Option<u32> {
self.queue_families
.iter()
.enumerate()
.find(|(_, qf)| qf.queue_flags.contains(flags))
.map(|(i, _)| i as u32)
}
pub fn wait_idle(&self) -> Result<(), DeviceError> {
unsafe {
self.logical_device
.device_wait_idle()
.map_err(DeviceError::CreationFailed)
}
}
}
impl Drop for Device {
fn drop(&mut self) {
unsafe {
self.logical_device.destroy_device(None);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::InstanceCreateInfo;
#[test]
fn test_device_creation() {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false,
..Default::default()
})
.unwrap();
let physical_devices = instance.enumerate_physical_devices().unwrap();
assert!(!physical_devices.is_empty());
let physical_device = physical_devices[0];
let device_info = unsafe { PhysicalDeviceInfo::from_device(&instance, physical_device) };
println!("Device: {}", device_info.device_name());
println!("Discrete GPU: {}", device_info.is_discrete_gpu());
let graphics_family = device_info
.queue_families
.iter()
.enumerate()
.find(|(_, qf)| qf.queue_flags.contains(vk::QueueFlags::GRAPHICS))
.map(|(i, _)| i as u32);
assert!(graphics_family.is_some());
let device = Device::new(
&instance,
physical_device,
DeviceCreateInfo {
extensions: Vec::new(),
queue_create_infos: vec![QueueCreateInfo {
queue_family_index: graphics_family.unwrap(),
queue_count: 1,
queue_priorities: vec![1.0],
}],
..Default::default()
},
);
assert!(device.is_ok());
}
#[test]
fn test_find_memory_type() {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false,
..Default::default()
})
.unwrap();
let physical_devices = instance.enumerate_physical_devices().unwrap();
let physical_device = physical_devices[0];
let graphics_family = unsafe {
instance
.get_physical_device_queue_family_properties(physical_device)
.iter()
.enumerate()
.find(|(_, qf)| qf.queue_flags.contains(vk::QueueFlags::GRAPHICS))
.map(|(i, _)| i as u32)
.unwrap()
};
let device = Device::new(
&instance,
physical_device,
DeviceCreateInfo {
queue_create_infos: vec![QueueCreateInfo {
queue_family_index: graphics_family,
queue_count: 1,
queue_priorities: vec![1.0],
}],
..Default::default()
},
)
.unwrap();
let memory_type = device.find_memory_type(
u32::MAX,
vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT,
);
assert!(memory_type.is_some());
}
}