use crate::check_errors;
use crate::descriptor_set::layout::DescriptorType;
use crate::device::Device;
use crate::format::Format;
use crate::image::view::ImageViewType;
use crate::pipeline::graphics::input_assembly::PrimitiveTopology;
use crate::pipeline::layout::PipelineLayoutPcRange;
use crate::shader::spirv::{Capability, Spirv, SpirvError};
use crate::sync::PipelineStages;
use crate::DeviceSize;
use crate::OomError;
use crate::Version;
use crate::VulkanObject;
use fnv::FnvHashMap;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::error;
use std::error::Error;
use std::ffi::CStr;
use std::ffi::CString;
use std::fmt;
use std::fmt::Display;
use std::mem;
use std::mem::MaybeUninit;
use std::ops::BitOr;
use std::ptr;
use std::sync::Arc;
pub mod reflect;
pub mod spirv;
use spirv::ExecutionModel;
include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
#[derive(Debug)]
pub struct ShaderModule {
handle: ash::vk::ShaderModule,
device: Arc<Device>,
entry_points: HashMap<String, HashMap<ExecutionModel, EntryPointInfo>>,
}
impl ShaderModule {
pub unsafe fn from_words(
device: Arc<Device>,
words: &[u32],
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
let spirv = Spirv::new(words)?;
Self::from_words_with_data(
device,
words,
spirv.version(),
reflect::spirv_capabilities(&spirv),
reflect::spirv_extensions(&spirv),
reflect::entry_points(&spirv, false),
)
}
pub unsafe fn from_bytes(
device: Arc<Device>,
bytes: &[u8],
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
assert!((bytes.len() % 4) == 0);
Self::from_words(
device,
std::slice::from_raw_parts(
bytes.as_ptr() as *const _,
bytes.len() / mem::size_of::<u32>(),
),
)
}
pub unsafe fn from_words_with_data<'a>(
device: Arc<Device>,
words: &[u32],
spirv_version: Version,
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
spirv_extensions: impl IntoIterator<Item = &'a str>,
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
if let Err(reason) = check_spirv_version(&device, spirv_version) {
return Err(ShaderCreationError::SpirvVersionNotSupported {
version: spirv_version,
reason,
});
}
for capability in spirv_capabilities {
if let Err(reason) = check_spirv_capability(&device, capability.clone()) {
return Err(ShaderCreationError::SpirvCapabilityNotSupported {
capability: capability.clone(),
reason,
});
}
}
for extension in spirv_extensions {
if let Err(reason) = check_spirv_extension(&device, extension) {
return Err(ShaderCreationError::SpirvExtensionNotSupported {
extension: extension.to_owned(),
reason,
});
}
}
let handle = {
let infos = ash::vk::ShaderModuleCreateInfo {
flags: ash::vk::ShaderModuleCreateFlags::empty(),
code_size: words.len() * mem::size_of::<u32>(),
p_code: words.as_ptr(),
..Default::default()
};
let fns = device.fns();
let mut output = MaybeUninit::uninit();
check_errors(fns.v1_0.create_shader_module(
device.internal_object(),
&infos,
ptr::null(),
output.as_mut_ptr(),
))?;
output.assume_init()
};
let entries = entry_points.into_iter().collect::<Vec<_>>();
let entry_points = entries
.iter()
.filter_map(|(name, _, _)| Some(name))
.collect::<HashSet<_>>()
.iter()
.map(|name| {
((*name).clone(),
entries.iter().filter_map(|(entry_name, entry_model, info)| {
if &entry_name == name {
Some((*entry_model, info.clone()))
} else {
None
}
}).collect::<HashMap<_, _>>()
)
})
.collect();
Ok(Arc::new(ShaderModule {
handle,
device,
entry_points,
}))
}
pub unsafe fn from_bytes_with_data<'a>(
device: Arc<Device>,
bytes: &[u8],
spirv_version: Version,
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
spirv_extensions: impl IntoIterator<Item = &'a str>,
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
assert!((bytes.len() % 4) == 0);
Self::from_words_with_data(
device,
std::slice::from_raw_parts(
bytes.as_ptr() as *const _,
bytes.len() / mem::size_of::<u32>(),
),
spirv_version,
spirv_capabilities,
spirv_extensions,
entry_points,
)
}
pub fn entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>> {
self.entry_points.get(name).and_then(|infos| {
if infos.len() == 1 {
infos.iter().next().map(|(_, info)| EntryPoint {
module: self,
name: CString::new(name).unwrap(),
info,
})
} else {
None
}
})
}
pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option<EntryPoint<'a>> {
self.entry_points.get(name).and_then(|infos| {
infos.get(&execution).map(|info| EntryPoint {
module: self,
name: CString::new(name).unwrap(),
info,
})
})
}
}
unsafe impl VulkanObject for ShaderModule {
type Object = ash::vk::ShaderModule;
#[inline]
fn internal_object(&self) -> ash::vk::ShaderModule {
self.handle
}
}
impl Drop for ShaderModule {
#[inline]
fn drop(&mut self) {
unsafe {
let fns = self.device.fns();
fns.v1_0
.destroy_shader_module(self.device.internal_object(), self.handle, ptr::null());
}
}
}
#[derive(Clone, Debug)]
pub enum ShaderCreationError {
OomError(OomError),
SpirvCapabilityNotSupported {
capability: Capability,
reason: ShaderSupportError,
},
SpirvError(SpirvError),
SpirvExtensionNotSupported {
extension: String,
reason: ShaderSupportError,
},
SpirvVersionNotSupported {
version: Version,
reason: ShaderSupportError,
},
}
impl Error for ShaderCreationError {
#[inline]
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::OomError(err) => Some(err),
Self::SpirvCapabilityNotSupported { reason, .. } => Some(reason),
Self::SpirvError(err) => Some(err),
Self::SpirvExtensionNotSupported { reason, .. } => Some(reason),
Self::SpirvVersionNotSupported { reason, .. } => Some(reason),
}
}
}
impl Display for ShaderCreationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::OomError(_) => write!(f, "not enough memory available"),
Self::SpirvCapabilityNotSupported { capability, .. } => write!(
f,
"the SPIR-V capability {:?} enabled by the shader is not supported by the device",
capability,
),
Self::SpirvError(_) => write!(f, "the SPIR-V module could not be read"),
Self::SpirvExtensionNotSupported { extension, .. } => write!(
f,
"the SPIR-V extension {} enabled by the shader is not supported by the device",
extension,
),
Self::SpirvVersionNotSupported { version, .. } => write!(
f,
"the shader uses SPIR-V version {}.{}, which is not supported by the device",
version.major, version.minor,
),
}
}
}
impl From<crate::Error> for ShaderCreationError {
fn from(err: crate::Error) -> Self {
Self::OomError(err.into())
}
}
impl From<SpirvError> for ShaderCreationError {
fn from(err: SpirvError) -> Self {
Self::SpirvError(err)
}
}
#[derive(Clone, Copy, Debug)]
pub enum ShaderSupportError {
NotSupportedByVulkan,
RequirementsNotMet(&'static [&'static str]),
}
impl Error for ShaderSupportError {}
impl Display for ShaderSupportError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotSupportedByVulkan => write!(f, "not supported by Vulkan"),
Self::RequirementsNotMet(requirements) => write!(
f,
"at least one of the following must be available/enabled on the device: {}",
requirements.join(", "),
),
}
}
}
#[derive(Clone, Debug)]
pub struct EntryPointInfo {
pub execution: ShaderExecution,
pub descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>,
pub push_constant_requirements: Option<PipelineLayoutPcRange>,
pub specialization_constant_requirements: FnvHashMap<u32, SpecializationConstantRequirements>,
pub input_interface: ShaderInterface,
pub output_interface: ShaderInterface,
}
#[derive(Clone, Debug)]
pub struct EntryPoint<'a> {
module: &'a ShaderModule,
name: CString,
info: &'a EntryPointInfo,
}
impl<'a> EntryPoint<'a> {
#[inline]
pub fn module(&self) -> &'a ShaderModule {
self.module
}
#[inline]
pub fn name(&self) -> &CStr {
&self.name
}
#[inline]
pub fn execution(&self) -> &ShaderExecution {
&self.info.execution
}
#[inline]
pub fn descriptor_requirements(
&self,
) -> impl ExactSizeIterator<Item = ((u32, u32), &DescriptorRequirements)> {
self.info
.descriptor_requirements
.iter()
.map(|(k, v)| (*k, v))
}
#[inline]
pub fn push_constant_requirements(&self) -> Option<&PipelineLayoutPcRange> {
self.info.push_constant_requirements.as_ref()
}
#[inline]
pub fn specialization_constant_requirements(
&self,
) -> impl ExactSizeIterator<Item = (u32, &SpecializationConstantRequirements)> {
self.info
.specialization_constant_requirements
.iter()
.map(|(k, v)| (*k, v))
}
#[inline]
pub fn input_interface(&self) -> &ShaderInterface {
&self.info.input_interface
}
#[inline]
pub fn output_interface(&self) -> &ShaderInterface {
&self.info.output_interface
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ShaderExecution {
Vertex,
TessellationControl,
TessellationEvaluation,
Geometry(GeometryShaderExecution),
Fragment,
Compute,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GeometryShaderExecution {
pub input: GeometryShaderInput,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}
impl GeometryShaderInput {
#[inline]
pub fn is_compatible_with(&self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct DescriptorRequirements {
pub descriptor_types: Vec<DescriptorType>,
pub descriptor_count: u32,
pub format: Option<Format>,
pub image_view_type: Option<ImageViewType>,
pub multisampled: bool,
pub mutable: bool,
pub stages: ShaderStages,
}
impl DescriptorRequirements {
pub fn intersection(&self, other: &Self) -> Result<Self, DescriptorRequirementsIncompatible> {
let descriptor_types: Vec<_> = self
.descriptor_types
.iter()
.copied()
.filter(|ty| other.descriptor_types.contains(&ty))
.collect();
if descriptor_types.is_empty() {
return Err(DescriptorRequirementsIncompatible::DescriptorType);
}
if let (Some(first), Some(second)) = (self.format, other.format) {
if first != second {
return Err(DescriptorRequirementsIncompatible::Format);
}
}
if let (Some(first), Some(second)) = (self.image_view_type, other.image_view_type) {
if first != second {
return Err(DescriptorRequirementsIncompatible::ImageViewType);
}
}
if self.multisampled != other.multisampled {
return Err(DescriptorRequirementsIncompatible::Multisampled);
}
Ok(Self {
descriptor_types,
descriptor_count: self.descriptor_count.max(other.descriptor_count),
format: self.format.or(other.format),
image_view_type: self.image_view_type.or(other.image_view_type),
multisampled: self.multisampled,
mutable: self.mutable || other.mutable,
stages: self.stages | other.stages,
})
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DescriptorRequirementsIncompatible {
DescriptorType,
Format,
ImageViewType,
Multisampled,
}
impl Error for DescriptorRequirementsIncompatible {}
impl Display for DescriptorRequirementsIncompatible {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DescriptorRequirementsIncompatible::DescriptorType => {
write!(
f,
"the allowed descriptor types of the two descriptors do not overlap"
)
}
DescriptorRequirementsIncompatible::Format => {
write!(f, "the descriptors require different formats")
}
DescriptorRequirementsIncompatible::ImageViewType => {
write!(f, "the descriptors require different image view types")
}
DescriptorRequirementsIncompatible::Multisampled => {
write!(
f,
"the multisampling requirements of the descriptors differ"
)
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct SpecializationConstantRequirements {
pub size: DeviceSize,
}
pub unsafe trait SpecializationConstants {
fn descriptors() -> &'static [SpecializationMapEntry];
}
unsafe impl SpecializationConstants for () {
#[inline]
fn descriptors() -> &'static [SpecializationMapEntry] {
&[]
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(C)]
pub struct SpecializationMapEntry {
pub constant_id: u32,
pub offset: u32,
pub size: usize,
}
impl From<SpecializationMapEntry> for ash::vk::SpecializationMapEntry {
#[inline]
fn from(val: SpecializationMapEntry) -> Self {
Self {
constant_id: val.constant_id,
offset: val.offset,
size: val.size,
}
}
}
#[derive(Clone, Debug)]
pub struct ShaderInterface {
elements: Vec<ShaderInterfaceEntry>,
}
impl ShaderInterface {
#[inline]
pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface {
ShaderInterface { elements }
}
pub const fn empty() -> ShaderInterface {
ShaderInterface {
elements: Vec::new(),
}
}
#[inline]
pub fn elements(&self) -> &[ShaderInterfaceEntry] {
self.elements.as_ref()
}
pub fn matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError> {
if self.elements().len() != other.elements().len() {
return Err(ShaderInterfaceMismatchError::ElementsCountMismatch {
self_elements: self.elements().len() as u32,
other_elements: other.elements().len() as u32,
});
}
for a in self.elements() {
let location_range = a.location..a.location + a.ty.num_locations();
for loc in location_range {
let b = match other
.elements()
.iter()
.find(|e| loc >= e.location && loc < e.location + e.ty.num_locations())
{
None => {
return Err(ShaderInterfaceMismatchError::MissingElement { location: loc })
}
Some(b) => b,
};
if a.ty != b.ty {
return Err(ShaderInterfaceMismatchError::TypeMismatch {
location: loc,
self_ty: a.ty,
other_ty: b.ty,
});
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ShaderInterfaceEntry {
pub location: u32,
pub component: u32,
pub name: Option<Cow<'static, str>>,
pub ty: ShaderInterfaceEntryType,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ShaderInterfaceEntryType {
pub base_type: ShaderScalarType,
pub num_components: u32,
pub num_elements: u32,
pub is_64bit: bool,
}
impl ShaderInterfaceEntryType {
pub(crate) fn to_format(&self) -> Format {
assert!(!self.is_64bit); match self.base_type {
ShaderScalarType::Float => match self.num_components {
1 => Format::R32_SFLOAT,
2 => Format::R32G32_SFLOAT,
3 => Format::R32G32B32_SFLOAT,
4 => Format::R32G32B32A32_SFLOAT,
_ => unreachable!(),
},
ShaderScalarType::Sint => match self.num_components {
1 => Format::R32_SINT,
2 => Format::R32G32_SINT,
3 => Format::R32G32B32_SINT,
4 => Format::R32G32B32A32_SINT,
_ => unreachable!(),
},
ShaderScalarType::Uint => match self.num_components {
1 => Format::R32_UINT,
2 => Format::R32G32_UINT,
3 => Format::R32G32B32_UINT,
4 => Format::R32G32B32A32_UINT,
_ => unreachable!(),
},
}
}
pub(crate) fn num_locations(&self) -> u32 {
assert!(!self.is_64bit); self.num_elements
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ShaderScalarType {
Float,
Sint,
Uint,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ShaderInterfaceMismatchError {
ElementsCountMismatch {
self_elements: u32,
other_elements: u32,
},
MissingElement {
location: u32,
},
TypeMismatch {
location: u32,
self_ty: ShaderInterfaceEntryType,
other_ty: ShaderInterfaceEntryType,
},
}
impl error::Error for ShaderInterfaceMismatchError {}
impl fmt::Display for ShaderInterfaceMismatchError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
fmt,
"{}",
match *self {
ShaderInterfaceMismatchError::ElementsCountMismatch { .. } => {
"the number of elements mismatches"
}
ShaderInterfaceMismatchError::MissingElement { .. } => "an element is missing",
ShaderInterfaceMismatchError::TypeMismatch { .. } => {
"the type of an element does not match"
}
}
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ShaderStage {
Vertex = ash::vk::ShaderStageFlags::VERTEX.as_raw(),
TessellationControl = ash::vk::ShaderStageFlags::TESSELLATION_CONTROL.as_raw(),
TessellationEvaluation = ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION.as_raw(),
Geometry = ash::vk::ShaderStageFlags::GEOMETRY.as_raw(),
Fragment = ash::vk::ShaderStageFlags::FRAGMENT.as_raw(),
Compute = ash::vk::ShaderStageFlags::COMPUTE.as_raw(),
Raygen = ash::vk::ShaderStageFlags::RAYGEN_KHR.as_raw(),
AnyHit = ash::vk::ShaderStageFlags::ANY_HIT_KHR.as_raw(),
ClosestHit = ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR.as_raw(),
Miss = ash::vk::ShaderStageFlags::MISS_KHR.as_raw(),
Intersection = ash::vk::ShaderStageFlags::INTERSECTION_KHR.as_raw(),
Callable = ash::vk::ShaderStageFlags::CALLABLE_KHR.as_raw(),
}
impl From<ShaderExecution> for ShaderStage {
#[inline]
fn from(val: ShaderExecution) -> Self {
match val {
ShaderExecution::Vertex => Self::Vertex,
ShaderExecution::TessellationControl => Self::TessellationControl,
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment => Self::Fragment,
ShaderExecution::Compute => Self::Compute,
}
}
}
impl From<ShaderStage> for ShaderStages {
#[inline]
fn from(val: ShaderStage) -> Self {
match val {
ShaderStage::Vertex => Self {
vertex: true,
..Self::none()
},
ShaderStage::TessellationControl => Self {
tessellation_control: true,
..Self::none()
},
ShaderStage::TessellationEvaluation => Self {
tessellation_evaluation: true,
..Self::none()
},
ShaderStage::Geometry => Self {
geometry: true,
..Self::none()
},
ShaderStage::Fragment => Self {
fragment: true,
..Self::none()
},
ShaderStage::Compute => Self {
compute: true,
..Self::none()
},
ShaderStage::Raygen => Self {
raygen: true,
..Self::none()
},
ShaderStage::AnyHit => Self {
any_hit: true,
..Self::none()
},
ShaderStage::ClosestHit => Self {
closest_hit: true,
..Self::none()
},
ShaderStage::Miss => Self {
miss: true,
..Self::none()
},
ShaderStage::Intersection => Self {
intersection: true,
..Self::none()
},
ShaderStage::Callable => Self {
callable: true,
..Self::none()
},
}
}
}
impl From<ShaderStage> for ash::vk::ShaderStageFlags {
#[inline]
fn from(val: ShaderStage) -> Self {
Self::from_raw(val as u32)
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ShaderStages {
pub vertex: bool,
pub tessellation_control: bool,
pub tessellation_evaluation: bool,
pub geometry: bool,
pub fragment: bool,
pub compute: bool,
pub raygen: bool,
pub any_hit: bool,
pub closest_hit: bool,
pub miss: bool,
pub intersection: bool,
pub callable: bool,
}
impl ShaderStages {
#[inline]
pub const fn all() -> ShaderStages {
ShaderStages {
vertex: true,
tessellation_control: true,
tessellation_evaluation: true,
geometry: true,
fragment: true,
compute: true,
raygen: true,
any_hit: true,
closest_hit: true,
miss: true,
intersection: true,
callable: true,
}
}
#[inline]
pub const fn none() -> ShaderStages {
ShaderStages {
vertex: false,
tessellation_control: false,
tessellation_evaluation: false,
geometry: false,
fragment: false,
compute: false,
raygen: false,
any_hit: false,
closest_hit: false,
miss: false,
intersection: false,
callable: false,
}
}
#[inline]
pub const fn all_graphics() -> ShaderStages {
ShaderStages {
vertex: true,
tessellation_control: true,
tessellation_evaluation: true,
geometry: true,
fragment: true,
..ShaderStages::none()
}
}
#[inline]
pub const fn compute() -> ShaderStages {
ShaderStages {
compute: true,
..ShaderStages::none()
}
}
#[inline]
pub const fn is_superset_of(&self, other: &ShaderStages) -> bool {
let Self {
vertex,
tessellation_control,
tessellation_evaluation,
geometry,
fragment,
compute,
raygen,
any_hit,
closest_hit,
miss,
intersection,
callable,
} = *self;
(vertex || !other.vertex)
&& (tessellation_control || !other.tessellation_control)
&& (tessellation_evaluation || !other.tessellation_evaluation)
&& (geometry || !other.geometry)
&& (fragment || !other.fragment)
&& (compute || !other.compute)
&& (raygen || !other.raygen)
&& (any_hit || !other.any_hit)
&& (closest_hit || !other.closest_hit)
&& (miss || !other.miss)
&& (intersection || !other.intersection)
&& (callable || !other.callable)
}
#[inline]
pub const fn intersects(&self, other: &ShaderStages) -> bool {
let Self {
vertex,
tessellation_control,
tessellation_evaluation,
geometry,
fragment,
compute,
raygen,
any_hit,
closest_hit,
miss,
intersection,
callable,
} = *self;
(vertex && other.vertex)
|| (tessellation_control && other.tessellation_control)
|| (tessellation_evaluation && other.tessellation_evaluation)
|| (geometry && other.geometry)
|| (fragment && other.fragment)
|| (compute && other.compute)
|| (raygen && other.raygen)
|| (any_hit && other.any_hit)
|| (closest_hit && other.closest_hit)
|| (miss && other.miss)
|| (intersection && other.intersection)
|| (callable && other.callable)
}
#[inline]
pub const fn union(&self, other: &Self) -> Self {
Self {
vertex: self.vertex || other.vertex,
tessellation_control: self.tessellation_control || other.tessellation_control,
tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation,
geometry: self.geometry || other.geometry,
fragment: self.fragment || other.fragment,
compute: self.compute || other.compute,
raygen: self.raygen || other.raygen,
any_hit: self.any_hit || other.any_hit,
closest_hit: self.closest_hit || other.closest_hit,
miss: self.miss || other.miss,
intersection: self.intersection || other.intersection,
callable: self.callable || other.callable,
}
}
}
impl From<ShaderStages> for ash::vk::ShaderStageFlags {
#[inline]
fn from(val: ShaderStages) -> Self {
let mut result = ash::vk::ShaderStageFlags::empty();
let ShaderStages {
vertex,
tessellation_control,
tessellation_evaluation,
geometry,
fragment,
compute,
raygen,
any_hit,
closest_hit,
miss,
intersection,
callable,
} = val;
if vertex {
result |= ash::vk::ShaderStageFlags::VERTEX;
}
if tessellation_control {
result |= ash::vk::ShaderStageFlags::TESSELLATION_CONTROL;
}
if tessellation_evaluation {
result |= ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION;
}
if geometry {
result |= ash::vk::ShaderStageFlags::GEOMETRY;
}
if fragment {
result |= ash::vk::ShaderStageFlags::FRAGMENT;
}
if compute {
result |= ash::vk::ShaderStageFlags::COMPUTE;
}
if raygen {
result |= ash::vk::ShaderStageFlags::RAYGEN_KHR;
}
if any_hit {
result |= ash::vk::ShaderStageFlags::ANY_HIT_KHR;
}
if closest_hit {
result |= ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR;
}
if miss {
result |= ash::vk::ShaderStageFlags::MISS_KHR;
}
if intersection {
result |= ash::vk::ShaderStageFlags::INTERSECTION_KHR;
}
if callable {
result |= ash::vk::ShaderStageFlags::CALLABLE_KHR;
}
result
}
}
impl From<ash::vk::ShaderStageFlags> for ShaderStages {
#[inline]
fn from(val: ash::vk::ShaderStageFlags) -> Self {
Self {
vertex: val.intersects(ash::vk::ShaderStageFlags::VERTEX),
tessellation_control: val.intersects(ash::vk::ShaderStageFlags::TESSELLATION_CONTROL),
tessellation_evaluation: val
.intersects(ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION),
geometry: val.intersects(ash::vk::ShaderStageFlags::GEOMETRY),
fragment: val.intersects(ash::vk::ShaderStageFlags::FRAGMENT),
compute: val.intersects(ash::vk::ShaderStageFlags::COMPUTE),
raygen: val.intersects(ash::vk::ShaderStageFlags::RAYGEN_KHR),
any_hit: val.intersects(ash::vk::ShaderStageFlags::ANY_HIT_KHR),
closest_hit: val.intersects(ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR),
miss: val.intersects(ash::vk::ShaderStageFlags::MISS_KHR),
intersection: val.intersects(ash::vk::ShaderStageFlags::INTERSECTION_KHR),
callable: val.intersects(ash::vk::ShaderStageFlags::CALLABLE_KHR),
}
}
}
impl BitOr for ShaderStages {
type Output = ShaderStages;
#[inline]
fn bitor(self, other: ShaderStages) -> ShaderStages {
ShaderStages {
vertex: self.vertex || other.vertex,
tessellation_control: self.tessellation_control || other.tessellation_control,
tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation,
geometry: self.geometry || other.geometry,
fragment: self.fragment || other.fragment,
compute: self.compute || other.compute,
raygen: self.raygen || other.raygen,
any_hit: self.any_hit || other.any_hit,
closest_hit: self.closest_hit || other.closest_hit,
miss: self.miss || other.miss,
intersection: self.intersection || other.intersection,
callable: self.callable || other.callable,
}
}
}
impl From<ShaderStages> for PipelineStages {
#[inline]
fn from(stages: ShaderStages) -> PipelineStages {
let ShaderStages {
vertex,
tessellation_control,
tessellation_evaluation,
geometry,
fragment,
compute,
raygen,
any_hit,
closest_hit,
miss,
intersection,
callable,
} = stages;
PipelineStages {
vertex_shader: vertex,
tessellation_control_shader: tessellation_control,
tessellation_evaluation_shader: tessellation_evaluation,
geometry_shader: geometry,
fragment_shader: fragment,
compute_shader: compute,
ray_tracing_shader: raygen | any_hit | closest_hit | miss | intersection | callable,
..PipelineStages::none()
}
}
}
fn check_spirv_version(device: &Device, mut version: Version) -> Result<(), ShaderSupportError> {
version.patch = 0;
match version {
Version::V1_0 => {}
Version::V1_1 | Version::V1_2 | Version::V1_3 => {
if !(device.api_version() >= Version::V1_1) {
return Err(ShaderSupportError::RequirementsNotMet(&[
"Vulkan API version 1.1",
]));
}
}
Version::V1_4 => {
if !(device.api_version() >= Version::V1_2 || device.enabled_extensions().khr_spirv_1_4)
{
return Err(ShaderSupportError::RequirementsNotMet(&[
"Vulkan API version 1.2",
"extension `khr_spirv_1_4`",
]));
}
}
Version::V1_5 => {
if !(device.api_version() >= Version::V1_2) {
return Err(ShaderSupportError::RequirementsNotMet(&[
"Vulkan API version 1.2",
]));
}
}
_ => return Err(ShaderSupportError::NotSupportedByVulkan),
}
Ok(())
}