use ash::vk;
use std::ffi::{CStr, CString};
use thiserror::Error;
#[cfg(debug_assertions)]
use ash::ext::debug_utils;
#[derive(Debug, Error)]
pub enum InstanceError {
#[error("Failed to create Vulkan instance: {0}")]
CreationFailed(vk::Result),
#[error("Extension not supported: {0}")]
ExtensionNotSupported(String),
#[error("Layer not supported: {0}")]
LayerNotSupported(String),
#[error("Validation layers not available")]
ValidationLayersNotAvailable,
#[error("Failed to load Vulkan entry point")]
LoadingFailed,
#[error("Failed to convert string: {0}")]
StringConversion(#[from] std::ffi::NulError),
}
#[derive(Debug, Clone)]
pub struct InstanceCreateInfo {
pub app_name: String,
pub app_version: u32,
pub enable_validation: bool,
pub extensions: Vec<String>,
}
impl Default for InstanceCreateInfo {
fn default() -> Self {
Self {
app_name: "shdrlib Application".to_string(),
app_version: vk::make_api_version(0, 1, 0, 0),
#[cfg(debug_assertions)]
enable_validation: true,
#[cfg(not(debug_assertions))]
enable_validation: false,
extensions: Vec::new(),
}
}
}
pub struct Instance {
instance: ash::Instance,
entry: ash::Entry,
#[cfg(debug_assertions)]
debug_utils: Option<debug_utils::Instance>,
#[cfg(debug_assertions)]
debug_messenger: Option<vk::DebugUtilsMessengerEXT>,
}
impl Instance {
pub fn new(create_info: InstanceCreateInfo) -> Result<Self, InstanceError> {
let entry = unsafe {
ash::Entry::load().map_err(|_| InstanceError::LoadingFailed)?
};
let app_name = CString::new(create_info.app_name.as_str())?;
let engine_name = CString::new("shdrlib")?;
let app_info = vk::ApplicationInfo {
p_application_name: app_name.as_ptr(),
application_version: create_info.app_version,
p_engine_name: engine_name.as_ptr(),
engine_version: vk::make_api_version(0, 0, 1, 0),
api_version: vk::API_VERSION_1_3,
..Default::default()
};
let mut extensions: Vec<CString> = create_info
.extensions
.iter()
.map(|s| CString::new(s.as_str()))
.collect::<Result<Vec<_>, _>>()?;
let validation_layer_name = CString::new("VK_LAYER_KHRONOS_validation")?;
let mut layers: Vec<*const i8> = Vec::new();
#[cfg(debug_assertions)]
if create_info.enable_validation {
let available_layers = unsafe {
entry
.enumerate_instance_layer_properties()
.map_err(InstanceError::CreationFailed)?
};
let validation_available = available_layers.iter().any(|layer| {
let name = unsafe { CStr::from_ptr(layer.layer_name.as_ptr()) };
name == validation_layer_name.as_c_str()
});
if !validation_available {
return Err(InstanceError::ValidationLayersNotAvailable);
}
layers.push(validation_layer_name.as_ptr());
let debug_utils_name = CString::new("VK_EXT_debug_utils")?;
extensions.push(debug_utils_name);
}
let extension_ptrs: Vec<*const i8> = extensions.iter().map(|s| s.as_ptr()).collect();
let create_info_vk = vk::InstanceCreateInfo {
p_application_info: &app_info,
enabled_extension_count: extension_ptrs.len() as u32,
pp_enabled_extension_names: extension_ptrs.as_ptr(),
enabled_layer_count: layers.len() as u32,
pp_enabled_layer_names: layers.as_ptr(),
..Default::default()
};
let instance = unsafe {
entry
.create_instance(&create_info_vk, None)
.map_err(InstanceError::CreationFailed)?
};
#[cfg(debug_assertions)]
let (debug_utils, debug_messenger) = if create_info.enable_validation {
let debug_utils = debug_utils::Instance::new(&entry, &instance);
let messenger_info = vk::DebugUtilsMessengerCreateInfoEXT {
message_severity: vk::DebugUtilsMessageSeverityFlagsEXT::ERROR
| vk::DebugUtilsMessageSeverityFlagsEXT::WARNING
| vk::DebugUtilsMessageSeverityFlagsEXT::INFO,
message_type: vk::DebugUtilsMessageTypeFlagsEXT::GENERAL
| vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION
| vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE,
pfn_user_callback: Some(debug_callback),
..Default::default()
};
let messenger = unsafe {
debug_utils
.create_debug_utils_messenger(&messenger_info, None)
.map_err(InstanceError::CreationFailed)?
};
(Some(debug_utils), Some(messenger))
} else {
(None, None)
};
Ok(Self {
instance,
entry,
#[cfg(debug_assertions)]
debug_utils,
#[cfg(debug_assertions)]
debug_messenger,
})
}
#[inline]
pub fn handle(&self) -> &ash::Instance {
&self.instance
}
#[inline]
pub fn entry(&self) -> &ash::Entry {
&self.entry
}
pub fn enumerate_physical_devices(&self) -> Result<Vec<vk::PhysicalDevice>, InstanceError> {
unsafe {
self.instance
.enumerate_physical_devices()
.map_err(InstanceError::CreationFailed)
}
}
#[inline]
pub unsafe fn get_physical_device_properties(
&self,
physical_device: vk::PhysicalDevice,
) -> vk::PhysicalDeviceProperties {
unsafe {
self.instance
.get_physical_device_properties(physical_device)
}
}
#[inline]
pub unsafe fn get_physical_device_features(
&self,
physical_device: vk::PhysicalDevice,
) -> vk::PhysicalDeviceFeatures {
unsafe {
self.instance
.get_physical_device_features(physical_device)
}
}
#[inline]
pub unsafe fn get_physical_device_memory_properties(
&self,
physical_device: vk::PhysicalDevice,
) -> vk::PhysicalDeviceMemoryProperties {
unsafe {
self.instance
.get_physical_device_memory_properties(physical_device)
}
}
#[inline]
pub unsafe fn get_physical_device_queue_family_properties(
&self,
physical_device: vk::PhysicalDevice,
) -> Vec<vk::QueueFamilyProperties> {
unsafe {
self.instance
.get_physical_device_queue_family_properties(physical_device)
}
}
}
impl Drop for Instance {
fn drop(&mut self) {
unsafe {
#[cfg(debug_assertions)]
{
if let (Some(debug_utils), Some(messenger)) =
(&self.debug_utils, self.debug_messenger)
{
debug_utils.destroy_debug_utils_messenger(messenger, None);
}
}
self.instance.destroy_instance(None);
}
}
}
#[cfg(debug_assertions)]
unsafe extern "system" fn debug_callback(
message_severity: vk::DebugUtilsMessageSeverityFlagsEXT,
message_type: vk::DebugUtilsMessageTypeFlagsEXT,
p_callback_data: *const vk::DebugUtilsMessengerCallbackDataEXT,
_user_data: *mut std::os::raw::c_void,
) -> vk::Bool32 {
let callback_data = unsafe { *p_callback_data };
let message = if callback_data.p_message.is_null() {
"".to_string()
} else {
unsafe {
CStr::from_ptr(callback_data.p_message)
.to_string_lossy()
.into_owned()
}
};
let severity = match message_severity {
vk::DebugUtilsMessageSeverityFlagsEXT::VERBOSE => "[VERBOSE]",
vk::DebugUtilsMessageSeverityFlagsEXT::INFO => "[INFO]",
vk::DebugUtilsMessageSeverityFlagsEXT::WARNING => "[WARNING]",
vk::DebugUtilsMessageSeverityFlagsEXT::ERROR => "[ERROR]",
_ => "[UNKNOWN]",
};
let msg_type = match message_type {
vk::DebugUtilsMessageTypeFlagsEXT::GENERAL => "GENERAL",
vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION => "VALIDATION",
vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE => "PERFORMANCE",
_ => "UNKNOWN",
};
eprintln!("{} [{}] {}", severity, msg_type, message);
vk::FALSE
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_instance_creation() {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false, ..Default::default()
});
assert!(instance.is_ok());
}
#[test]
fn test_enumerate_physical_devices() {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false,
..Default::default()
})
.unwrap();
let devices = instance.enumerate_physical_devices();
assert!(devices.is_ok());
assert!(!devices.unwrap().is_empty());
}
}