spirv-layout 1.0.0

SPIRV reflection utility for deriving Vulkan DescriptorSetLayouts
Documentation
#![allow(unknown_lints)]
#![warn(clippy::all, clippy::pedantic)]
#![allow(
    clippy::unreadable_literal,
    clippy::too_many_lines,
    clippy::must_use_candidate
)]

use std::{collections::HashMap, str::Utf8Error};

use ops::{Dim, Op};
use thiserror::Error;

mod ops;

#[derive(Debug, Clone, Error)]
pub enum Error {
    #[error("{0}")]
    Other(String),
    #[error("invalid header")]
    InvalidHeader,
    #[error("invalid bytes in commmand")]
    InvalidOp,
    #[error("invalid id")]
    InvalidId,
    #[error("invalid utf-8 in string")]
    StringFormat(#[from] Utf8Error),
}

pub type SpirvResult<T> = ::std::result::Result<T, Error>;

/// Stores the reflection info of a single SPIRV module.
#[derive(Debug)]
pub struct Module {
    /// Stores information about all type declarations that exist in the given SPIRV module.
    types: HashMap<u32, Type>,
    /// Stores information about all uniform variables that exist in the given SPIRV module.
    uniforms: Vec<Variable>,
    /// Stores information about all push constant variables that exist in the given SPIRV module.
    push_constants: Vec<PushConstantVariable>,
}

impl Module {
    /// Generates reflection info from a given stream of `words`.
    ///
    /// # Errors
    /// - [`Error::InvalidHeader`] if the SPIRV header is not valid
    /// - [`Error::InvalidOp`] if the binary representation of any instruction in `words` is not valid
    /// - [`Error::InvalidId`] if any type declaration in the SPIRV module reference non-existent IDs
    /// - [`Error::StringFormat`] if any `OpCode` contains a String with invalid UTF-8 characters
    /// - [`Error::Other`] if any other errors occur
    pub fn from_words(mut words: &[u32]) -> SpirvResult<Self> {
        // Check the SPIRV header magic number
        if words.len() < 6 || words[0] != 0x07230203 {
            return Err(Error::InvalidHeader);
        }

        // Skip the rest of the header (Should be parsed in the future)
        words = &words[5..];

        // decode all opcodes
        let mut ops = Vec::new();
        while !words.is_empty() {
            let op = Op::decode(&mut words)?;
            ops.push(op);
        }

        // All OpConstant values are stored in this Vec
        let mut constants = HashMap::new();
        // All type declarations are stored in this Vec
        let mut types = HashMap::new();
        // All variable declarations are stored in this Vec
        let mut vars = HashMap::new();

        Self::collect_types_and_vars(&ops, &mut types, &mut constants, &mut vars)?;
        Self::collect_decorations_and_names(&ops, &mut types, &mut vars);

        // uniforms are all variables that are a pointer with a storage class of Uniform or UniformConstant
        let uniforms = vars
            .iter()
            .filter_map(|(_id, var)| {
                if let Some(Type::Pointer {
                    storage_class: StorageClass::Uniform | StorageClass::UniformConstant,
                    pointed_type_id,
                }) = types.get(&var.type_id)
                {
                    Some(Variable {
                        set: var.set,
                        binding: var.binding,
                        type_id: *pointed_type_id, // for convenience, we store the pointed-to type instead of the pointer, since every uniform is a pointer
                        name: var.name.clone(),
                    })
                } else {
                    None
                }
            })
            .collect();

        let push_constants = vars
            .iter()
            .filter_map(|(_id, var)| {
                if let Some(Type::Pointer {
                    storage_class: StorageClass::PushConstant,
                    pointed_type_id,
                }) = types.get(&var.type_id)
                {
                    Some(PushConstantVariable {
                        type_id: *pointed_type_id,
                        name: var.name.clone(),
                    })
                } else {
                    None
                }
            })
            .collect();

        Ok(Self {
            types,
            uniforms,
            push_constants,
        })
    }

    /// Returns the [`Type`] definition indicated by `type_id`, or `None` if `type_id` is not a type.
    pub fn get_type(&self, type_id: u32) -> Option<&Type> {
        self.types.get(&type_id)
    }

    /// Returns all uniform variables declared in the given SPIR-V module.
    pub fn get_uniforms(&self) -> &[Variable] {
        &self.uniforms
    }

    /// Returns all push-constant variables declared in the given SPIR-V module.
    pub fn get_push_constants(&self) -> &[PushConstantVariable] {
        &self.push_constants
    }

    /// Calculates the size of a primitive type or Struct.
    ///
    /// # Returns
    /// The size of the type indicated by `type_id`
    /// or [`None`] if the size of the given type is not known.
    pub fn get_type_size(&self, type_id: u32) -> Option<u32> {
        if let Some(ty) = self.types.get(&type_id) {
            match ty {
                Type::Int32 | Type::UInt32 | Type::Float32 => Some(4),
                Type::Vec2 => Some(8),
                Type::Vec3 => Some(12),
                Type::Vec4 => Some(16),
                Type::Mat3 => Some(48), // Mat3 works like three Vec3 after another, Vec3 has alignment of Vec4
                Type::Mat4 => Some(64),
                Type::Struct { elements, .. } => {
                    // Since there is no Size Decoration in SPIRV that tells us the size,
                    // we calculate it from the offset of the last member and its size.
                    let last_element = elements.iter().max_by_key(|e| e.offset.unwrap_or(0))?;
                    let offset = last_element.offset?;
                    let size = self.get_type_size(last_element.type_id)?;

                    Some(offset + size)
                }
                _ => None,
            }
        } else {
            None
        }
    }

    /// Parses all the Op*Decoration and Op*Name instructions
    fn collect_decorations_and_names(
        ops: &[Op],
        types: &mut HashMap<u32, Type>,
        vars: &mut HashMap<u32, Variable>,
    ) {
        for op in ops {
            match op {
                Op::OpName { target, name } => {
                    if let Some(target) = vars.get_mut(&target.0) {
                        target.name = Some(name.clone());
                    } else if let Some(Type::Struct { name: n, .. }) = types.get_mut(&target.0) {
                        *n = Some(name.clone());
                    }
                }
                Op::OpMemberName {
                    target,
                    member_index,
                    name,
                } => {
                    if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
                        if elements.len() > *member_index as usize {
                            elements[*member_index as usize].name = Some(name.clone());
                        }
                    }
                }
                Op::OpDecorate { target, decoration } => match decoration {
                    ops::Decoration::Binding { binding } => {
                        if let Some(target) = vars.get_mut(&target.0) {
                            target.binding = Some(*binding);
                        }
                    }
                    ops::Decoration::DescriptorSet { set } => {
                        if let Some(target) = vars.get_mut(&target.0) {
                            target.set = Some(*set);
                        }
                    }
                    _ => {}
                },
                Op::OpMemberDecorate {
                    target,
                    member_index,
                    decoration: ops::Decoration::Offset { offset },
                } => {
                    if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
                        if elements.len() > *member_index as usize {
                            elements[*member_index as usize].offset = Some(*offset);
                        }
                    }
                }
                _ => {}
            }
        }
    }

    // Parses all the OpType* and OpVariable instructions
    fn collect_types_and_vars(
        ops: &[Op],
        types: &mut HashMap<u32, Type>,
        constants: &mut HashMap<u32, u32>,
        vars: &mut HashMap<u32, Variable>,
    ) -> SpirvResult<()> {
        for op in ops {
            match op {
                Op::OpTypeVoid { result } => {
                    types.insert(result.0, Type::Void);
                }
                Op::OpTypeBool { result } => {
                    types.insert(result.0, Type::Bool);
                }
                Op::OpTypeInt {
                    result,
                    width,
                    signed,
                } => {
                    if *width != 32 {
                        types.insert(result.0, Type::Unknown);
                    } else if *signed == 0 {
                        types.insert(result.0, Type::UInt32);
                    } else {
                        types.insert(result.0, Type::Int32);
                    }
                }
                Op::OpTypeFloat { result, width } => {
                    if *width == 32 {
                        types.insert(result.0, Type::Float32);
                    } else {
                        types.insert(result.0, Type::Unknown);
                    }
                }
                Op::OpTypeVector {
                    result,
                    component_type,
                    component_count,
                } => {
                    if let Some(t) = types.get(&component_type.0) {
                        if let Type::Float32 = t {
                            match component_count {
                                2 => {
                                    types.insert(result.0, Type::Vec2);
                                }
                                3 => {
                                    types.insert(result.0, Type::Vec3);
                                }
                                4 => {
                                    types.insert(result.0, Type::Vec4);
                                }
                                _ => {
                                    types.insert(result.0, Type::Unknown);
                                }
                            }
                        } else {
                            types.insert(result.0, Type::Unknown);
                        }
                    } else {
                        return Err(Error::InvalidId);
                    }
                }
                Op::OpTypeMatrix {
                    result,
                    column_type,
                    column_count,
                } => {
                    let t = types
                        .get(&column_type.0)
                        .map(|column_type| match column_type {
                            Type::Vec3 if *column_count == 3 => Type::Mat3,
                            Type::Vec4 if *column_count == 4 => Type::Mat4,
                            _ => Type::Unknown,
                        })
                        .unwrap_or(Type::Unknown);
                    types.insert(result.0, t);
                }
                Op::OpTypeImage {
                    result,
                    sampled_type,
                    dim,
                    depth,
                    arrayed: _,
                    ms: _,
                    sampled,
                    format,
                    access: _,
                } => {
                    let t = if let Some(Type::Float32) = types.get(&sampled_type.0) {
                        if let Dim::D2 {} = dim {
                            Type::Image2D {
                                depth: *depth != 0,
                                sampled: *sampled != 0,
                                format: *format,
                            }
                        } else {
                            Type::Unknown
                        }
                    } else {
                        Type::Unknown
                    };
                    types.insert(result.0, t);
                }
                Op::OpTypeSampler { result } => {
                    types.insert(result.0, Type::Sampler);
                }
                Op::OpTypeSampledImage { result, image_type } => {
                    let t = if let Some(Type::Image2D { .. }) = types.get(&image_type.0) {
                        Type::SampledImage {
                            image_type_id: image_type.0,
                        }
                    } else {
                        Type::Unknown
                    };
                    types.insert(result.0, t);
                }
                Op::OpTypeArray {
                    result,
                    element_type,
                    length,
                } => {
                    if let Some(length) = constants.get(&length.0) {
                        types.insert(
                            result.0,
                            Type::Array {
                                element_type_id: element_type.0,
                                length: Some(*length),
                            },
                        );
                    } else {
                        return Err(Error::InvalidId);
                    }
                }
                Op::OpTypeRuntimeArray {
                    result,
                    element_type,
                } => {
                    types.insert(
                        result.0,
                        Type::Array {
                            element_type_id: element_type.0,
                            length: None,
                        },
                    );
                }
                Op::OpTypeStruct {
                    result,
                    element_types,
                } => {
                    types.insert(
                        result.0,
                        Type::Struct {
                            name: None,
                            elements: element_types
                                .iter()
                                .map(|e| StructMember {
                                    name: None,
                                    type_id: e.0,
                                    offset: None,
                                })
                                .collect(),
                        },
                    );
                }
                Op::OpTypePointer {
                    result,
                    storage_class,
                    pointed_type,
                } => {
                    types.insert(
                        result.0,
                        Type::Pointer {
                            storage_class: match storage_class {
                                ops::StorageClass::Unknown => StorageClass::Unknown,
                                ops::StorageClass::UniformConstant {}
                                | ops::StorageClass::Uniform {} => StorageClass::Uniform,
                                ops::StorageClass::PushConstant {} => StorageClass::PushConstant,
                            },
                            pointed_type_id: pointed_type.0,
                        },
                    );
                }
                Op::OpConstant {
                    result_type,
                    result,
                    value,
                } => {
                    if let Some(Type::UInt32) = types.get(&result_type.0) {
                        if value.len() == 1 {
                            constants.insert(result.0, value[0]);
                        }
                    }
                }
                Op::OpVariable {
                    result_type,
                    result,
                    storage_class: _,
                    initializer: _,
                } => {
                    vars.insert(
                        result.0,
                        Variable {
                            set: None,
                            binding: None,
                            type_id: result_type.0,
                            name: None,
                        },
                    );
                }
                _ => {}
            }
        }

        Ok(())
    }
}

/// Represents a type declared in a SPIRV module.
///
/// Types are declared in a hierarchy, with e.g. pointers relying on previously declared types as pointed-to types.
#[derive(Debug)]
#[non_exhaustive]
pub enum Type {
    /// An unsupported type
    Unknown,
    /// The Void type
    Void,
    /// A boolean
    Bool,
    /// A signed 32-Bit integer
    Int32,
    /// An unsigned 32-Bit integer
    UInt32,
    /// A 32-Bit float
    Float32,
    /// A 2 component, 32-Bit vector (GLSL: vec2)
    Vec2,
    /// A 3 component, 32-Bit vector (GLSL: vec3)
    Vec3,
    /// A 4 component, 32-Bit vector (GLSL: vec4)
    Vec4,
    /// A 3x3, 32-Bit Matrix (GLSL: mat3)
    Mat3,
    /// A 4x4, 32-Bit Matrix (GLSL: mat4)
    Mat4,
    /// A 2D image
    Image2D {
        /// true if this image is a depth image
        depth: bool,
        /// true if this image can be sampled from
        sampled: bool,
        /// SPIRV code of the images format (should always be 0 in Vulkan)
        format: u32,
    },
    /// An opaque sampler object
    Sampler,
    /// A combined image and sampler (Vulkan: CombinedImageSampler descriptor)
    SampledImage {
        /// type id of the image contained in the SampledImage
        image_type_id: u32,
    },
    /// Either a static array with known length (`length` is [`Some`]) or dynamic array with unknown length (`length` is [`None`])
    Array {
        /// type id of the contained type
        element_type_id: u32,
        /// length of the array (if known)
        length: Option<u32>,
    },
    /// A struct containing other types
    Struct {
        name: Option<String>,
        /// members of the struct, in the order they appear in the SPIRV module (not necessarily ascending offsets)
        elements: Vec<StructMember>,
    },
    /// A pointer pointing to another type
    Pointer {
        /// The type of storage this pointer points to
        storage_class: StorageClass,
        /// The type id of the pointed-to type
        pointed_type_id: u32,
    },
}

/// Describes a single member of a [`Type::Struct`] type
#[derive(Debug)]
pub struct StructMember {
    pub name: Option<String>,
    pub type_id: u32,
    pub offset: Option<u32>,
}

/// Describes what type of storage a pointer points to
#[derive(Debug)]
#[non_exhaustive]
pub enum StorageClass {
    Unknown,
    /// The pointer is a uniform variable (Uniform blocks)
    Uniform,
    /// The pointer is a uniform variable (Images, etc.)
    UniformConstant,
    /// The pointer is a push constant
    PushConstant,
}

/// Describes a variable declared in a SPIRV module
#[derive(Debug, Clone)]
pub struct Variable {
    /// Which DescriptorSet the variable is contained in (if known)
    pub set: Option<u32>,
    /// Which DescriptorSet binding the variable is contained in (if known)
    pub binding: Option<u32>,
    /// The type id of the variable's [`Type`]
    pub type_id: u32,
    /// The variables name (if known)
    pub name: Option<String>,
}

#[derive(Debug, Clone)]
pub struct PushConstantVariable {
    pub type_id: u32,
    pub name: Option<String>,
}