use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use core::fmt::{Error as FmtError, Write};
use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};
mod keywords;
pub mod sampler;
mod writer;
pub use writer::Writer;
pub type Slot = u8;
pub type InlineSamplerIndex = u8;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BindSamplerTarget {
Resource(Slot),
Inline(InlineSamplerIndex),
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct BindTarget {
pub buffer: Option<Slot>,
pub texture: Option<Slot>,
pub sampler: Option<BindSamplerTarget>,
pub mutable: bool,
}
#[cfg(any(feature = "serialize", feature = "deserialize"))]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
struct BindingMapSerialization {
resource_binding: crate::ResourceBinding,
bind_target: BindTarget,
}
#[cfg(feature = "deserialize")]
fn deserialize_binding_map<'de, D>(deserializer: D) -> Result<BindingMap, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
let mut map = BindingMap::default();
for item in vec {
map.insert(item.resource_binding, item.bind_target);
}
Ok(map)
}
pub type BindingMap = alloc::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct EntryPointResources {
#[cfg_attr(
feature = "deserialize",
serde(deserialize_with = "deserialize_binding_map")
)]
pub resources: BindingMap,
pub push_constant_buffer: Option<Slot>,
pub sizes_buffer: Option<Slot>,
}
pub type EntryPointResourceMap = alloc::collections::BTreeMap<String, EntryPointResources>;
enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
Attribute(u32),
Color {
location: u32,
blend_src: Option<u32>,
},
User {
prefix: &'static str,
index: u32,
interpolation: Option<ResolvedInterpolation>,
},
Resource(BindTarget),
}
#[derive(Copy, Clone)]
enum ResolvedInterpolation {
CenterPerspective,
CenterNoPerspective,
CentroidPerspective,
CentroidNoPerspective,
SamplePerspective,
SampleNoPerspective,
Flat,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Format(#[from] FmtError),
#[error("bind target {0:?} is empty")]
UnimplementedBindTarget(BindTarget),
#[error("composing of {0:?} is not implemented yet")]
UnsupportedCompose(Handle<crate::Type>),
#[error("operation {0:?} is not implemented yet")]
UnsupportedBinaryOp(crate::BinaryOperator),
#[error("standard function '{0}' is not implemented yet")]
UnsupportedCall(String),
#[error("feature '{0}' is not implemented yet")]
FeatureNotImplemented(String),
#[error("internal naga error: module should not have validated: {0}")]
GenericValidation(String),
#[error("BuiltIn {0:?} is not supported")]
UnsupportedBuiltIn(crate::BuiltIn),
#[error("capability {0:?} is not supported")]
CapabilityNotSupported(crate::valid::Capabilities),
#[error("attribute '{0}' is not supported for target MSL version")]
UnsupportedAttribute(String),
#[error("function '{0}' is not supported for target MSL version")]
UnsupportedFunction(String),
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
UnsupportedWriteableStorageTexture(ir::ShaderStage),
#[error("can not use read-write storage textures prior to MSL 1.2")]
UnsupportedRWStorageTexture,
#[error("array of '{0}' is not supported for target MSL version")]
UnsupportedArrayOf(String),
#[error("array of type '{0:?}' is not supported")]
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error("overrides should not be present at this stage")]
Override,
#[error("bitcasting to {0:?} is not supported")]
UnsupportedBitCast(crate::TypeInner),
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
#[error("entry point with stage {0:?} and name '{1}' not found")]
EntryPointNotFound(ir::ShaderStage, String),
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum EntryPointError {
#[error("global '{0}' doesn't have a binding")]
MissingBinding(String),
#[error("mapping of {0:?} is missing")]
MissingBindTarget(crate::ResourceBinding),
#[error("mapping for push constants is missing")]
MissingPushConstants,
#[error("mapping for sizes buffer is missing")]
MissingSizesBuffer,
}
#[derive(Clone, Copy, Debug)]
enum LocationMode {
VertexInput,
VertexOutput,
FragmentInput,
FragmentOutput,
Uniform,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct Options {
pub lang_version: (u8, u8),
pub per_entry_point_map: EntryPointResourceMap,
pub inline_samplers: Vec<sampler::InlineSampler>,
pub spirv_cross_compatibility: bool,
pub fake_missing_bindings: bool,
pub bounds_check_policies: index::BoundsCheckPolicies,
pub zero_initialize_workgroup_memory: bool,
pub force_loop_bounding: bool,
}
impl Default for Options {
fn default() -> Self {
Options {
lang_version: (1, 0),
per_entry_point_map: EntryPointResourceMap::default(),
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: true,
force_loop_bounding: true,
}
}
}
#[repr(u32)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum VertexFormat {
Uint8 = 0,
Uint8x2 = 1,
Uint8x4 = 2,
Sint8 = 3,
Sint8x2 = 4,
Sint8x4 = 5,
Unorm8 = 6,
Unorm8x2 = 7,
Unorm8x4 = 8,
Snorm8 = 9,
Snorm8x2 = 10,
Snorm8x4 = 11,
Uint16 = 12,
Uint16x2 = 13,
Uint16x4 = 14,
Sint16 = 15,
Sint16x2 = 16,
Sint16x4 = 17,
Unorm16 = 18,
Unorm16x2 = 19,
Unorm16x4 = 20,
Snorm16 = 21,
Snorm16x2 = 22,
Snorm16x4 = 23,
Float16 = 24,
Float16x2 = 25,
Float16x4 = 26,
Float32 = 27,
Float32x2 = 28,
Float32x3 = 29,
Float32x4 = 30,
Uint32 = 31,
Uint32x2 = 32,
Uint32x3 = 33,
Uint32x4 = 34,
Sint32 = 35,
Sint32x2 = 36,
Sint32x3 = 37,
Sint32x4 = 38,
#[cfg_attr(
any(feature = "serialize", feature = "deserialize"),
serde(rename = "unorm10-10-10-2")
)]
Unorm10_10_10_2 = 43,
#[cfg_attr(
any(feature = "serialize", feature = "deserialize"),
serde(rename = "unorm8x4-bgra")
)]
Unorm8x4Bgra = 44,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct AttributeMapping {
pub shader_location: u32,
pub offset: u32,
pub format: VertexFormat,
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct VertexBufferMapping {
pub id: u32,
pub stride: u32,
pub indexed_by_vertex: bool,
pub attributes: Vec<AttributeMapping>,
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct PipelineOptions {
pub entry_point: Option<(ir::ShaderStage, String)>,
pub allow_and_force_point_size: bool,
pub vertex_pulling_transform: bool,
pub vertex_buffer_mappings: Vec<VertexBufferMapping>,
}
impl Options {
fn resolve_local_binding(
&self,
binding: &crate::Binding,
mode: LocationMode,
) -> Result<ResolvedBinding, Error> {
match *binding {
crate::Binding::BuiltIn(mut built_in) => {
match built_in {
crate::BuiltIn::Position { ref mut invariant } => {
if *invariant && self.lang_version < (2, 1) {
return Err(Error::UnsupportedAttribute("invariant".to_string()));
}
if !matches!(mode, LocationMode::VertexOutput) {
*invariant = false;
}
}
crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => {
return Err(Error::UnsupportedAttribute("base_instance".to_string()));
}
crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => {
return Err(Error::UnsupportedAttribute("instance_id".to_string()));
}
crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => {
return Err(Error::UnsupportedAttribute("primitive_id".to_string()));
}
_ => {}
}
Ok(ResolvedBinding::BuiltIn(built_in))
}
crate::Binding::Location {
location,
interpolation,
sampling,
blend_src,
} => match mode {
LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
LocationMode::FragmentOutput => {
if blend_src.is_some() && self.lang_version < (1, 2) {
return Err(Error::UnsupportedAttribute("blend_src".to_string()));
}
Ok(ResolvedBinding::Color {
location,
blend_src,
})
}
LocationMode::VertexOutput | LocationMode::FragmentInput => {
Ok(ResolvedBinding::User {
prefix: if self.spirv_cross_compatibility {
"locn"
} else {
"loc"
},
index: location,
interpolation: {
let interpolation = interpolation.unwrap();
let sampling = sampling.unwrap_or(crate::Sampling::Center);
Some(ResolvedInterpolation::from_binding(interpolation, sampling))
},
})
}
LocationMode::Uniform => Err(Error::GenericValidation(format!(
"Unexpected Binding::Location({location}) for the Uniform mode"
))),
},
}
}
fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
self.per_entry_point_map.get(&ep.name)
}
fn get_resource_binding_target(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Option<&BindTarget> {
self.get_entry_point_resources(ep)
.and_then(|res| res.resources.get(res_binding))
}
fn resolve_resource_binding(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
let target = self.get_resource_binding_target(ep, res_binding);
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingBindTarget(*res_binding)),
}
}
fn resolve_push_constants(
&self,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingPushConstants),
}
}
fn resolve_sizes_buffer(
&self,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingSizesBuffer),
}
}
}
impl ResolvedBinding {
fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
match *self {
Self::Resource(BindTarget {
sampler: Some(BindSamplerTarget::Inline(index)),
..
}) => Some(&options.inline_samplers[index as usize]),
_ => None,
}
}
fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
write!(out, " [[")?;
match *self {
Self::BuiltIn(built_in) => {
use crate::BuiltIn as Bi;
let name = match built_in {
Bi::Position { invariant: false } => "position",
Bi::Position { invariant: true } => "position, invariant",
Bi::BaseInstance => "base_instance",
Bi::BaseVertex => "base_vertex",
Bi::ClipDistance => "clip_distance",
Bi::InstanceIndex => "instance_id",
Bi::PointSize => "point_size",
Bi::VertexIndex => "vertex_id",
Bi::FragDepth => "depth(any)",
Bi::PointCoord => "point_coord",
Bi::FrontFacing => "front_facing",
Bi::PrimitiveIndex => "primitive_id",
Bi::SampleIndex => "sample_id",
Bi::SampleMask => "sample_mask",
Bi::GlobalInvocationId => "thread_position_in_grid",
Bi::LocalInvocationId => "thread_position_in_threadgroup",
Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
Bi::NumWorkGroups => "threadgroups_per_grid",
Bi::NumSubgroups => "simdgroups_per_threadgroup",
Bi::SubgroupId => "simdgroup_index_in_threadgroup",
Bi::SubgroupSize => "threads_per_simdgroup",
Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
Bi::CullDistance | Bi::ViewIndex | Bi::DrawID => {
return Err(Error::UnsupportedBuiltIn(built_in))
}
};
write!(out, "{name}")?;
}
Self::Attribute(index) => write!(out, "attribute({index})")?,
Self::Color {
location,
blend_src,
} => {
if let Some(blend_src) = blend_src {
write!(out, "color({location}) index({blend_src})")?
} else {
write!(out, "color({location})")?
}
}
Self::User {
prefix,
index,
interpolation,
} => {
write!(out, "user({prefix}{index})")?;
if let Some(interpolation) = interpolation {
write!(out, ", ")?;
interpolation.try_fmt(out)?;
}
}
Self::Resource(ref target) => {
if let Some(id) = target.buffer {
write!(out, "buffer({id})")?;
} else if let Some(id) = target.texture {
write!(out, "texture({id})")?;
} else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
write!(out, "sampler({id})")?;
} else {
return Err(Error::UnimplementedBindTarget(target.clone()));
}
}
}
write!(out, "]]")?;
Ok(())
}
}
impl ResolvedInterpolation {
const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
use crate::Interpolation as I;
use crate::Sampling as S;
match (interpolation, sampling) {
(I::Perspective, S::Center) => Self::CenterPerspective,
(I::Perspective, S::Centroid) => Self::CentroidPerspective,
(I::Perspective, S::Sample) => Self::SamplePerspective,
(I::Linear, S::Center) => Self::CenterNoPerspective,
(I::Linear, S::Centroid) => Self::CentroidNoPerspective,
(I::Linear, S::Sample) => Self::SampleNoPerspective,
(I::Flat, _) => Self::Flat,
_ => unreachable!(),
}
}
fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
let identifier = match self {
Self::CenterPerspective => "center_perspective",
Self::CenterNoPerspective => "center_no_perspective",
Self::CentroidPerspective => "centroid_perspective",
Self::CentroidNoPerspective => "centroid_no_perspective",
Self::SamplePerspective => "sample_perspective",
Self::SampleNoPerspective => "sample_no_perspective",
Self::Flat => "flat",
};
out.write_str(identifier)?;
Ok(())
}
}
pub struct TranslationInfo {
pub entry_point_names: Vec<Result<String, EntryPointError>>,
}
pub fn write_string(
module: &crate::Module,
info: &ModuleInfo,
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<(String, TranslationInfo), Error> {
let mut w = Writer::new(String::new());
let info = w.write(module, info, options, pipeline_options)?;
Ok((w.finish(), info))
}
#[test]
fn test_error_size() {
assert_eq!(size_of::<Error>(), 40);
}