vulkano 0.27.0

Safe wrapper for the Vulkan graphics API
Documentation
// Copyright (c) 2016 The vulkano developers
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
// at your option. All files in the project carrying such
// notice may not be copied, modified, or distributed except
// according to those terms.

//! A program that is run on the device.
//!
//! In Vulkan, shaders are grouped in *shader modules*. Each shader module is built from SPIR-V
//! code and can contain one or more entry points. Note that for the moment the official
//! GLSL-to-SPIR-V compiler does not support multiple entry points.
//!
//! The vulkano library can parse and introspect SPIR-V code, but it does not fully validate the
//! code. You are encouraged to use the `vulkano-shaders` crate that will generate Rust code that
//! wraps around vulkano's shaders API.

use crate::check_errors;
use crate::descriptor_set::layout::DescriptorType;
use crate::device::Device;
use crate::format::Format;
use crate::image::view::ImageViewType;
use crate::pipeline::graphics::input_assembly::PrimitiveTopology;
use crate::pipeline::layout::PipelineLayoutPcRange;
use crate::shader::spirv::{Capability, Spirv, SpirvError};
use crate::sync::PipelineStages;
use crate::DeviceSize;
use crate::OomError;
use crate::Version;
use crate::VulkanObject;
use fnv::FnvHashMap;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::error;
use std::error::Error;
use std::ffi::CStr;
use std::ffi::CString;
use std::fmt;
use std::fmt::Display;
use std::mem;
use std::mem::MaybeUninit;
use std::ops::BitOr;
use std::ptr;
use std::sync::Arc;

pub mod reflect;
pub mod spirv;

use spirv::ExecutionModel;

// Generated by build.rs
include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));

/// Contains SPIR-V code with one or more entry points.
#[derive(Debug)]
pub struct ShaderModule {
    handle: ash::vk::ShaderModule,
    device: Arc<Device>,
    entry_points: HashMap<String, HashMap<ExecutionModel, EntryPointInfo>>,
}

impl ShaderModule {
    /// Builds a new shader module from SPIR-V 32-bit words. The shader code is parsed and the
    /// necessary information is extracted from it.
    ///
    /// # Safety
    ///
    /// - The SPIR-V code is not validated beyond the minimum needed to extract the information.
    pub unsafe fn from_words(
        device: Arc<Device>,
        words: &[u32],
    ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
        let spirv = Spirv::new(words)?;

        Self::from_words_with_data(
            device,
            words,
            spirv.version(),
            reflect::spirv_capabilities(&spirv),
            reflect::spirv_extensions(&spirv),
            reflect::entry_points(&spirv, false),
        )
    }

    /// As `from_words`, but takes a slice of bytes.
    ///
    /// # Panics
    ///
    /// - Panics if the length of `bytes` is not a multiple of 4.
    pub unsafe fn from_bytes(
        device: Arc<Device>,
        bytes: &[u8],
    ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
        assert!((bytes.len() % 4) == 0);
        Self::from_words(
            device,
            std::slice::from_raw_parts(
                bytes.as_ptr() as *const _,
                bytes.len() / mem::size_of::<u32>(),
            ),
        )
    }

    /// As `from_words`, but does not parse the code. Instead, you must provide the needed
    /// information yourself. This can be useful if you've already done parsing yourself and
    /// want to prevent Vulkano from doing it a second time.
    ///
    /// # Safety
    ///
    /// - The SPIR-V code is not validated at all.
    /// - The provided information must match what the SPIR-V code contains.
    pub unsafe fn from_words_with_data<'a>(
        device: Arc<Device>,
        words: &[u32],
        spirv_version: Version,
        spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
        spirv_extensions: impl IntoIterator<Item = &'a str>,
        entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
    ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
        if let Err(reason) = check_spirv_version(&device, spirv_version) {
            return Err(ShaderCreationError::SpirvVersionNotSupported {
                version: spirv_version,
                reason,
            });
        }

        for capability in spirv_capabilities {
            if let Err(reason) = check_spirv_capability(&device, capability.clone()) {
                return Err(ShaderCreationError::SpirvCapabilityNotSupported {
                    capability: capability.clone(),
                    reason,
                });
            }
        }

        for extension in spirv_extensions {
            if let Err(reason) = check_spirv_extension(&device, extension) {
                return Err(ShaderCreationError::SpirvExtensionNotSupported {
                    extension: extension.to_owned(),
                    reason,
                });
            }
        }

        let handle = {
            let infos = ash::vk::ShaderModuleCreateInfo {
                flags: ash::vk::ShaderModuleCreateFlags::empty(),
                code_size: words.len() * mem::size_of::<u32>(),
                p_code: words.as_ptr(),
                ..Default::default()
            };

            let fns = device.fns();
            let mut output = MaybeUninit::uninit();
            check_errors(fns.v1_0.create_shader_module(
                device.internal_object(),
                &infos,
                ptr::null(),
                output.as_mut_ptr(),
            ))?;
            output.assume_init()
        };

        let entries = entry_points.into_iter().collect::<Vec<_>>();
        let entry_points = entries
            .iter()
            .filter_map(|(name, _, _)| Some(name))
            .collect::<HashSet<_>>()
            .iter()
            .map(|name| {
                ((*name).clone(),
                    entries.iter().filter_map(|(entry_name, entry_model, info)| {
                        if &entry_name == name {
                            Some((*entry_model, info.clone()))
                        } else {
                            None
                        }
                    }).collect::<HashMap<_, _>>()
                )
            })
            .collect();

        Ok(Arc::new(ShaderModule {
            handle,
            device,
            entry_points,
        }))
    }

    /// As `from_words_with_data`, but takes a slice of bytes.
    ///
    /// # Panics
    ///
    /// - Panics if the length of `bytes` is not a multiple of 4.
    pub unsafe fn from_bytes_with_data<'a>(
        device: Arc<Device>,
        bytes: &[u8],
        spirv_version: Version,
        spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
        spirv_extensions: impl IntoIterator<Item = &'a str>,
        entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
    ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
        assert!((bytes.len() % 4) == 0);
        Self::from_words_with_data(
            device,
            std::slice::from_raw_parts(
                bytes.as_ptr() as *const _,
                bytes.len() / mem::size_of::<u32>(),
            ),
            spirv_version,
            spirv_capabilities,
            spirv_extensions,
            entry_points,
        )
    }

    /// Returns information about the entry point with the provided name. Returns `None` if no entry
    /// point with that name exists in the shader module or if multiple entry points with the same
    /// name exist.
    pub fn entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>> {
        self.entry_points.get(name).and_then(|infos| {
            if infos.len() == 1 {
                infos.iter().next().map(|(_, info)| EntryPoint {
                    module: self,
                    name: CString::new(name).unwrap(),
                    info,
                })
            } else {
                None
            }
        })
    }

    /// Returns information about the entry point with the provided name and execution model. Returns
    /// `None` if no entry and execution model exists in the shader module.
    pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option<EntryPoint<'a>> {
        self.entry_points.get(name).and_then(|infos| {
            infos.get(&execution).map(|info| EntryPoint {
                module: self,
                name: CString::new(name).unwrap(),
                info,
            })
        })
    }
}

unsafe impl VulkanObject for ShaderModule {
    type Object = ash::vk::ShaderModule;

    #[inline]
    fn internal_object(&self) -> ash::vk::ShaderModule {
        self.handle
    }
}

impl Drop for ShaderModule {
    #[inline]
    fn drop(&mut self) {
        unsafe {
            let fns = self.device.fns();
            fns.v1_0
                .destroy_shader_module(self.device.internal_object(), self.handle, ptr::null());
        }
    }
}

/// Error that can happen when creating a new shader module.
#[derive(Clone, Debug)]
pub enum ShaderCreationError {
    OomError(OomError),
    SpirvCapabilityNotSupported {
        capability: Capability,
        reason: ShaderSupportError,
    },
    SpirvError(SpirvError),
    SpirvExtensionNotSupported {
        extension: String,
        reason: ShaderSupportError,
    },
    SpirvVersionNotSupported {
        version: Version,
        reason: ShaderSupportError,
    },
}

impl Error for ShaderCreationError {
    #[inline]
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
        match self {
            Self::OomError(err) => Some(err),
            Self::SpirvCapabilityNotSupported { reason, .. } => Some(reason),
            Self::SpirvError(err) => Some(err),
            Self::SpirvExtensionNotSupported { reason, .. } => Some(reason),
            Self::SpirvVersionNotSupported { reason, .. } => Some(reason),
        }
    }
}

impl Display for ShaderCreationError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::OomError(_) => write!(f, "not enough memory available"),
            Self::SpirvCapabilityNotSupported { capability, .. } => write!(
                f,
                "the SPIR-V capability {:?} enabled by the shader is not supported by the device",
                capability,
            ),
            Self::SpirvError(_) => write!(f, "the SPIR-V module could not be read"),
            Self::SpirvExtensionNotSupported { extension, .. } => write!(
                f,
                "the SPIR-V extension {} enabled by the shader is not supported by the device",
                extension,
            ),
            Self::SpirvVersionNotSupported { version, .. } => write!(
                f,
                "the shader uses SPIR-V version {}.{}, which is not supported by the device",
                version.major, version.minor,
            ),
        }
    }
}

impl From<crate::Error> for ShaderCreationError {
    fn from(err: crate::Error) -> Self {
        Self::OomError(err.into())
    }
}

impl From<SpirvError> for ShaderCreationError {
    fn from(err: SpirvError) -> Self {
        Self::SpirvError(err)
    }
}

/// Error that can happen when checking whether a shader is supported by a device.
#[derive(Clone, Copy, Debug)]
pub enum ShaderSupportError {
    NotSupportedByVulkan,
    RequirementsNotMet(&'static [&'static str]),
}

impl Error for ShaderSupportError {}

impl Display for ShaderSupportError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::NotSupportedByVulkan => write!(f, "not supported by Vulkan"),
            Self::RequirementsNotMet(requirements) => write!(
                f,
                "at least one of the following must be available/enabled on the device: {}",
                requirements.join(", "),
            ),
        }
    }
}

/// The information associated with a single entry point in a shader.
#[derive(Clone, Debug)]
pub struct EntryPointInfo {
    pub execution: ShaderExecution,
    pub descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>,
    pub push_constant_requirements: Option<PipelineLayoutPcRange>,
    pub specialization_constant_requirements: FnvHashMap<u32, SpecializationConstantRequirements>,
    pub input_interface: ShaderInterface,
    pub output_interface: ShaderInterface,
}

/// Represents a shader entry point in a shader module.
///
/// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
#[derive(Clone, Debug)]
pub struct EntryPoint<'a> {
    module: &'a ShaderModule,
    name: CString,
    info: &'a EntryPointInfo,
}

impl<'a> EntryPoint<'a> {
    /// Returns the module this entry point comes from.
    #[inline]
    pub fn module(&self) -> &'a ShaderModule {
        self.module
    }

    /// Returns the name of the entry point.
    #[inline]
    pub fn name(&self) -> &CStr {
        &self.name
    }

    /// Returns the execution model of the shader.
    #[inline]
    pub fn execution(&self) -> &ShaderExecution {
        &self.info.execution
    }

    /// Returns the descriptor requirements.
    #[inline]
    pub fn descriptor_requirements(
        &self,
    ) -> impl ExactSizeIterator<Item = ((u32, u32), &DescriptorRequirements)> {
        self.info
            .descriptor_requirements
            .iter()
            .map(|(k, v)| (*k, v))
    }

    /// Returns the push constant requirements.
    #[inline]
    pub fn push_constant_requirements(&self) -> Option<&PipelineLayoutPcRange> {
        self.info.push_constant_requirements.as_ref()
    }

    /// Returns the specialization constant requirements.
    #[inline]
    pub fn specialization_constant_requirements(
        &self,
    ) -> impl ExactSizeIterator<Item = (u32, &SpecializationConstantRequirements)> {
        self.info
            .specialization_constant_requirements
            .iter()
            .map(|(k, v)| (*k, v))
    }

    /// Returns the input attributes used by the shader stage.
    #[inline]
    pub fn input_interface(&self) -> &ShaderInterface {
        &self.info.input_interface
    }

    /// Returns the output attributes used by the shader stage.
    #[inline]
    pub fn output_interface(&self) -> &ShaderInterface {
        &self.info.output_interface
    }
}

/// The mode in which a shader executes. This includes both information about the shader type/stage,
/// and additional data relevant to particular shader types.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ShaderExecution {
    Vertex,
    TessellationControl,
    TessellationEvaluation,
    Geometry(GeometryShaderExecution),
    Fragment,
    Compute,
}

/*#[derive(Clone, Copy, Debug)]
pub struct TessellationShaderExecution {
    pub num_output_vertices: u32,
    pub point_mode: bool,
    pub subdivision: TessellationShaderSubdivision,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TessellationShaderSubdivision {
    Triangles,
    Quads,
    Isolines,
}*/

/// The mode in which a geometry shader executes.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GeometryShaderExecution {
    pub input: GeometryShaderInput,
    /*pub max_output_vertices: u32,
    pub num_invocations: u32,
    pub output: GeometryShaderOutput,*/
}

/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderInput {
    Points,
    Lines,
    LinesWithAdjacency,
    Triangles,
    TrianglesWithAdjacency,
}

impl GeometryShaderInput {
    /// Returns true if the given primitive topology can be used as input for this geometry shader.
    #[inline]
    pub fn is_compatible_with(&self, topology: PrimitiveTopology) -> bool {
        match self {
            Self::Points => matches!(topology, PrimitiveTopology::PointList),
            Self::Lines => matches!(
                topology,
                PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
            ),
            Self::LinesWithAdjacency => matches!(
                topology,
                PrimitiveTopology::LineListWithAdjacency
                    | PrimitiveTopology::LineStripWithAdjacency
            ),
            Self::Triangles => matches!(
                topology,
                PrimitiveTopology::TriangleList
                    | PrimitiveTopology::TriangleStrip
                    | PrimitiveTopology::TriangleFan,
            ),
            Self::TrianglesWithAdjacency => matches!(
                topology,
                PrimitiveTopology::TriangleListWithAdjacency
                    | PrimitiveTopology::TriangleStripWithAdjacency,
            ),
        }
    }
}

/*#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderOutput {
    Points,
    LineStrip,
    TriangleStrip,
}*/

/// The requirements imposed by a shader on a descriptor within a descriptor set layout, and on any
/// resource that is bound to that descriptor.
#[derive(Clone, Debug, Default)]
pub struct DescriptorRequirements {
    /// The descriptor types that are allowed.
    pub descriptor_types: Vec<DescriptorType>,

    /// The number of descriptors (array elements) that the shader requires. The descriptor set
    /// layout can declare more than this, but never less.
    pub descriptor_count: u32,

    /// The image format that is required for image views bound to this descriptor. If this is
    /// `None`, then any image format is allowed.
    pub format: Option<Format>,

    /// The view type that is required for image views bound to this descriptor. This is `None` for
    /// non-image descriptors.
    pub image_view_type: Option<ImageViewType>,

    /// Whether image views bound to this descriptor must have multisampling enabled or disabled.
    pub multisampled: bool,

    /// Whether the shader requires mutable (exclusive) access to the resource bound to this
    /// descriptor.
    pub mutable: bool,

    /// The shader stages that the descriptor must be declared for.
    pub stages: ShaderStages,
}

impl DescriptorRequirements {
    /// Produces the intersection of two descriptor requirements, so that the requirements of both
    /// are satisfied. An error is returned if the requirements conflict.
    pub fn intersection(&self, other: &Self) -> Result<Self, DescriptorRequirementsIncompatible> {
        let descriptor_types: Vec<_> = self
            .descriptor_types
            .iter()
            .copied()
            .filter(|ty| other.descriptor_types.contains(&ty))
            .collect();

        if descriptor_types.is_empty() {
            return Err(DescriptorRequirementsIncompatible::DescriptorType);
        }

        if let (Some(first), Some(second)) = (self.format, other.format) {
            if first != second {
                return Err(DescriptorRequirementsIncompatible::Format);
            }
        }

        if let (Some(first), Some(second)) = (self.image_view_type, other.image_view_type) {
            if first != second {
                return Err(DescriptorRequirementsIncompatible::ImageViewType);
            }
        }

        if self.multisampled != other.multisampled {
            return Err(DescriptorRequirementsIncompatible::Multisampled);
        }

        Ok(Self {
            descriptor_types,
            descriptor_count: self.descriptor_count.max(other.descriptor_count),
            format: self.format.or(other.format),
            image_view_type: self.image_view_type.or(other.image_view_type),
            multisampled: self.multisampled,
            mutable: self.mutable || other.mutable,
            stages: self.stages | other.stages,
        })
    }
}

/// An error that can be returned when trying to create the intersection of two
/// `DescriptorRequirements` values.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DescriptorRequirementsIncompatible {
    /// The allowed descriptor types of the descriptors do not overlap.
    DescriptorType,
    /// The descriptors require different formats.
    Format,
    /// The descriptors require different image view types.
    ImageViewType,
    /// The multisampling requirements of the descriptors differ.
    Multisampled,
}

impl Error for DescriptorRequirementsIncompatible {}

impl Display for DescriptorRequirementsIncompatible {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            DescriptorRequirementsIncompatible::DescriptorType => {
                write!(
                    f,
                    "the allowed descriptor types of the two descriptors do not overlap"
                )
            }
            DescriptorRequirementsIncompatible::Format => {
                write!(f, "the descriptors require different formats")
            }
            DescriptorRequirementsIncompatible::ImageViewType => {
                write!(f, "the descriptors require different image view types")
            }
            DescriptorRequirementsIncompatible::Multisampled => {
                write!(
                    f,
                    "the multisampling requirements of the descriptors differ"
                )
            }
        }
    }
}

/// The requirements imposed by a shader on a specialization constant.
#[derive(Clone, Copy, Debug)]
pub struct SpecializationConstantRequirements {
    pub size: DeviceSize,
}

/// Trait for types that contain specialization data for shaders.
///
/// Shader modules can contain what is called *specialization constants*. They are the same as
/// constants except that their values can be defined when you create a compute pipeline or a
/// graphics pipeline. Doing so is done by passing a type that implements the
/// `SpecializationConstants` trait and that stores the values in question. The `descriptors()`
/// method of this trait indicates how to grab them.
///
/// Boolean specialization constants must be stored as 32bits integers, where `0` means `false` and
/// any non-zero value means `true`. Integer and floating-point specialization constants are
/// stored as their Rust equivalent.
///
/// This trait is implemented on `()` for shaders that don't have any specialization constant.
///
/// # Example
///
/// ```rust
/// use vulkano::shader::SpecializationConstants;
/// use vulkano::shader::SpecializationMapEntry;
///
/// #[repr(C)]      // `#[repr(C)]` guarantees that the struct has a specific layout
/// struct MySpecConstants {
///     my_integer_constant: i32,
///     a_boolean: u32,
///     floating_point: f32,
/// }
///
/// unsafe impl SpecializationConstants for MySpecConstants {
///     fn descriptors() -> &'static [SpecializationMapEntry] {
///         static DESCRIPTORS: [SpecializationMapEntry; 3] = [
///             SpecializationMapEntry {
///                 constant_id: 0,
///                 offset: 0,
///                 size: 4,
///             },
///             SpecializationMapEntry {
///                 constant_id: 1,
///                 offset: 4,
///                 size: 4,
///             },
///             SpecializationMapEntry {
///                 constant_id: 2,
///                 offset: 8,
///                 size: 4,
///             },
///         ];
///
///         &DESCRIPTORS
///     }
/// }
/// ```
///
/// # Safety
///
/// - The `SpecializationMapEntry` returned must contain valid offsets and sizes.
/// - The size of each `SpecializationMapEntry` must match the size of the corresponding constant
///   (`4` for booleans).
///
pub unsafe trait SpecializationConstants {
    /// Returns descriptors of the struct's layout.
    fn descriptors() -> &'static [SpecializationMapEntry];
}

unsafe impl SpecializationConstants for () {
    #[inline]
    fn descriptors() -> &'static [SpecializationMapEntry] {
        &[]
    }
}

/// Describes an individual constant to set in the shader. Also a field in the struct.
// Implementation note: has the same memory representation as a `VkSpecializationMapEntry`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(C)]
pub struct SpecializationMapEntry {
    /// Identifier of the constant in the shader that corresponds to this field.
    ///
    /// For SPIR-V, this must be the value of the `SpecId` decoration applied to the specialization
    /// constant.
    /// For GLSL, this must be the value of `N` in the `layout(constant_id = N)` attribute applied
    /// to a constant.
    pub constant_id: u32,

    /// Offset within the struct where the data can be found.
    pub offset: u32,

    /// Size of the data in bytes. Must match the size of the constant (`4` for booleans).
    pub size: usize,
}

impl From<SpecializationMapEntry> for ash::vk::SpecializationMapEntry {
    #[inline]
    fn from(val: SpecializationMapEntry) -> Self {
        Self {
            constant_id: val.constant_id,
            offset: val.offset,
            size: val.size,
        }
    }
}

/// Type that contains the definition of an interface between two shader stages, or between
/// the outside and a shader stage.
#[derive(Clone, Debug)]
pub struct ShaderInterface {
    elements: Vec<ShaderInterfaceEntry>,
}

impl ShaderInterface {
    /// Constructs a new `ShaderInterface`.
    ///
    /// # Safety
    ///
    /// - Must only provide one entry per location.
    /// - The format of each element must not be larger than 128 bits.
    // TODO: 4x64 bit formats are possible, but they require special handling.
    // TODO: could this be made safe?
    #[inline]
    pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface {
        ShaderInterface { elements }
    }

    /// Creates a description of an empty shader interface.
    pub const fn empty() -> ShaderInterface {
        ShaderInterface {
            elements: Vec::new(),
        }
    }

    /// Returns a slice containing the elements of the interface.
    #[inline]
    pub fn elements(&self) -> &[ShaderInterfaceEntry] {
        self.elements.as_ref()
    }

    /// Checks whether the interface is potentially compatible with another one.
    ///
    /// Returns `Ok` if the two interfaces are compatible.
    pub fn matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError> {
        if self.elements().len() != other.elements().len() {
            return Err(ShaderInterfaceMismatchError::ElementsCountMismatch {
                self_elements: self.elements().len() as u32,
                other_elements: other.elements().len() as u32,
            });
        }

        for a in self.elements() {
            let location_range = a.location..a.location + a.ty.num_locations();
            for loc in location_range {
                let b = match other
                    .elements()
                    .iter()
                    .find(|e| loc >= e.location && loc < e.location + e.ty.num_locations())
                {
                    None => {
                        return Err(ShaderInterfaceMismatchError::MissingElement { location: loc })
                    }
                    Some(b) => b,
                };

                if a.ty != b.ty {
                    return Err(ShaderInterfaceMismatchError::TypeMismatch {
                        location: loc,
                        self_ty: a.ty,
                        other_ty: b.ty,
                    });
                }

                // TODO: enforce this?
                /*match (a.name, b.name) {
                    (Some(ref an), Some(ref bn)) => if an != bn { return false },
                    _ => ()
                };*/
            }
        }

        // Note: since we check that the number of elements is the same, we don't need to iterate
        // over b's elements.

        Ok(())
    }
}

/// Entry of a shader interface definition.
#[derive(Debug, Clone)]
pub struct ShaderInterfaceEntry {
    /// The location slot that the variable starts at.
    pub location: u32,

    /// The component slot that the variable starts at. Must be in the range 0..=3.
    pub component: u32,

    /// Name of the element, or `None` if the name is unknown.
    pub name: Option<Cow<'static, str>>,

    /// The type of the variable.
    pub ty: ShaderInterfaceEntryType,
}

/// The type of a variable in a shader interface.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ShaderInterfaceEntryType {
    /// The base numeric type.
    pub base_type: ShaderScalarType,

    /// The number of vector components. Must be in the range 1..=4.
    pub num_components: u32,

    /// The number of array elements or matrix columns.
    pub num_elements: u32,

    /// Whether the base type is 64 bits wide. If true, each item of the base type takes up two
    /// component slots instead of one.
    pub is_64bit: bool,
}

impl ShaderInterfaceEntryType {
    pub(crate) fn to_format(&self) -> Format {
        assert!(!self.is_64bit); // TODO: implement
        match self.base_type {
            ShaderScalarType::Float => match self.num_components {
                1 => Format::R32_SFLOAT,
                2 => Format::R32G32_SFLOAT,
                3 => Format::R32G32B32_SFLOAT,
                4 => Format::R32G32B32A32_SFLOAT,
                _ => unreachable!(),
            },
            ShaderScalarType::Sint => match self.num_components {
                1 => Format::R32_SINT,
                2 => Format::R32G32_SINT,
                3 => Format::R32G32B32_SINT,
                4 => Format::R32G32B32A32_SINT,
                _ => unreachable!(),
            },
            ShaderScalarType::Uint => match self.num_components {
                1 => Format::R32_UINT,
                2 => Format::R32G32_UINT,
                3 => Format::R32G32B32_UINT,
                4 => Format::R32G32B32A32_UINT,
                _ => unreachable!(),
            },
        }
    }

    pub(crate) fn num_locations(&self) -> u32 {
        assert!(!self.is_64bit); // TODO: implement
        self.num_elements
    }
}

/// The numeric base type of a shader interface variable.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ShaderScalarType {
    Float,
    Sint,
    Uint,
}

/// Error that can happen when the interface mismatches between two shader stages.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ShaderInterfaceMismatchError {
    /// The number of elements is not the same between the two shader interfaces.
    ElementsCountMismatch {
        /// Number of elements in the first interface.
        self_elements: u32,
        /// Number of elements in the second interface.
        other_elements: u32,
    },

    /// An element is missing from one of the interfaces.
    MissingElement {
        /// Location of the missing element.
        location: u32,
    },

    /// The type of an element does not match.
    TypeMismatch {
        /// Location of the element that mismatches.
        location: u32,
        /// Type in the first interface.
        self_ty: ShaderInterfaceEntryType,
        /// Type in the second interface.
        other_ty: ShaderInterfaceEntryType,
    },
}

impl error::Error for ShaderInterfaceMismatchError {}

impl fmt::Display for ShaderInterfaceMismatchError {
    #[inline]
    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
        write!(
            fmt,
            "{}",
            match *self {
                ShaderInterfaceMismatchError::ElementsCountMismatch { .. } => {
                    "the number of elements mismatches"
                }
                ShaderInterfaceMismatchError::MissingElement { .. } => "an element is missing",
                ShaderInterfaceMismatchError::TypeMismatch { .. } => {
                    "the type of an element does not match"
                }
            }
        )
    }
}

/// A single shader stage.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ShaderStage {
    Vertex = ash::vk::ShaderStageFlags::VERTEX.as_raw(),
    TessellationControl = ash::vk::ShaderStageFlags::TESSELLATION_CONTROL.as_raw(),
    TessellationEvaluation = ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION.as_raw(),
    Geometry = ash::vk::ShaderStageFlags::GEOMETRY.as_raw(),
    Fragment = ash::vk::ShaderStageFlags::FRAGMENT.as_raw(),
    Compute = ash::vk::ShaderStageFlags::COMPUTE.as_raw(),
    Raygen = ash::vk::ShaderStageFlags::RAYGEN_KHR.as_raw(),
    AnyHit = ash::vk::ShaderStageFlags::ANY_HIT_KHR.as_raw(),
    ClosestHit = ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR.as_raw(),
    Miss = ash::vk::ShaderStageFlags::MISS_KHR.as_raw(),
    Intersection = ash::vk::ShaderStageFlags::INTERSECTION_KHR.as_raw(),
    Callable = ash::vk::ShaderStageFlags::CALLABLE_KHR.as_raw(),
}

impl From<ShaderExecution> for ShaderStage {
    #[inline]
    fn from(val: ShaderExecution) -> Self {
        match val {
            ShaderExecution::Vertex => Self::Vertex,
            ShaderExecution::TessellationControl => Self::TessellationControl,
            ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
            ShaderExecution::Geometry(_) => Self::Geometry,
            ShaderExecution::Fragment => Self::Fragment,
            ShaderExecution::Compute => Self::Compute,
        }
    }
}

impl From<ShaderStage> for ShaderStages {
    #[inline]
    fn from(val: ShaderStage) -> Self {
        match val {
            ShaderStage::Vertex => Self {
                vertex: true,
                ..Self::none()
            },
            ShaderStage::TessellationControl => Self {
                tessellation_control: true,
                ..Self::none()
            },
            ShaderStage::TessellationEvaluation => Self {
                tessellation_evaluation: true,
                ..Self::none()
            },
            ShaderStage::Geometry => Self {
                geometry: true,
                ..Self::none()
            },
            ShaderStage::Fragment => Self {
                fragment: true,
                ..Self::none()
            },
            ShaderStage::Compute => Self {
                compute: true,
                ..Self::none()
            },
            ShaderStage::Raygen => Self {
                raygen: true,
                ..Self::none()
            },
            ShaderStage::AnyHit => Self {
                any_hit: true,
                ..Self::none()
            },
            ShaderStage::ClosestHit => Self {
                closest_hit: true,
                ..Self::none()
            },
            ShaderStage::Miss => Self {
                miss: true,
                ..Self::none()
            },
            ShaderStage::Intersection => Self {
                intersection: true,
                ..Self::none()
            },
            ShaderStage::Callable => Self {
                callable: true,
                ..Self::none()
            },
        }
    }
}

impl From<ShaderStage> for ash::vk::ShaderStageFlags {
    #[inline]
    fn from(val: ShaderStage) -> Self {
        Self::from_raw(val as u32)
    }
}

/// A set of shader stages.
// TODO: add example with BitOr
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ShaderStages {
    pub vertex: bool,
    pub tessellation_control: bool,
    pub tessellation_evaluation: bool,
    pub geometry: bool,
    pub fragment: bool,
    pub compute: bool,
    pub raygen: bool,
    pub any_hit: bool,
    pub closest_hit: bool,
    pub miss: bool,
    pub intersection: bool,
    pub callable: bool,
}

impl ShaderStages {
    /// Creates a `ShaderStages` struct will all stages set to `true`.
    // TODO: add example
    #[inline]
    pub const fn all() -> ShaderStages {
        ShaderStages {
            vertex: true,
            tessellation_control: true,
            tessellation_evaluation: true,
            geometry: true,
            fragment: true,
            compute: true,
            raygen: true,
            any_hit: true,
            closest_hit: true,
            miss: true,
            intersection: true,
            callable: true,
        }
    }

    /// Creates a `ShaderStages` struct will all stages set to `false`.
    // TODO: add example
    #[inline]
    pub const fn none() -> ShaderStages {
        ShaderStages {
            vertex: false,
            tessellation_control: false,
            tessellation_evaluation: false,
            geometry: false,
            fragment: false,
            compute: false,
            raygen: false,
            any_hit: false,
            closest_hit: false,
            miss: false,
            intersection: false,
            callable: false,
        }
    }

    /// Creates a `ShaderStages` struct with all graphics stages set to `true`.
    // TODO: add example
    #[inline]
    pub const fn all_graphics() -> ShaderStages {
        ShaderStages {
            vertex: true,
            tessellation_control: true,
            tessellation_evaluation: true,
            geometry: true,
            fragment: true,
            ..ShaderStages::none()
        }
    }

    /// Creates a `ShaderStages` struct with the compute stage set to `true`.
    // TODO: add example
    #[inline]
    pub const fn compute() -> ShaderStages {
        ShaderStages {
            compute: true,
            ..ShaderStages::none()
        }
    }

    /// Returns whether `self` contains all the stages of `other`.
    // TODO: add example
    #[inline]
    pub const fn is_superset_of(&self, other: &ShaderStages) -> bool {
        let Self {
            vertex,
            tessellation_control,
            tessellation_evaluation,
            geometry,
            fragment,
            compute,
            raygen,
            any_hit,
            closest_hit,
            miss,
            intersection,
            callable,
        } = *self;

        (vertex || !other.vertex)
            && (tessellation_control || !other.tessellation_control)
            && (tessellation_evaluation || !other.tessellation_evaluation)
            && (geometry || !other.geometry)
            && (fragment || !other.fragment)
            && (compute || !other.compute)
            && (raygen || !other.raygen)
            && (any_hit || !other.any_hit)
            && (closest_hit || !other.closest_hit)
            && (miss || !other.miss)
            && (intersection || !other.intersection)
            && (callable || !other.callable)
    }

    /// Checks whether any of the stages in `self` are also present in `other`.
    // TODO: add example
    #[inline]
    pub const fn intersects(&self, other: &ShaderStages) -> bool {
        let Self {
            vertex,
            tessellation_control,
            tessellation_evaluation,
            geometry,
            fragment,
            compute,
            raygen,
            any_hit,
            closest_hit,
            miss,
            intersection,
            callable,
        } = *self;

        (vertex && other.vertex)
            || (tessellation_control && other.tessellation_control)
            || (tessellation_evaluation && other.tessellation_evaluation)
            || (geometry && other.geometry)
            || (fragment && other.fragment)
            || (compute && other.compute)
            || (raygen && other.raygen)
            || (any_hit && other.any_hit)
            || (closest_hit && other.closest_hit)
            || (miss && other.miss)
            || (intersection && other.intersection)
            || (callable && other.callable)
    }

    /// Returns the union of the stages in `self` and `other`.
    #[inline]
    pub const fn union(&self, other: &Self) -> Self {
        Self {
            vertex: self.vertex || other.vertex,
            tessellation_control: self.tessellation_control || other.tessellation_control,
            tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation,
            geometry: self.geometry || other.geometry,
            fragment: self.fragment || other.fragment,
            compute: self.compute || other.compute,
            raygen: self.raygen || other.raygen,
            any_hit: self.any_hit || other.any_hit,
            closest_hit: self.closest_hit || other.closest_hit,
            miss: self.miss || other.miss,
            intersection: self.intersection || other.intersection,
            callable: self.callable || other.callable,
        }
    }
}

impl From<ShaderStages> for ash::vk::ShaderStageFlags {
    #[inline]
    fn from(val: ShaderStages) -> Self {
        let mut result = ash::vk::ShaderStageFlags::empty();
        let ShaderStages {
            vertex,
            tessellation_control,
            tessellation_evaluation,
            geometry,
            fragment,
            compute,
            raygen,
            any_hit,
            closest_hit,
            miss,
            intersection,
            callable,
        } = val;

        if vertex {
            result |= ash::vk::ShaderStageFlags::VERTEX;
        }
        if tessellation_control {
            result |= ash::vk::ShaderStageFlags::TESSELLATION_CONTROL;
        }
        if tessellation_evaluation {
            result |= ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION;
        }
        if geometry {
            result |= ash::vk::ShaderStageFlags::GEOMETRY;
        }
        if fragment {
            result |= ash::vk::ShaderStageFlags::FRAGMENT;
        }
        if compute {
            result |= ash::vk::ShaderStageFlags::COMPUTE;
        }
        if raygen {
            result |= ash::vk::ShaderStageFlags::RAYGEN_KHR;
        }
        if any_hit {
            result |= ash::vk::ShaderStageFlags::ANY_HIT_KHR;
        }
        if closest_hit {
            result |= ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR;
        }
        if miss {
            result |= ash::vk::ShaderStageFlags::MISS_KHR;
        }
        if intersection {
            result |= ash::vk::ShaderStageFlags::INTERSECTION_KHR;
        }
        if callable {
            result |= ash::vk::ShaderStageFlags::CALLABLE_KHR;
        }
        result
    }
}

impl From<ash::vk::ShaderStageFlags> for ShaderStages {
    #[inline]
    fn from(val: ash::vk::ShaderStageFlags) -> Self {
        Self {
            vertex: val.intersects(ash::vk::ShaderStageFlags::VERTEX),
            tessellation_control: val.intersects(ash::vk::ShaderStageFlags::TESSELLATION_CONTROL),
            tessellation_evaluation: val
                .intersects(ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION),
            geometry: val.intersects(ash::vk::ShaderStageFlags::GEOMETRY),
            fragment: val.intersects(ash::vk::ShaderStageFlags::FRAGMENT),
            compute: val.intersects(ash::vk::ShaderStageFlags::COMPUTE),
            raygen: val.intersects(ash::vk::ShaderStageFlags::RAYGEN_KHR),
            any_hit: val.intersects(ash::vk::ShaderStageFlags::ANY_HIT_KHR),
            closest_hit: val.intersects(ash::vk::ShaderStageFlags::CLOSEST_HIT_KHR),
            miss: val.intersects(ash::vk::ShaderStageFlags::MISS_KHR),
            intersection: val.intersects(ash::vk::ShaderStageFlags::INTERSECTION_KHR),
            callable: val.intersects(ash::vk::ShaderStageFlags::CALLABLE_KHR),
        }
    }
}

impl BitOr for ShaderStages {
    type Output = ShaderStages;

    #[inline]
    fn bitor(self, other: ShaderStages) -> ShaderStages {
        ShaderStages {
            vertex: self.vertex || other.vertex,
            tessellation_control: self.tessellation_control || other.tessellation_control,
            tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation,
            geometry: self.geometry || other.geometry,
            fragment: self.fragment || other.fragment,
            compute: self.compute || other.compute,
            raygen: self.raygen || other.raygen,
            any_hit: self.any_hit || other.any_hit,
            closest_hit: self.closest_hit || other.closest_hit,
            miss: self.miss || other.miss,
            intersection: self.intersection || other.intersection,
            callable: self.callable || other.callable,
        }
    }
}

impl From<ShaderStages> for PipelineStages {
    #[inline]
    fn from(stages: ShaderStages) -> PipelineStages {
        let ShaderStages {
            vertex,
            tessellation_control,
            tessellation_evaluation,
            geometry,
            fragment,
            compute,
            raygen,
            any_hit,
            closest_hit,
            miss,
            intersection,
            callable,
        } = stages;

        PipelineStages {
            vertex_shader: vertex,
            tessellation_control_shader: tessellation_control,
            tessellation_evaluation_shader: tessellation_evaluation,
            geometry_shader: geometry,
            fragment_shader: fragment,
            compute_shader: compute,
            ray_tracing_shader: raygen | any_hit | closest_hit | miss | intersection | callable,
            ..PipelineStages::none()
        }
    }
}

fn check_spirv_version(device: &Device, mut version: Version) -> Result<(), ShaderSupportError> {
    version.patch = 0; // Ignore the patch version

    match version {
        Version::V1_0 => {}
        Version::V1_1 | Version::V1_2 | Version::V1_3 => {
            if !(device.api_version() >= Version::V1_1) {
                return Err(ShaderSupportError::RequirementsNotMet(&[
                    "Vulkan API version 1.1",
                ]));
            }
        }
        Version::V1_4 => {
            if !(device.api_version() >= Version::V1_2 || device.enabled_extensions().khr_spirv_1_4)
            {
                return Err(ShaderSupportError::RequirementsNotMet(&[
                    "Vulkan API version 1.2",
                    "extension `khr_spirv_1_4`",
                ]));
            }
        }
        Version::V1_5 => {
            if !(device.api_version() >= Version::V1_2) {
                return Err(ShaderSupportError::RequirementsNotMet(&[
                    "Vulkan API version 1.2",
                ]));
            }
        }
        _ => return Err(ShaderSupportError::NotSupportedByVulkan),
    }
    Ok(())
}