use crate::{feature_info::with_feature_info, nvsdk_ngx::*};
use ash::{Entry, vk::PhysicalDevice};
use std::{ffi::CStr, ptr, slice};
use uuid::Uuid;
use wgpu::{
Adapter, Device, DeviceDescriptor, Instance, InstanceDescriptor, Limits, Queue,
RequestDeviceError,
hal::{
DeviceError, InstanceError,
api::Vulkan,
vulkan::{CreateDeviceCallbackArgs, CreateInstanceCallbackArgs},
},
};
pub fn create_instance(
project_id: Uuid,
instance_descriptor: &InstanceDescriptor,
feature_support: &mut FeatureSupport,
) -> Result<Instance, InitializationError> {
unsafe {
let mut result = Ok(());
let raw_instance = wgpu::hal::vulkan::Instance::init_with_callback(
&wgpu::hal::InstanceDescriptor {
name: "wgpu",
flags: instance_descriptor.flags,
memory_budget_thresholds: instance_descriptor.memory_budget_thresholds,
backend_options: instance_descriptor.backend_options.clone(),
telemetry: None,
display: None,
},
Some(Box::new(|mut args| {
result = register_instance_extensions(project_id, &mut args, feature_support);
})),
)?;
result?;
Ok(Instance::from_hal::<Vulkan>(raw_instance))
}
}
pub fn register_instance_extensions(
project_id: Uuid,
args: &mut CreateInstanceCallbackArgs,
feature_support: &mut FeatureSupport,
) -> Result<(), RegisterInstanceExtensionsError> {
let mut result = Ok(());
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
result
}
pub fn request_device(
project_id: Uuid,
adapter: &Adapter,
device_descriptor: &DeviceDescriptor,
feature_support: &mut FeatureSupport,
limits: Option<Limits>,
) -> Result<(Device, Queue), InitializationError> {
unsafe {
let raw_adapter = adapter
.as_hal::<Vulkan>()
.ok_or(InitializationError::UnsupportedBackend)?;
let mut result = Ok(());
let open_device = raw_adapter.open_with_callback(
device_descriptor.required_features,
&limits.unwrap_or(adapter.limits()),
&device_descriptor.memory_hints,
Some(Box::new(|mut args| {
result = register_device_extensions(
project_id,
&mut args,
&raw_adapter,
feature_support,
);
})),
)?;
result?;
Ok(adapter.create_device_from_hal::<Vulkan>(open_device, device_descriptor)?)
}
}
pub fn register_device_extensions(
project_id: Uuid,
args: &mut CreateDeviceCallbackArgs,
raw_adapter: &wgpu::hal::vulkan::Adapter,
feature_support: &mut FeatureSupport,
) -> Result<(), RegisterInstanceExtensionsError> {
let raw_instance = raw_adapter.shared_instance().raw_instance();
let raw_physical_device = raw_adapter.raw_physical_device();
let mut result = Ok(());
match required_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};
match required_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
result
}
fn required_instance_extensions(
project_id: Uuid,
feature_id: NVSDK_NGX_Feature,
entry: &Entry,
) -> Result<(impl Iterator<Item = &'static CStr>, bool), RegisterInstanceExtensionsError> {
with_feature_info(project_id, feature_id, |feature_info| unsafe {
let mut required_extensions = ptr::null_mut();
let mut required_extension_count = 0;
check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureInstanceExtensionRequirements(
feature_info,
&mut required_extension_count,
&mut required_extensions,
))?;
let required_extensions =
slice::from_raw_parts(required_extensions, required_extension_count as usize);
let required_extensions = required_extensions
.iter()
.map(|extension| CStr::from_ptr(extension.extension_name.as_ptr()));
let supported_extensions = entry.enumerate_instance_extension_properties(None)?;
let extensions_supported = required_extensions.clone().all(|required_extension| {
supported_extensions
.iter()
.any(|extension| extension.extension_name_as_c_str() == Ok(required_extension))
});
Ok((required_extensions, extensions_supported))
})
}
fn required_device_extensions(
project_id: Uuid,
feature_id: NVSDK_NGX_Feature,
raw_adapter: &wgpu::hal::vulkan::Adapter,
raw_instance: ash::vk::Instance,
raw_physical_device: PhysicalDevice,
) -> Result<(impl Iterator<Item = &'static CStr>, bool), RegisterInstanceExtensionsError> {
with_feature_info(project_id, feature_id, |feature_info| unsafe {
let mut required_extensions = ptr::null_mut();
let mut required_extension_count = 0;
check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureDeviceExtensionRequirements(
raw_instance,
raw_physical_device,
feature_info,
&mut required_extension_count,
&mut required_extensions,
))?;
let required_extensions =
slice::from_raw_parts(required_extensions, required_extension_count as usize);
let required_extensions = required_extensions
.iter()
.map(|extension| CStr::from_ptr(extension.extension_name.as_ptr()));
let extensions_supported = required_extensions.clone().all(|required_extension| {
raw_adapter
.physical_device_capabilities()
.supports_extension(required_extension)
});
Ok((required_extensions, extensions_supported))
})
}
pub struct FeatureSupport {
pub super_resolution_supported: bool,
pub ray_reconstruction_supported: bool,
}
impl Default for FeatureSupport {
fn default() -> Self {
Self {
super_resolution_supported: true,
ray_reconstruction_supported: true,
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum InitializationError {
#[error(transparent)]
InstanceError(#[from] InstanceError),
#[error(transparent)]
RequestDeviceError(#[from] RequestDeviceError),
#[error(transparent)]
DeviceError(#[from] DeviceError),
#[error(transparent)]
VulkanError(#[from] ash::vk::Result),
#[error(transparent)]
DlssError(#[from] DlssError),
#[error("Provided adapter is not using the Vulkan backend")]
UnsupportedBackend,
}
#[derive(thiserror::Error, Debug)]
pub enum RegisterInstanceExtensionsError {
#[error(transparent)]
VulkanError(#[from] ash::vk::Result),
#[error(transparent)]
DlssError(#[from] DlssError),
}
impl From<RegisterInstanceExtensionsError> for InitializationError {
fn from(value: RegisterInstanceExtensionsError) -> Self {
match value {
RegisterInstanceExtensionsError::VulkanError(err) => {
InitializationError::VulkanError(err)
}
RegisterInstanceExtensionsError::DlssError(dlss_error) => {
InitializationError::DlssError(dlss_error)
}
}
}
}