1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5 fn from(stage: naga::ShaderStage) -> Self {
6 match stage {
7 naga::ShaderStage::Compute => Self::COMPUTE,
8 naga::ShaderStage::Vertex => Self::VERTEX,
9 naga::ShaderStage::Fragment => Self::FRAGMENT,
10 _ => Self::empty(),
11 }
12 }
13}
14
15impl super::Context {
16 fn validate_module(
17 &self,
18 module: &naga::Module,
19 source: &str,
20 ) -> Result<naga::valid::ModuleInfo, &'static str> {
21 let device_caps = self.capabilities();
22
23 let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
25 let mut caps = naga::valid::Capabilities::empty();
26 caps.set(
27 naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY
28 | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY
29 | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
30 | naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING,
31 device_caps.binding_array,
32 );
33 caps.set(
34 naga::valid::Capabilities::RAY_QUERY
35 | naga::valid::Capabilities::ACCELERATION_STRUCTURE_BINDING_ARRAY,
36 !device_caps.ray_query.is_empty(),
37 );
38 caps.set(
39 naga::valid::Capabilities::DUAL_SOURCE_BLENDING,
40 device_caps.dual_source_blending,
41 );
42 caps.set(
43 naga::valid::Capabilities::SHADER_FLOAT16,
44 device_caps.shader_float16,
45 );
46 caps.set(
47 naga::valid::Capabilities::COOPERATIVE_MATRIX,
48 device_caps.cooperative_matrix.is_supported(),
49 );
50 caps.set(naga::valid::Capabilities::SUBGROUP, true);
55
56 naga::valid::Validator::new(flags, caps)
57 .validate(module)
58 .map_err(|e| {
59 crate::util::emit_annotated_error(&e, "", source);
60 crate::util::print_err(&e);
61 "validation failed"
62 })
63 }
64
65 pub fn try_create_shader(
66 &self,
67 desc: super::ShaderDesc,
68 ) -> Result<super::Shader, &'static str> {
69 let module = match desc.naga_module {
70 Some(module) => module,
71 None => naga::front::wgsl::parse_str(desc.source).map_err(|e| {
72 eprintln!("{}", e.emit_to_string_with_path(desc.source, ""));
73 "compilation failed"
74 })?,
75 };
76 let info = self.validate_module(&module, desc.source)?;
77 Ok(super::Shader {
78 module,
79 info,
80 source: desc.source.to_owned(),
81 })
82 }
83
84 pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
85 self.try_create_shader(desc).unwrap()
86 }
87}
88
89pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
90
91impl super::Shader {
92 pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
93 super::ShaderFunction {
94 shader: self,
95 entry_point,
96 constants: Lazy::force(&EMPTY_CONSTANTS),
97 }
98 }
99
100 pub fn with_constants<'a>(
101 &'a self,
102 entry_point: &'a str,
103 constants: &'a super::PipelineConstants,
104 ) -> super::ShaderFunction<'a> {
105 super::ShaderFunction {
106 shader: self,
107 entry_point,
108 constants,
109 }
110 }
111
112 pub fn resolve_constants<'a>(
113 &'a self,
114 constants: &super::PipelineConstants,
115 ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
116 let (module, info) = naga::back::pipeline_constants::process_overrides(
117 &self.module,
118 &self.info,
119 None,
120 constants,
121 )
122 .unwrap();
123 (module.into_owned(), info)
124 }
125
126 pub fn get_struct_size(&self, struct_name: &str) -> u32 {
127 match self
128 .module
129 .types
130 .iter()
131 .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
132 {
133 Some((_, ty)) => match ty.inner {
134 naga::TypeInner::Struct { members: _, span } => span,
135 _ => panic!("Type '{struct_name}' is not a struct in the shader"),
136 },
137 None => panic!("Struct '{struct_name}' is not found in the shader"),
138 }
139 }
140
141 pub fn check_struct_size<T>(&self) {
142 use std::{any::type_name, mem::size_of};
143 let name = type_name::<T>().rsplit("::").next().unwrap();
144 assert_eq!(
145 size_of::<T>(),
146 self.get_struct_size(name) as usize,
147 "Host struct '{name}' size doesn't match the shader"
148 );
149 }
150
151 pub(crate) fn fill_resource_bindings(
152 module: &mut naga::Module,
153 sd_infos: &mut [crate::ShaderDataInfo],
154 naga_stage: naga::ShaderStage,
155 ep_info: &naga::valid::FunctionInfo,
156 group_layouts: &[&crate::ShaderDataLayout],
157 ) {
158 let mut layouter = naga::proc::Layouter::default();
159 layouter.update(module.to_ctx()).unwrap();
160
161 for (handle, var) in module.global_variables.iter_mut() {
162 if ep_info[handle].is_empty() {
163 continue;
164 }
165 let var_access = match var.space {
166 naga::AddressSpace::Storage { access } => access,
167 naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
168 naga::StorageAccess::empty()
169 }
170 _ => continue,
171 };
172
173 assert_eq!(var.binding, None);
174 let var_name = var.name.as_ref().unwrap();
175 for (group_index, (&layout, info)) in
176 group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
177 {
178 if let Some((binding_index, &(_, proto_binding))) = layout
179 .bindings
180 .iter()
181 .enumerate()
182 .find(|&(_, &(name, _))| name == var_name)
183 {
184 let (expected_proto, access) = match module.types[var.ty].inner {
185 naga::TypeInner::Image {
186 class: naga::ImageClass::Storage { access, format: _ },
187 ..
188 } => (crate::ShaderBinding::Texture, access),
189 naga::TypeInner::Image { .. } => {
190 (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
191 }
192 naga::TypeInner::Sampler { .. } => {
193 (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
194 }
195 naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
196 crate::ShaderBinding::AccelerationStructure,
197 naga::StorageAccess::empty(),
198 ),
199 naga::TypeInner::BindingArray { base, size: _ } => {
200 let count = match proto_binding {
202 crate::ShaderBinding::TextureArray { count } => count,
203 crate::ShaderBinding::BufferArray { count } => count,
204 crate::ShaderBinding::AccelerationStructureArray { count } => count,
205 _ => 0,
206 };
207 let proto = match module.types[base].inner {
208 naga::TypeInner::Image { .. } => {
209 crate::ShaderBinding::TextureArray { count }
210 }
211 naga::TypeInner::Struct { .. } => {
212 crate::ShaderBinding::BufferArray { count }
213 }
214 naga::TypeInner::AccelerationStructure { .. } => {
215 crate::ShaderBinding::AccelerationStructureArray { count }
216 }
217 ref other => panic!("Unsupported binding array for {:?}", other),
218 };
219 (proto, var_access)
220 }
221 _ => {
222 let type_layout = &layouter[var.ty];
223 let proto = if var_access.is_empty()
224 && proto_binding != crate::ShaderBinding::Buffer
225 {
226 crate::ShaderBinding::Plain {
227 size: type_layout.size,
228 }
229 } else {
230 crate::ShaderBinding::Buffer
231 };
232 (proto, var_access)
233 }
234 };
235 assert_eq!(
236 proto_binding, expected_proto,
237 "Mismatched type for binding '{}'",
238 var_name
239 );
240 assert_eq!(var.binding, None);
241 var.binding = Some(naga::ResourceBinding {
242 group: group_index as u32,
243 binding: binding_index as u32,
244 });
245 info.visibility |= naga_stage.into();
246 info.binding_access[binding_index] |= access;
247 break;
248 }
249 }
250
251 assert!(
252 var.binding.is_some(),
253 "Unable to resolve binding for '{}' in stage '{:?}'",
254 var_name,
255 naga_stage,
256 );
257 }
258 }
259
260 pub(crate) fn fill_vertex_locations(
261 module: &mut naga::Module,
262 selected_ep_index: usize,
263 fetch_states: &[crate::VertexFetchState],
264 ) -> Vec<crate::VertexAttributeMapping> {
265 let mut attribute_mappings = Vec::new();
266 for (ep_index, ep) in module.entry_points.iter().enumerate() {
267 if ep.stage != naga::ShaderStage::Vertex {
268 continue;
269 }
270 if ep_index != selected_ep_index {
271 continue;
272 }
273
274 for argument in ep.function.arguments.iter() {
275 if argument.binding.is_some() {
276 continue;
277 }
278
279 let arg_name = match argument.name {
280 Some(ref name) => name.as_str(),
281 None => "?",
282 };
283 let mut ty = module.types[argument.ty].clone();
284 let members = match ty.inner {
285 naga::TypeInner::Struct {
286 ref mut members, ..
287 } => members,
288 ref other => {
289 log::error!("Unexpected type for '{}': {:?}", arg_name, other);
290 continue;
291 }
292 };
293
294 log::debug!("Processing vertex argument: {}", arg_name);
295 'member: for member in members.iter_mut() {
296 let member_name = match member.name {
297 Some(ref name) => name.as_str(),
298 None => "?",
299 };
300 if let Some(ref binding) = member.binding {
301 log::warn!(
302 "Member '{}' already has binding: {:?}",
303 member_name,
304 binding
305 );
306 continue;
307 }
308 let binding = naga::Binding::Location {
309 location: attribute_mappings.len() as u32,
310 interpolation: None,
311 sampling: None,
312 blend_src: None,
313 per_primitive: false,
314 };
315 for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
316 for (attribute_index, &(at_name, _)) in
317 vertex_fetch.layout.attributes.iter().enumerate()
318 {
319 if at_name == member_name {
320 log::debug!(
321 "Assigning location({}) for member '{}' to be using input {}:{}",
322 attribute_mappings.len(),
323 member_name,
324 buffer_index,
325 attribute_index
326 );
327 member.binding = Some(binding);
328 attribute_mappings.push(crate::VertexAttributeMapping {
329 buffer_index,
330 attribute_index,
331 });
332 continue 'member;
333 }
334 }
335 }
336 assert_ne!(
337 member.binding, None,
338 "Field {} is not covered by the vertex fetch layouts!",
339 member_name
340 );
341 }
342 module.types.replace(argument.ty, ty);
343 }
344 }
345 attribute_mappings
346 }
347}