use crate::{shader::SpecializationConstant, Version};
use ahash::HashMap;
use smallvec::{smallvec, SmallVec};
use std::{
borrow::Cow,
error::Error,
fmt::{Display, Error as FmtError, Formatter},
string::FromUtf8Error,
};
mod specialization;
include!(concat!(env!("OUT_DIR"), "/spirv_parse.rs"));
#[derive(Clone, Debug)]
pub struct Spirv {
version: Version,
bound: u32,
ids: HashMap<Id, IdInfo>,
instructions_capability: Vec<Instruction>,
instructions_extension: Vec<Instruction>,
instructions_ext_inst_import: Vec<Instruction>,
instruction_memory_model: Instruction,
instructions_entry_point: Vec<Instruction>,
instructions_execution_mode: Vec<Instruction>,
instructions_name: Vec<Instruction>,
instructions_decoration: Vec<Instruction>,
instructions_global: Vec<Instruction>,
functions: HashMap<Id, FunctionInfo>,
}
impl Spirv {
pub fn new(words: &[u32]) -> Result<Spirv, SpirvError> {
if words.len() < 5 {
return Err(SpirvError::InvalidHeader);
}
if words[0] != 0x07230203 {
return Err(SpirvError::InvalidHeader);
}
let version = Version {
major: (words[1] & 0x00ff0000) >> 16,
minor: (words[1] & 0x0000ff00) >> 8,
patch: words[1] & 0x000000ff,
};
let mut bound = 0;
let mut ids = HashMap::default();
let mut instructions_capability = Vec::new();
let mut instructions_extension = Vec::new();
let mut instructions_ext_inst_import = Vec::new();
let mut instructions_memory_model = Vec::new();
let mut instructions_entry_point = Vec::new();
let mut instructions_execution_mode = Vec::new();
let mut instructions_name = Vec::new();
let mut instructions_decoration = Vec::new();
let mut instructions_global = Vec::new();
let mut functions = HashMap::default();
let mut current_function: Option<&mut Vec<Instruction>> = None;
for instruction in iter_instructions(&words[5..]) {
let instruction = instruction?;
if let Some(id) = instruction.result_id() {
bound = bound.max(u32::from(id) + 1);
let members = if let Instruction::TypeStruct {
ref member_types, ..
} = instruction
{
member_types
.iter()
.map(|_| StructMemberInfo::default())
.collect()
} else {
Vec::new()
};
let data = IdInfo {
instruction: instruction.clone(),
names: Vec::new(),
decorations: Vec::new(),
members,
};
if ids.insert(id, data).is_some() {
return Err(SpirvError::DuplicateId { id });
}
}
if matches!(
instruction,
Instruction::Line { .. } | Instruction::NoLine { .. }
) {
continue;
}
if current_function.is_some() {
match instruction {
Instruction::FunctionEnd { .. } => {
current_function.take().unwrap().push(instruction);
}
_ => current_function.as_mut().unwrap().push(instruction),
}
} else {
let destination = match instruction {
Instruction::Function { result_id, .. } => {
current_function = None;
let function = functions.entry(result_id).or_insert_with(|| {
let entry_point = instructions_entry_point
.iter()
.find(|instruction| {
matches!(
**instruction,
Instruction::EntryPoint { entry_point, .. }
if entry_point == result_id
)
})
.cloned();
let execution_modes = instructions_execution_mode
.iter()
.filter(|instruction| {
matches!(
**instruction,
Instruction::ExecutionMode { entry_point, .. }
| Instruction::ExecutionModeId { entry_point, .. }
if entry_point == result_id
)
})
.cloned()
.collect();
FunctionInfo {
instructions: Vec::new(),
entry_point,
execution_modes,
}
});
current_function.insert(&mut function.instructions)
}
Instruction::Capability { .. } => &mut instructions_capability,
Instruction::Extension { .. } => &mut instructions_extension,
Instruction::ExtInstImport { .. } => &mut instructions_ext_inst_import,
Instruction::MemoryModel { .. } => &mut instructions_memory_model,
Instruction::EntryPoint { .. } => &mut instructions_entry_point,
Instruction::ExecutionMode { .. } | Instruction::ExecutionModeId { .. } => {
&mut instructions_execution_mode
}
Instruction::Name { .. } | Instruction::MemberName { .. } => {
&mut instructions_name
}
Instruction::Decorate { .. }
| Instruction::MemberDecorate { .. }
| Instruction::DecorationGroup { .. }
| Instruction::GroupDecorate { .. }
| Instruction::GroupMemberDecorate { .. }
| Instruction::DecorateId { .. }
| Instruction::DecorateString { .. }
| Instruction::MemberDecorateString { .. } => &mut instructions_decoration,
Instruction::TypeVoid { .. }
| Instruction::TypeBool { .. }
| Instruction::TypeInt { .. }
| Instruction::TypeFloat { .. }
| Instruction::TypeVector { .. }
| Instruction::TypeMatrix { .. }
| Instruction::TypeImage { .. }
| Instruction::TypeSampler { .. }
| Instruction::TypeSampledImage { .. }
| Instruction::TypeArray { .. }
| Instruction::TypeRuntimeArray { .. }
| Instruction::TypeStruct { .. }
| Instruction::TypeOpaque { .. }
| Instruction::TypePointer { .. }
| Instruction::TypeFunction { .. }
| Instruction::TypeEvent { .. }
| Instruction::TypeDeviceEvent { .. }
| Instruction::TypeReserveId { .. }
| Instruction::TypeQueue { .. }
| Instruction::TypePipe { .. }
| Instruction::TypeForwardPointer { .. }
| Instruction::TypePipeStorage { .. }
| Instruction::TypeNamedBarrier { .. }
| Instruction::TypeRayQueryKHR { .. }
| Instruction::TypeAccelerationStructureKHR { .. }
| Instruction::TypeCooperativeMatrixNV { .. }
| Instruction::TypeVmeImageINTEL { .. }
| Instruction::TypeAvcImePayloadINTEL { .. }
| Instruction::TypeAvcRefPayloadINTEL { .. }
| Instruction::TypeAvcSicPayloadINTEL { .. }
| Instruction::TypeAvcMcePayloadINTEL { .. }
| Instruction::TypeAvcMceResultINTEL { .. }
| Instruction::TypeAvcImeResultINTEL { .. }
| Instruction::TypeAvcImeResultSingleReferenceStreamoutINTEL { .. }
| Instruction::TypeAvcImeResultDualReferenceStreamoutINTEL { .. }
| Instruction::TypeAvcImeSingleReferenceStreaminINTEL { .. }
| Instruction::TypeAvcImeDualReferenceStreaminINTEL { .. }
| Instruction::TypeAvcRefResultINTEL { .. }
| Instruction::TypeAvcSicResultINTEL { .. }
| Instruction::ConstantTrue { .. }
| Instruction::ConstantFalse { .. }
| Instruction::Constant { .. }
| Instruction::ConstantComposite { .. }
| Instruction::ConstantSampler { .. }
| Instruction::ConstantNull { .. }
| Instruction::ConstantPipeStorage { .. }
| Instruction::SpecConstantTrue { .. }
| Instruction::SpecConstantFalse { .. }
| Instruction::SpecConstant { .. }
| Instruction::SpecConstantComposite { .. }
| Instruction::SpecConstantOp { .. }
| Instruction::Variable { .. }
| Instruction::Undef { .. } => &mut instructions_global,
_ => continue,
};
destination.push(instruction);
}
}
let instruction_memory_model = instructions_memory_model.drain(..).next().unwrap();
let mut decoration_groups: HashMap<Id, Vec<Instruction>> = HashMap::default();
let instructions_decoration = instructions_decoration
.into_iter()
.flat_map(|instruction| -> SmallVec<[Instruction; 1]> {
match instruction {
Instruction::Decorate { target, .. }
| Instruction::DecorateId { target, .. }
| Instruction::DecorateString { target, .. } => {
let id_info = ids.get_mut(&target).unwrap();
if matches!(id_info.instruction(), Instruction::DecorationGroup { .. }) {
decoration_groups
.entry(target)
.or_default()
.push(instruction);
smallvec![]
} else {
id_info.decorations.push(instruction.clone());
smallvec![instruction]
}
}
Instruction::MemberDecorate {
structure_type: target,
member,
..
}
| Instruction::MemberDecorateString {
struct_type: target,
member,
..
} => {
ids.get_mut(&target).unwrap().members[member as usize]
.decorations
.push(instruction.clone());
smallvec![instruction]
}
Instruction::DecorationGroup { result_id } => {
decoration_groups.entry(result_id).or_default();
ids.remove(&result_id);
smallvec![]
}
Instruction::GroupDecorate {
decoration_group,
ref targets,
} => {
let decorations = &decoration_groups[&decoration_group];
(targets.iter().copied())
.flat_map(|target| {
decorations
.iter()
.map(move |instruction| (target, instruction))
})
.map(|(target, instruction)| {
let id_info = ids.get_mut(&target).unwrap();
match instruction {
Instruction::Decorate { ref decoration, .. } => {
let instruction = Instruction::Decorate {
target,
decoration: decoration.clone(),
};
id_info.decorations.push(instruction.clone());
instruction
}
Instruction::DecorateId { ref decoration, .. } => {
let instruction = Instruction::DecorateId {
target,
decoration: decoration.clone(),
};
id_info.decorations.push(instruction.clone());
instruction
}
_ => unreachable!(),
}
})
.collect()
}
Instruction::GroupMemberDecorate {
decoration_group,
ref targets,
} => {
let decorations = &decoration_groups[&decoration_group];
(targets.iter().copied())
.flat_map(|target| {
decorations
.iter()
.map(move |instruction| (target, instruction))
})
.map(|((structure_type, member), instruction)| {
let member_info =
&mut ids.get_mut(&structure_type).unwrap().members
[member as usize];
match instruction {
Instruction::Decorate { ref decoration, .. } => {
let instruction = Instruction::MemberDecorate {
structure_type,
member,
decoration: decoration.clone(),
};
member_info.decorations.push(instruction.clone());
instruction
}
Instruction::DecorateId { .. } => {
panic!(
"a DecorateId instruction targets a decoration group, \
and that decoration group is applied using a \
GroupMemberDecorate instruction, but there is no \
MemberDecorateId instruction"
);
}
_ => unreachable!(),
}
})
.collect()
}
_ => smallvec![instruction],
}
})
.collect();
instructions_name.retain(|instruction| match *instruction {
Instruction::Name { target, .. } => {
if let Some(id_info) = ids.get_mut(&target) {
id_info.names.push(instruction.clone());
true
} else {
false
}
}
Instruction::MemberName { ty, member, .. } => {
if let Some(id_info) = ids.get_mut(&ty) {
id_info.members[member as usize]
.names
.push(instruction.clone());
true
} else {
false
}
}
_ => unreachable!(),
});
Ok(Spirv {
version,
bound,
ids,
instructions_capability,
instructions_extension,
instructions_ext_inst_import,
instruction_memory_model,
instructions_entry_point,
instructions_execution_mode,
instructions_name,
instructions_decoration,
instructions_global,
functions,
})
}
#[inline]
pub fn version(&self) -> Version {
self.version
}
#[inline]
pub fn id(&self, id: Id) -> &IdInfo {
&self.ids[&id]
}
#[inline]
pub fn function(&self, id: Id) -> &FunctionInfo {
&self.functions[&id]
}
#[inline]
pub fn iter_capability(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_capability.iter()
}
#[inline]
pub fn iter_extension(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_extension.iter()
}
#[inline]
pub fn iter_ext_inst_import(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_ext_inst_import.iter()
}
#[inline]
pub fn memory_model(&self) -> &Instruction {
&self.instruction_memory_model
}
#[inline]
pub fn iter_entry_point(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_entry_point.iter()
}
#[inline]
pub fn iter_execution_mode(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_execution_mode.iter()
}
#[inline]
pub fn iter_name(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_name.iter()
}
#[inline]
pub fn iter_decoration(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_decoration.iter()
}
#[inline]
pub fn iter_global(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions_global.iter()
}
#[inline]
pub fn iter_functions(&self) -> impl ExactSizeIterator<Item = &FunctionInfo> {
self.functions.values()
}
pub fn apply_specialization(
&mut self,
specialization_info: &HashMap<u32, SpecializationConstant>,
) {
self.instructions_global = specialization::replace_specialization_instructions(
specialization_info,
self.instructions_global.drain(..),
&self.ids,
self.bound,
);
for instruction in &self.instructions_global {
if let Some(id) = instruction.result_id() {
if let Some(id_info) = self.ids.get_mut(&id) {
id_info.instruction = instruction.clone();
id_info.decorations.retain(|instruction| {
!matches!(
instruction,
Instruction::Decorate {
decoration: Decoration::SpecId { .. },
..
}
)
});
} else {
self.ids.insert(
id,
IdInfo {
instruction: instruction.clone(),
names: Vec::new(),
decorations: Vec::new(),
members: Vec::new(),
},
);
self.bound = self.bound.max(u32::from(id) + 1);
}
}
}
self.instructions_decoration.retain(|instruction| {
!matches!(
instruction,
Instruction::Decorate {
decoration: Decoration::SpecId { .. },
..
}
)
});
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Id(u32);
impl Id {
#[inline]
pub const fn as_raw(self) -> u32 {
self.0
}
}
impl From<Id> for u32 {
#[inline]
fn from(id: Id) -> u32 {
id.as_raw()
}
}
impl Display for Id {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
write!(f, "%{}", self.0)
}
}
#[derive(Clone, Debug)]
pub struct IdInfo {
instruction: Instruction,
names: Vec<Instruction>,
decorations: Vec<Instruction>,
members: Vec<StructMemberInfo>,
}
impl IdInfo {
#[inline]
pub fn instruction(&self) -> &Instruction {
&self.instruction
}
#[inline]
pub fn iter_name(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.names.iter()
}
#[inline]
pub fn iter_decoration(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.decorations.iter()
}
#[inline]
pub fn iter_members(&self) -> impl ExactSizeIterator<Item = &StructMemberInfo> {
self.members.iter()
}
}
#[derive(Clone, Debug, Default)]
pub struct StructMemberInfo {
names: Vec<Instruction>,
decorations: Vec<Instruction>,
}
impl StructMemberInfo {
#[inline]
pub fn iter_name(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.names.iter()
}
#[inline]
pub fn iter_decoration(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.decorations.iter()
}
}
#[derive(Clone, Debug)]
pub struct FunctionInfo {
instructions: Vec<Instruction>,
entry_point: Option<Instruction>,
execution_modes: Vec<Instruction>,
}
impl FunctionInfo {
#[inline]
pub fn iter_instructions(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions.iter()
}
#[inline]
pub fn entry_point(&self) -> Option<&Instruction> {
self.entry_point.as_ref()
}
#[inline]
pub fn iter_execution_mode(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.execution_modes.iter()
}
}
fn iter_instructions(
mut words: &[u32],
) -> impl Iterator<Item = Result<Instruction, ParseError>> + '_ {
let mut index = 0;
let next = move || -> Option<Result<Instruction, ParseError>> {
if words.is_empty() {
return None;
}
let word_count = (words[0] >> 16) as usize;
assert!(word_count >= 1);
if words.len() < word_count {
return Some(Err(ParseError {
instruction: index,
word: words.len(),
error: ParseErrors::UnexpectedEOF,
words: words.to_owned(),
}));
}
let mut reader = InstructionReader::new(&words[0..word_count], index);
let instruction = match Instruction::parse(&mut reader) {
Ok(x) => x,
Err(err) => return Some(Err(err)),
};
if !reader.is_empty() {
return Some(Err(reader.map_err(ParseErrors::LeftoverOperands)));
}
words = &words[word_count..];
index += 1;
Some(Ok(instruction))
};
std::iter::from_fn(next)
}
#[derive(Debug)]
struct InstructionReader<'a> {
words: &'a [u32],
next_word: usize,
instruction: usize,
}
impl<'a> InstructionReader<'a> {
fn new(words: &'a [u32], instruction: usize) -> Self {
debug_assert!(!words.is_empty());
Self {
words,
next_word: 0,
instruction,
}
}
fn is_empty(&self) -> bool {
self.next_word >= self.words.len()
}
fn map_err(&self, error: ParseErrors) -> ParseError {
ParseError {
instruction: self.instruction,
word: self.next_word - 1, error,
words: self.words.to_owned(),
}
}
fn next_word(&mut self) -> Result<u32, ParseError> {
let word = *self.words.get(self.next_word).ok_or(ParseError {
instruction: self.instruction,
word: self.next_word, error: ParseErrors::MissingOperands,
words: self.words.to_owned(),
})?;
self.next_word += 1;
Ok(word)
}
fn next_string(&mut self) -> Result<String, ParseError> {
let mut bytes = Vec::new();
loop {
let word = self.next_word()?.to_le_bytes();
if let Some(nul) = word.iter().position(|&b| b == 0) {
bytes.extend(&word[0..nul]);
break;
} else {
bytes.extend(word);
}
}
String::from_utf8(bytes).map_err(|err| self.map_err(ParseErrors::FromUtf8Error(err)))
}
fn remainder(&mut self) -> Vec<u32> {
let vec = self.words[self.next_word..].to_owned();
self.next_word = self.words.len();
vec
}
}
#[derive(Clone, Debug)]
pub enum SpirvError {
DuplicateId { id: Id },
InvalidHeader,
ParseError(ParseError),
}
impl Display for SpirvError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
match self {
Self::DuplicateId { id } => write!(f, "id {} is assigned more than once", id,),
Self::InvalidHeader => write!(f, "the SPIR-V module header is invalid"),
Self::ParseError(_) => write!(f, "parse error"),
}
}
}
impl Error for SpirvError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::ParseError(err) => Some(err),
_ => None,
}
}
}
impl From<ParseError> for SpirvError {
fn from(err: ParseError) -> Self {
Self::ParseError(err)
}
}
#[derive(Clone, Debug)]
pub struct ParseError {
pub instruction: usize,
pub word: usize,
pub error: ParseErrors,
pub words: Vec<u32>,
}
impl Display for ParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
write!(
f,
"at instruction {}, word {}: {}",
self.instruction, self.word, self.error,
)
}
}
impl Error for ParseError {}
#[derive(Clone, Debug)]
pub enum ParseErrors {
FromUtf8Error(FromUtf8Error),
LeftoverOperands,
MissingOperands,
UnexpectedEOF,
UnknownEnumerant(&'static str, u32),
UnknownOpcode(u16),
UnknownSpecConstantOpcode(u16),
}
impl Display for ParseErrors {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
match self {
Self::FromUtf8Error(_) => write!(f, "invalid UTF-8 in string literal"),
Self::LeftoverOperands => write!(f, "unparsed operands remaining"),
Self::MissingOperands => write!(
f,
"the instruction and its operands require more words than are present in the \
instruction",
),
Self::UnexpectedEOF => write!(f, "encountered unexpected end of file"),
Self::UnknownEnumerant(ty, enumerant) => {
write!(f, "invalid enumerant {} for enum {}", enumerant, ty)
}
Self::UnknownOpcode(opcode) => write!(f, "invalid instruction opcode {}", opcode),
Self::UnknownSpecConstantOpcode(opcode) => {
write!(f, "invalid spec constant instruction opcode {}", opcode)
}
}
}
}
pub fn bytes_to_words(bytes: &[u8]) -> Result<Cow<'_, [u32]>, SpirvBytesNotMultipleOf4> {
#[cfg(target_endian = "little")]
if let Ok(words) = bytemuck::try_cast_slice(bytes) {
return Ok(Cow::Borrowed(words));
}
if bytes.len() % 4 != 0 {
return Err(SpirvBytesNotMultipleOf4);
}
let words: Vec<u32> = bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect();
Ok(Cow::Owned(words))
}
#[derive(Clone, Copy, Debug, Default)]
pub struct SpirvBytesNotMultipleOf4;
impl Error for SpirvBytesNotMultipleOf4 {}
impl Display for SpirvBytesNotMultipleOf4 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "the length of the provided slice is not a multiple of 4")
}
}