use crate::error;
use crate::Compiler;
use spirv::StorageClass;
use spirv_cross_sys::{BaseType, SpvId, VariableId};
use crate::error::{SpirvCrossError, ToContextError};
use crate::handle::Handle;
use crate::handle::{ConstantId, TypeId};
use crate::sealed::Sealed;
use crate::string::CompilerStr;
use spirv_cross_sys as sys;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum ScalarKind {
Int = 0,
Uint = 1,
Float = 2,
Bool = 3,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum BitWidth {
Bit = 1,
Byte = 8,
HalfWord = 16,
Word = 32,
DoubleWord = 64,
}
impl BitWidth {
pub const fn byte_size(&self) -> usize {
match self {
BitWidth::Bit => 1,
BitWidth::Byte => 1,
BitWidth::HalfWord => 2,
BitWidth::Word => 4,
BitWidth::DoubleWord => 8,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Scalar {
pub kind: ScalarKind,
pub size: BitWidth,
}
impl TryFrom<BaseType> for Scalar {
type Error = SpirvCrossError;
fn try_from(value: BaseType) -> Result<Self, Self::Error> {
Ok(match value {
BaseType::Boolean => Scalar {
kind: ScalarKind::Bool,
size: BitWidth::Bit,
},
BaseType::Int8 => Scalar {
kind: ScalarKind::Int,
size: BitWidth::Byte,
},
BaseType::Int16 => Scalar {
kind: ScalarKind::Int,
size: BitWidth::HalfWord,
},
BaseType::Int32 => Scalar {
kind: ScalarKind::Int,
size: BitWidth::Word,
},
BaseType::Int64 => Scalar {
kind: ScalarKind::Int,
size: BitWidth::DoubleWord,
},
BaseType::Uint8 => Scalar {
kind: ScalarKind::Uint,
size: BitWidth::Byte,
},
BaseType::Uint16 => Scalar {
kind: ScalarKind::Uint,
size: BitWidth::HalfWord,
},
BaseType::Uint32 => Scalar {
kind: ScalarKind::Uint,
size: BitWidth::Word,
},
BaseType::Uint64 => Scalar {
kind: ScalarKind::Uint,
size: BitWidth::DoubleWord,
},
BaseType::Fp16 => Scalar {
kind: ScalarKind::Float,
size: BitWidth::HalfWord,
},
BaseType::Fp32 => Scalar {
kind: ScalarKind::Float,
size: BitWidth::Word,
},
BaseType::Fp64 => Scalar {
kind: ScalarKind::Float,
size: BitWidth::DoubleWord,
},
_ => {
return Err(SpirvCrossError::InvalidArgument(String::from(
"Invalid base type used to instantiate a scalar",
)))
}
})
}
}
#[derive(Debug, Clone)]
pub struct Type<'a> {
pub id: Handle<TypeId>,
pub name: Option<CompilerStr<'a>>,
pub inner: TypeInner<'a>,
pub size_hint: TypeSizeHint,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct StructMember<'a> {
pub id: Handle<TypeId>,
pub struct_type: Handle<TypeId>,
pub name: Option<CompilerStr<'a>>,
pub index: usize,
pub offset: u32,
pub size: usize,
pub matrix_stride: Option<u32>,
pub array_stride: Option<u32>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct StructType<'a> {
pub id: Handle<TypeId>,
pub size: usize,
pub members: Vec<StructMember<'a>>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ArrayDimension {
Literal(u32),
Constant(Handle<ConstantId>),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ImageClass {
Sampled {
depth: bool,
multisampled: bool,
arrayed: bool,
},
Texture {
multisampled: bool,
arrayed: bool,
},
Storage {
format: spirv::ImageFormat,
},
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ImageType {
pub id: Handle<TypeId>,
pub sampled_type: Handle<TypeId>,
pub dimension: spirv::Dim,
pub class: ImageClass,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TypeInner<'a> {
Unknown,
Void,
Pointer {
base: Handle<TypeId>,
storage: StorageClass,
forward: bool,
},
Struct(StructType<'a>),
Scalar(Scalar),
Vector {
width: u32,
scalar: Scalar,
},
Matrix {
columns: u32,
rows: u32,
scalar: Scalar,
},
Array {
base: Handle<TypeId>,
storage: StorageClass,
dimensions: Vec<ArrayDimension>,
stride: Option<u32>,
},
Image(ImageType),
AccelerationStructure,
Sampler,
}
#[derive(Debug, Clone)]
pub struct MatrixStrideHole {
columns: usize,
rows: usize,
declared: usize,
}
impl Sealed for MatrixStrideHole {}
impl ResolveSize for MatrixStrideHole {
type Hole = (usize, bool);
fn declared(&self) -> usize {
self.declared
}
fn resolve(&self, hole: Self::Hole) -> usize {
let (stride, is_row_major) = hole;
if is_row_major {
stride * self.rows
} else {
stride * self.columns
}
}
}
#[derive(Debug, Clone)]
pub struct ArraySizeHole {
stride: usize,
declared: usize,
}
#[derive(Debug, Clone)]
pub struct UnknownStrideHole {
hint: Box<TypeSizeHint>,
count: usize,
}
impl Sealed for UnknownStrideHole {}
impl ResolveSize for UnknownStrideHole {
type Hole = Box<dyn FnOnce(&TypeSizeHint) -> usize>;
fn declared(&self) -> usize {
self.count * self.hint.declared()
}
fn resolve(&self, hole: Self::Hole) -> usize {
self.count * hole(&self.hint)
}
}
impl ResolveSize for usize {
type Hole = core::convert::Infallible;
fn declared(&self) -> usize {
*self
}
fn resolve(&self, _hole: Self::Hole) -> usize {
self.declared()
}
}
impl ResolveSize for ArraySizeHole {
type Hole = usize;
fn declared(&self) -> usize {
self.declared
}
fn resolve(&self, count: Self::Hole) -> usize {
count * self.stride
}
}
impl Sealed for ArraySizeHole {}
impl Sealed for usize {}
#[derive(Debug, Clone)]
pub enum TypeSizeHint {
Static(usize),
RuntimeArray(ArraySizeHole),
Matrix(MatrixStrideHole),
UnknownArrayStride(UnknownStrideHole),
}
impl TypeSizeHint {
pub fn declared(&self) -> usize {
match &self {
TypeSizeHint::Static(sz) => *sz,
TypeSizeHint::RuntimeArray(hole) => hole.declared(),
TypeSizeHint::UnknownArrayStride(hole) => hole.declared(),
TypeSizeHint::Matrix(hole) => hole.declared(),
}
}
pub fn is_static(&self) -> bool {
matches!(self, TypeSizeHint::Static(_))
}
}
pub trait ResolveSize: Sealed {
type Hole;
fn declared(&self) -> usize;
fn resolve(&self, hole: Self::Hole) -> usize;
}
impl<T> Compiler<T> {
fn process_struct(&self, struct_ty_id: TypeId) -> error::Result<StructType<'_>> {
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), struct_ty_id);
let base_ty = sys::spvc_type_get_basetype(ty);
assert_eq!(base_ty, BaseType::Struct);
let mut struct_size = 0;
sys::spvc_compiler_get_declared_struct_size(self.ptr.as_ptr(), ty, &mut struct_size)
.ok(self)?;
let member_type_len = sys::spvc_type_get_num_member_types(ty);
let mut members = Vec::with_capacity(member_type_len as usize);
for i in 0..member_type_len {
let id = sys::spvc_type_get_member_type(ty, i);
let name = CompilerStr::from_ptr(
sys::spvc_compiler_get_member_name(self.ptr.as_ptr(), struct_ty_id, i),
self.ctx.drop_guard(),
);
let name = if name.as_ref().is_empty() {
None
} else {
Some(name)
};
let mut size = 0;
sys::spvc_compiler_get_declared_struct_member_size(
self.ptr.as_ptr(),
ty,
i,
&mut size,
)
.ok(self)?;
let mut offset = 0;
sys::spvc_compiler_type_struct_member_offset(self.ptr.as_ptr(), ty, i, &mut offset)
.ok(self)?;
let mut matrix_stride = 0;
let matrix_stride = sys::spvc_compiler_type_struct_member_matrix_stride(
self.ptr.as_ptr(),
ty,
i,
&mut matrix_stride,
)
.ok(self)
.ok()
.map(|_| matrix_stride);
let mut array_stride = 0;
let array_stride = sys::spvc_compiler_type_struct_member_array_stride(
self.ptr.as_ptr(),
ty,
i,
&mut array_stride,
)
.ok(self)
.ok()
.map(|_| array_stride);
members.push(StructMember {
name,
id: self.create_handle(id),
struct_type: self.create_handle(struct_ty_id),
offset,
size,
index: i as usize,
matrix_stride,
array_stride,
})
}
Ok(StructType {
id: self.create_handle(struct_ty_id),
size: struct_size,
members,
})
}
}
fn process_vector(&self, id: TypeId, vec_width: u32) -> error::Result<TypeInner<'_>> {
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
let base_ty = sys::spvc_type_get_basetype(ty);
Ok(TypeInner::Vector {
width: vec_width,
scalar: base_ty.try_into()?,
})
}
}
fn process_matrix(&self, id: TypeId, rows: u32, columns: u32) -> error::Result<TypeInner<'_>> {
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
let base_ty = sys::spvc_type_get_basetype(ty);
Ok(TypeInner::Matrix {
rows,
columns,
scalar: base_ty.try_into()?,
})
}
}
fn process_array<'a>(
&self,
id: TypeId,
name: Option<CompilerStr<'a>>,
) -> error::Result<Type<'a>> {
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
let base_type_id = sys::spvc_type_get_base_type_id(ty);
let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
let mut array_dims = Vec::with_capacity(array_dim_len as usize);
for i in 0..array_dim_len {
array_dims.push(sys::spvc_type_get_array_dimension(ty, i))
}
let mut array_is_literal = Vec::with_capacity(array_dim_len as usize);
for i in 0..array_dim_len {
array_is_literal.push(sys::spvc_type_array_dimension_is_literal(ty, i))
}
let storage_class = sys::spvc_type_get_storage_class(ty);
let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32) else {
return Err(SpirvCrossError::InvalidSpirv(format!(
"Unknown StorageClass found: {}",
storage_class.0
)));
};
let array_dims = array_dims
.into_iter()
.enumerate()
.map(|(index, dim)| {
if array_is_literal[index] {
ArrayDimension::Literal(dim.0)
} else {
ArrayDimension::Constant(self.create_handle(ConstantId(dim)))
}
})
.collect();
let id = self.create_handle(id);
let stride = self
.decoration(id, spirv::Decoration::ArrayStride)?
.and_then(|s| s.as_literal());
let inner = TypeInner::Array {
base: self.create_handle(base_type_id),
storage: storage_class,
dimensions: array_dims,
stride,
};
let size_hint = self.type_size_hint(&inner)?;
Ok(Type {
name,
id,
inner,
size_hint,
})
}
}
fn process_image(&self, id: TypeId) -> error::Result<ImageType> {
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
let base_ty = sys::spvc_type_get_basetype(ty);
let sampled_id = sys::spvc_type_get_image_sampled_type(ty);
let dimension = sys::spvc_type_get_image_dimension(ty);
let depth = sys::spvc_type_get_image_is_depth(ty);
let arrayed = sys::spvc_type_get_image_arrayed(ty);
let storage = sys::spvc_type_get_image_is_storage(ty);
let multisampled = sys::spvc_type_get_image_multisampled(ty);
let format = sys::spvc_type_get_image_storage_format(ty);
let Some(format) = spirv::ImageFormat::from_u32(format.0 as u32) else {
return Err(SpirvCrossError::InvalidSpirv(format!(
"Unknown image format found: {}",
format.0
)));
};
let Some(dimension) = spirv::Dim::from_u32(dimension.0 as u32) else {
return Err(SpirvCrossError::InvalidSpirv(format!(
"Unknown image dimension found: {}",
dimension.0
)));
};
let class = if storage {
ImageClass::Storage { format }
} else if base_ty == BaseType::SampledImage {
ImageClass::Sampled {
depth,
multisampled,
arrayed,
}
} else {
ImageClass::Texture {
multisampled,
arrayed,
}
};
Ok(ImageType {
id: self.create_handle(id),
sampled_type: self.create_handle(sampled_id),
dimension,
class,
})
}
}
pub fn type_description(&self, id: Handle<TypeId>) -> error::Result<Type<'_>> {
let id = self.yield_id(id)?;
unsafe {
let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
let base_type_id = sys::spvc_type_get_base_type_id(ty);
let base_ty = sys::spvc_type_get_basetype(ty);
let name = CompilerStr::from_ptr(
sys::spvc_compiler_get_name(self.ptr.as_ptr(), id.0),
self.ctx.drop_guard(),
);
let name = if name.as_ref().is_empty() {
None
} else {
Some(name)
};
let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
if array_dim_len != 0 {
return self.process_array(id, name);
}
if sys::spvc_rs_type_is_pointer(ty) {
let storage_class = sys::spvc_type_get_storage_class(ty);
let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
else {
return Err(SpirvCrossError::InvalidSpirv(format!(
"Unknown StorageClass found: {}",
storage_class.0
)));
};
let forward = sys::spvc_rs_type_is_forward_pointer(ty);
let inner = TypeInner::Pointer {
base: self.create_handle(base_type_id),
storage: storage_class,
forward,
};
let size_hint = self.type_size_hint(&inner)?;
return Ok(Type {
name,
id: self.create_handle(id),
inner,
size_hint,
});
}
let vec_size = sys::spvc_type_get_vector_size(ty);
let columns = sys::spvc_type_get_columns(ty);
let mut maybe_non_scalar = None;
if vec_size > 1 && columns == 1 {
maybe_non_scalar = Some(self.process_vector(id, vec_size)?);
}
if vec_size > 1 && columns > 1 {
maybe_non_scalar = Some(self.process_matrix(id, vec_size, columns)?);
}
let inner = match base_ty {
BaseType::Struct => {
let ty = self.process_struct(id)?;
TypeInner::Struct(ty)
}
BaseType::Image | BaseType::SampledImage => {
TypeInner::Image(self.process_image(id)?)
}
BaseType::Sampler => TypeInner::Sampler,
BaseType::Boolean
| BaseType::Int8
| BaseType::Uint8
| BaseType::Int16
| BaseType::Uint16
| BaseType::Int32
| BaseType::Uint32
| BaseType::Int64
| BaseType::Uint64
| BaseType::Fp16
| BaseType::Fp32
| BaseType::Fp64 => {
if let Some(prep) = maybe_non_scalar {
prep
} else {
TypeInner::Scalar(base_ty.try_into()?)
}
}
BaseType::Unknown => TypeInner::Unknown,
BaseType::Void => TypeInner::Void,
BaseType::AtomicCounter => {
let storage_class = sys::spvc_type_get_storage_class(ty);
let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
else {
return Err(SpirvCrossError::InvalidSpirv(format!(
"Unknown StorageClass found: {}",
storage_class.0
)));
};
let forward = sys::spvc_rs_type_is_forward_pointer(ty);
TypeInner::Pointer {
base: self.create_handle(base_type_id),
storage: storage_class,
forward,
}
}
BaseType::AccelerationStructure => TypeInner::AccelerationStructure,
};
let size_hint = self.type_size_hint(&inner)?;
let ty = Type {
name,
id: self.create_handle(id),
inner,
size_hint,
};
Ok(ty)
}
}
fn type_size_hint(&self, ty: &TypeInner) -> error::Result<TypeSizeHint> {
Ok(match ty {
TypeInner::Pointer { .. } => TypeSizeHint::Static(BitWidth::Word.byte_size()),
TypeInner::Struct(s) => {
if let Some(stride) = self.struct_has_runtime_array(s)? {
TypeSizeHint::RuntimeArray(ArraySizeHole {
stride: stride as usize,
declared: s.size,
})
} else {
TypeSizeHint::Static(s.size)
}
}
TypeInner::Scalar(s) => TypeSizeHint::Static(s.size.byte_size()),
TypeInner::Vector { width, scalar } => {
TypeSizeHint::Static((*width as usize) * scalar.size.byte_size())
}
TypeInner::Matrix {
columns,
rows,
scalar,
} => {
let rows_aligned = ((rows + 3) & !0x3) as usize;
let scalar_width = scalar.size.byte_size();
let columns = *columns as usize;
let declared = rows_aligned * scalar_width * columns;
TypeSizeHint::Matrix(MatrixStrideHole {
columns,
rows: *rows as usize,
declared,
})
}
TypeInner::Array {
dimensions,
stride,
base,
..
} => {
let mut count = 1usize;
for dim in dimensions.iter() {
match dim {
ArrayDimension::Literal(a) => count *= *a as usize,
ArrayDimension::Constant(c) => {
let value = self.specialization_constant_value::<u32>(*c)?;
count *= value as usize;
} }
}
if let Some(stride) = stride {
TypeSizeHint::Static(count * (*stride as usize))
} else {
let base_stride = self.type_description(*base)?.size_hint;
if base_stride.is_static() {
TypeSizeHint::Static(count * base_stride.declared())
} else {
TypeSizeHint::UnknownArrayStride(UnknownStrideHole {
hint: Box::new(base_stride),
count,
})
}
}
}
TypeInner::Image(_)
| TypeInner::AccelerationStructure
| TypeInner::Sampler
| TypeInner::Unknown
| TypeInner::Void => TypeSizeHint::Static(0),
})
}
fn struct_has_runtime_array(&self, struct_type: &StructType) -> error::Result<Option<u32>> {
if let Some(last) = struct_type.members.last() {
let Some(array_stride) = last.array_stride else {
return Ok(None);
};
let inner = self.type_description(last.id)?.inner;
if let TypeInner::Array { dimensions, .. } = inner {
if let Some(ArrayDimension::Literal(0)) = dimensions.first() {
return Ok(Some(array_stride));
}
}
}
Ok(None)
}
pub fn variable_type(
&self,
variable: impl Into<Handle<VariableId>>,
) -> error::Result<Handle<TypeId>> {
let variable = variable.into();
let variable_id = self.yield_id(variable)?;
unsafe {
let mut type_id = TypeId(SpvId(0));
sys::spvc_rs_compiler_variable_get_type(self.ptr.as_ptr(), variable_id, &mut type_id)
.ok(self)?;
Ok(self.create_handle(type_id))
}
}
}
#[cfg(test)]
mod test {
use crate::error::SpirvCrossError;
use crate::Compiler;
use crate::{targets, Module};
static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
#[test]
pub fn get_stage_outputs() -> Result<(), SpirvCrossError> {
let vec = Vec::from(BASIC_SPV);
let words = Module::from_words(bytemuck::cast_slice(&vec));
let compiler: Compiler<targets::None> = Compiler::new(words)?;
let resources = compiler.shader_resources()?.all_resources()?;
let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
eprintln!("{ty:?}");
drop(compiler);
eprintln!("{resources:?}");
eprintln!("{resources:?}");
Ok(())
}
#[test]
pub fn set_member_name_validity_test() -> Result<(), SpirvCrossError> {
let vec = Vec::from(BASIC_SPV);
let words = Module::from_words(bytemuck::cast_slice(&vec));
let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
let resources = compiler.shader_resources()?.all_resources()?;
let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
let id = ty.id;
let name = compiler.member_name(id, 0)?;
assert_eq!(Some("MVP"), name.as_deref());
compiler.set_member_name(ty.id, 0, "NotMVP")?;
let name = compiler.member_name(id, 0)?;
assert_eq!(Some("NotMVP"), name.as_deref());
let resources = compiler.shader_resources()?.all_resources()?;
let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
Ok(())
}
#[test]
pub fn get_variable_type_test() -> Result<(), SpirvCrossError> {
let vec = Vec::from(BASIC_SPV);
let words = Module::from_words(bytemuck::cast_slice(&vec));
let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
let resources = compiler.shader_resources()?.all_resources()?;
let variable = resources.uniform_buffers[0].id;
assert_eq!(
resources.uniform_buffers[0].type_id.id(),
compiler.variable_type(variable)?.id()
);
eprintln!("{:?}", resources);
Ok(())
}
}