use once_cell::sync::Lazy;
use std::borrow::Cow;
impl From<naga::ShaderStage> for super::ShaderVisibility {
fn from(stage: naga::ShaderStage) -> Self {
match stage {
naga::ShaderStage::Compute => Self::COMPUTE,
naga::ShaderStage::Vertex => Self::VERTEX,
naga::ShaderStage::Fragment => Self::FRAGMENT,
_ => Self::empty(),
}
}
}
impl super::Context {
fn validate_module(
&self,
module: &naga::Module,
source: &str,
) -> Result<naga::valid::ModuleInfo, &'static str> {
let device_caps = self.capabilities();
let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
let mut caps = naga::valid::Capabilities::empty();
caps.set(
naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY
| naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY
| naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
| naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING,
device_caps.binding_array,
);
caps.set(
naga::valid::Capabilities::RAY_QUERY
| naga::valid::Capabilities::ACCELERATION_STRUCTURE_BINDING_ARRAY,
!device_caps.ray_query.is_empty(),
);
caps.set(
naga::valid::Capabilities::DUAL_SOURCE_BLENDING,
device_caps.dual_source_blending,
);
caps.set(
naga::valid::Capabilities::SHADER_FLOAT16,
device_caps.shader_float16,
);
caps.set(
naga::valid::Capabilities::COOPERATIVE_MATRIX,
device_caps.cooperative_matrix.is_supported(),
);
naga::valid::Validator::new(flags, caps)
.validate(module)
.map_err(|e| {
crate::util::emit_annotated_error(&e, "", source);
crate::util::print_err(&e);
"validation failed"
})
}
pub fn try_create_shader(
&self,
desc: super::ShaderDesc,
) -> Result<super::Shader, &'static str> {
let module = match desc.naga_module {
Some(module) => module,
None => naga::front::wgsl::parse_str(desc.source).map_err(|e| {
eprintln!("{}", e.emit_to_string_with_path(desc.source, ""));
"compilation failed"
})?,
};
let info = self.validate_module(&module, desc.source)?;
Ok(super::Shader {
module,
info,
source: desc.source.to_owned(),
})
}
pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
self.try_create_shader(desc).unwrap()
}
}
pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
impl super::Shader {
pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
super::ShaderFunction {
shader: self,
entry_point,
constants: Lazy::force(&EMPTY_CONSTANTS),
}
}
pub fn with_constants<'a>(
&'a self,
entry_point: &'a str,
constants: &'a super::PipelineConstants,
) -> super::ShaderFunction<'a> {
super::ShaderFunction {
shader: self,
entry_point,
constants,
}
}
pub fn resolve_constants<'a>(
&'a self,
constants: &super::PipelineConstants,
) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
let (module, info) = naga::back::pipeline_constants::process_overrides(
&self.module,
&self.info,
None,
constants,
)
.unwrap();
(module.into_owned(), info)
}
pub fn get_struct_size(&self, struct_name: &str) -> u32 {
match self
.module
.types
.iter()
.find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
{
Some((_, ty)) => match ty.inner {
naga::TypeInner::Struct { members: _, span } => span,
_ => panic!("Type '{struct_name}' is not a struct in the shader"),
},
None => panic!("Struct '{struct_name}' is not found in the shader"),
}
}
pub fn check_struct_size<T>(&self) {
use std::{any::type_name, mem::size_of};
let name = type_name::<T>().rsplit("::").next().unwrap();
assert_eq!(
size_of::<T>(),
self.get_struct_size(name) as usize,
"Host struct '{name}' size doesn't match the shader"
);
}
pub(crate) fn fill_resource_bindings(
module: &mut naga::Module,
sd_infos: &mut [crate::ShaderDataInfo],
naga_stage: naga::ShaderStage,
ep_info: &naga::valid::FunctionInfo,
group_layouts: &[&crate::ShaderDataLayout],
) {
let mut layouter = naga::proc::Layouter::default();
layouter.update(module.to_ctx()).unwrap();
for (handle, var) in module.global_variables.iter_mut() {
if ep_info[handle].is_empty() {
continue;
}
let var_access = match var.space {
naga::AddressSpace::Storage { access } => access,
naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
naga::StorageAccess::empty()
}
_ => continue,
};
assert_eq!(var.binding, None);
let var_name = var.name.as_ref().unwrap();
for (group_index, (&layout, info)) in
group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
{
if let Some((binding_index, &(_, proto_binding))) = layout
.bindings
.iter()
.enumerate()
.find(|&(_, &(name, _))| name == var_name)
{
let (expected_proto, access) = match module.types[var.ty].inner {
naga::TypeInner::Image {
class: naga::ImageClass::Storage { access, format: _ },
..
} => (crate::ShaderBinding::Texture, access),
naga::TypeInner::Image { .. } => {
(crate::ShaderBinding::Texture, naga::StorageAccess::empty())
}
naga::TypeInner::Sampler { .. } => {
(crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
}
naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
crate::ShaderBinding::AccelerationStructure,
naga::StorageAccess::empty(),
),
naga::TypeInner::BindingArray { base, size: _ } => {
let count = match proto_binding {
crate::ShaderBinding::TextureArray { count } => count,
crate::ShaderBinding::BufferArray { count } => count,
crate::ShaderBinding::AccelerationStructureArray { count } => count,
_ => 0,
};
let proto = match module.types[base].inner {
naga::TypeInner::Image { .. } => {
crate::ShaderBinding::TextureArray { count }
}
naga::TypeInner::Struct { .. } => {
crate::ShaderBinding::BufferArray { count }
}
naga::TypeInner::AccelerationStructure { .. } => {
crate::ShaderBinding::AccelerationStructureArray { count }
}
ref other => panic!("Unsupported binding array for {:?}", other),
};
(proto, var_access)
}
_ => {
let type_layout = &layouter[var.ty];
let proto = if var_access.is_empty()
&& proto_binding != crate::ShaderBinding::Buffer
{
crate::ShaderBinding::Plain {
size: type_layout.size,
}
} else {
crate::ShaderBinding::Buffer
};
(proto, var_access)
}
};
assert_eq!(
proto_binding, expected_proto,
"Mismatched type for binding '{}'",
var_name
);
assert_eq!(var.binding, None);
var.binding = Some(naga::ResourceBinding {
group: group_index as u32,
binding: binding_index as u32,
});
info.visibility |= naga_stage.into();
info.binding_access[binding_index] |= access;
break;
}
}
assert!(
var.binding.is_some(),
"Unable to resolve binding for '{}' in stage '{:?}'",
var_name,
naga_stage,
);
}
}
pub(crate) fn fill_vertex_locations(
module: &mut naga::Module,
selected_ep_index: usize,
fetch_states: &[crate::VertexFetchState],
) -> Vec<crate::VertexAttributeMapping> {
let mut attribute_mappings = Vec::new();
for (ep_index, ep) in module.entry_points.iter().enumerate() {
if ep.stage != naga::ShaderStage::Vertex {
continue;
}
if ep_index != selected_ep_index {
continue;
}
for argument in ep.function.arguments.iter() {
if argument.binding.is_some() {
continue;
}
let arg_name = match argument.name {
Some(ref name) => name.as_str(),
None => "?",
};
let mut ty = module.types[argument.ty].clone();
let members = match ty.inner {
naga::TypeInner::Struct {
ref mut members, ..
} => members,
ref other => {
log::error!("Unexpected type for '{}': {:?}", arg_name, other);
continue;
}
};
log::debug!("Processing vertex argument: {}", arg_name);
'member: for member in members.iter_mut() {
let member_name = match member.name {
Some(ref name) => name.as_str(),
None => "?",
};
if let Some(ref binding) = member.binding {
log::warn!(
"Member '{}' already has binding: {:?}",
member_name,
binding
);
continue;
}
let binding = naga::Binding::Location {
location: attribute_mappings.len() as u32,
interpolation: None,
sampling: None,
blend_src: None,
per_primitive: false,
};
for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
for (attribute_index, &(at_name, _)) in
vertex_fetch.layout.attributes.iter().enumerate()
{
if at_name == member_name {
log::debug!(
"Assigning location({}) for member '{}' to be using input {}:{}",
attribute_mappings.len(),
member_name,
buffer_index,
attribute_index
);
member.binding = Some(binding);
attribute_mappings.push(crate::VertexAttributeMapping {
buffer_index,
attribute_index,
});
continue 'member;
}
}
}
assert_ne!(
member.binding, None,
"Field {} is not covered by the vertex fetch layouts!",
member_name
);
}
module.types.replace(argument.ty, ty);
}
}
attribute_mappings
}
}