1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5 fn from(stage: naga::ShaderStage) -> Self {
6 match stage {
7 naga::ShaderStage::Compute => Self::COMPUTE,
8 naga::ShaderStage::Vertex => Self::VERTEX,
9 naga::ShaderStage::Fragment => Self::FRAGMENT,
10 _ => Self::empty(),
11 }
12 }
13}
14
15impl super::Context {
16 fn validate_module(
17 &self,
18 module: &naga::Module,
19 source: &str,
20 ) -> Result<naga::valid::ModuleInfo, &'static str> {
21 let device_caps = self.capabilities();
22
23 let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
25 let mut caps = naga::valid::Capabilities::empty();
26 caps.set(
27 naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY
28 | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY
29 | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
30 | naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING,
31 device_caps.binding_array,
32 );
33 caps.set(
34 naga::valid::Capabilities::RAY_QUERY
35 | naga::valid::Capabilities::ACCELERATION_STRUCTURE_BINDING_ARRAY,
36 !device_caps.ray_query.is_empty(),
37 );
38 caps.set(
39 naga::valid::Capabilities::DUAL_SOURCE_BLENDING,
40 device_caps.dual_source_blending,
41 );
42 caps.set(
43 naga::valid::Capabilities::SHADER_FLOAT16,
44 device_caps.shader_float16,
45 );
46 caps.set(
47 naga::valid::Capabilities::COOPERATIVE_MATRIX,
48 device_caps.cooperative_matrix.is_supported(),
49 );
50 naga::valid::Validator::new(flags, caps)
51 .validate(module)
52 .map_err(|e| {
53 crate::util::emit_annotated_error(&e, "", source);
54 crate::util::print_err(&e);
55 "validation failed"
56 })
57 }
58
59 pub fn try_create_shader(
60 &self,
61 desc: super::ShaderDesc,
62 ) -> Result<super::Shader, &'static str> {
63 let module = match desc.naga_module {
64 Some(module) => module,
65 None => naga::front::wgsl::parse_str(desc.source).map_err(|e| {
66 eprintln!("{}", e.emit_to_string_with_path(desc.source, ""));
67 "compilation failed"
68 })?,
69 };
70 let info = self.validate_module(&module, desc.source)?;
71 Ok(super::Shader {
72 module,
73 info,
74 source: desc.source.to_owned(),
75 })
76 }
77
78 pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
79 self.try_create_shader(desc).unwrap()
80 }
81}
82
83pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
84
85impl super::Shader {
86 pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
87 super::ShaderFunction {
88 shader: self,
89 entry_point,
90 constants: Lazy::force(&EMPTY_CONSTANTS),
91 }
92 }
93
94 pub fn with_constants<'a>(
95 &'a self,
96 entry_point: &'a str,
97 constants: &'a super::PipelineConstants,
98 ) -> super::ShaderFunction<'a> {
99 super::ShaderFunction {
100 shader: self,
101 entry_point,
102 constants,
103 }
104 }
105
106 pub fn resolve_constants<'a>(
107 &'a self,
108 constants: &super::PipelineConstants,
109 ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
110 let (module, info) = naga::back::pipeline_constants::process_overrides(
111 &self.module,
112 &self.info,
113 None,
114 constants,
115 )
116 .unwrap();
117 (module.into_owned(), info)
118 }
119
120 pub fn get_struct_size(&self, struct_name: &str) -> u32 {
121 match self
122 .module
123 .types
124 .iter()
125 .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
126 {
127 Some((_, ty)) => match ty.inner {
128 naga::TypeInner::Struct { members: _, span } => span,
129 _ => panic!("Type '{struct_name}' is not a struct in the shader"),
130 },
131 None => panic!("Struct '{struct_name}' is not found in the shader"),
132 }
133 }
134
135 pub fn check_struct_size<T>(&self) {
136 use std::{any::type_name, mem::size_of};
137 let name = type_name::<T>().rsplit("::").next().unwrap();
138 assert_eq!(
139 size_of::<T>(),
140 self.get_struct_size(name) as usize,
141 "Host struct '{name}' size doesn't match the shader"
142 );
143 }
144
145 pub(crate) fn fill_resource_bindings(
146 module: &mut naga::Module,
147 sd_infos: &mut [crate::ShaderDataInfo],
148 naga_stage: naga::ShaderStage,
149 ep_info: &naga::valid::FunctionInfo,
150 group_layouts: &[&crate::ShaderDataLayout],
151 ) {
152 let mut layouter = naga::proc::Layouter::default();
153 layouter.update(module.to_ctx()).unwrap();
154
155 for (handle, var) in module.global_variables.iter_mut() {
156 if ep_info[handle].is_empty() {
157 continue;
158 }
159 let var_access = match var.space {
160 naga::AddressSpace::Storage { access } => access,
161 naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
162 naga::StorageAccess::empty()
163 }
164 _ => continue,
165 };
166
167 assert_eq!(var.binding, None);
168 let var_name = var.name.as_ref().unwrap();
169 for (group_index, (&layout, info)) in
170 group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
171 {
172 if let Some((binding_index, &(_, proto_binding))) = layout
173 .bindings
174 .iter()
175 .enumerate()
176 .find(|&(_, &(name, _))| name == var_name)
177 {
178 let (expected_proto, access) = match module.types[var.ty].inner {
179 naga::TypeInner::Image {
180 class: naga::ImageClass::Storage { access, format: _ },
181 ..
182 } => (crate::ShaderBinding::Texture, access),
183 naga::TypeInner::Image { .. } => {
184 (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
185 }
186 naga::TypeInner::Sampler { .. } => {
187 (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
188 }
189 naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
190 crate::ShaderBinding::AccelerationStructure,
191 naga::StorageAccess::empty(),
192 ),
193 naga::TypeInner::BindingArray { base, size: _ } => {
194 let count = match proto_binding {
196 crate::ShaderBinding::TextureArray { count } => count,
197 crate::ShaderBinding::BufferArray { count } => count,
198 crate::ShaderBinding::AccelerationStructureArray { count } => count,
199 _ => 0,
200 };
201 let proto = match module.types[base].inner {
202 naga::TypeInner::Image { .. } => {
203 crate::ShaderBinding::TextureArray { count }
204 }
205 naga::TypeInner::Struct { .. } => {
206 crate::ShaderBinding::BufferArray { count }
207 }
208 naga::TypeInner::AccelerationStructure { .. } => {
209 crate::ShaderBinding::AccelerationStructureArray { count }
210 }
211 ref other => panic!("Unsupported binding array for {:?}", other),
212 };
213 (proto, var_access)
214 }
215 _ => {
216 let type_layout = &layouter[var.ty];
217 let proto = if var_access.is_empty()
218 && proto_binding != crate::ShaderBinding::Buffer
219 {
220 crate::ShaderBinding::Plain {
221 size: type_layout.size,
222 }
223 } else {
224 crate::ShaderBinding::Buffer
225 };
226 (proto, var_access)
227 }
228 };
229 assert_eq!(
230 proto_binding, expected_proto,
231 "Mismatched type for binding '{}'",
232 var_name
233 );
234 assert_eq!(var.binding, None);
235 var.binding = Some(naga::ResourceBinding {
236 group: group_index as u32,
237 binding: binding_index as u32,
238 });
239 info.visibility |= naga_stage.into();
240 info.binding_access[binding_index] |= access;
241 break;
242 }
243 }
244
245 assert!(
246 var.binding.is_some(),
247 "Unable to resolve binding for '{}' in stage '{:?}'",
248 var_name,
249 naga_stage,
250 );
251 }
252 }
253
254 pub(crate) fn fill_vertex_locations(
255 module: &mut naga::Module,
256 selected_ep_index: usize,
257 fetch_states: &[crate::VertexFetchState],
258 ) -> Vec<crate::VertexAttributeMapping> {
259 let mut attribute_mappings = Vec::new();
260 for (ep_index, ep) in module.entry_points.iter().enumerate() {
261 if ep.stage != naga::ShaderStage::Vertex {
262 continue;
263 }
264 if ep_index != selected_ep_index {
265 continue;
266 }
267
268 for argument in ep.function.arguments.iter() {
269 if argument.binding.is_some() {
270 continue;
271 }
272
273 let arg_name = match argument.name {
274 Some(ref name) => name.as_str(),
275 None => "?",
276 };
277 let mut ty = module.types[argument.ty].clone();
278 let members = match ty.inner {
279 naga::TypeInner::Struct {
280 ref mut members, ..
281 } => members,
282 ref other => {
283 log::error!("Unexpected type for '{}': {:?}", arg_name, other);
284 continue;
285 }
286 };
287
288 log::debug!("Processing vertex argument: {}", arg_name);
289 'member: for member in members.iter_mut() {
290 let member_name = match member.name {
291 Some(ref name) => name.as_str(),
292 None => "?",
293 };
294 if let Some(ref binding) = member.binding {
295 log::warn!(
296 "Member '{}' already has binding: {:?}",
297 member_name,
298 binding
299 );
300 continue;
301 }
302 let binding = naga::Binding::Location {
303 location: attribute_mappings.len() as u32,
304 interpolation: None,
305 sampling: None,
306 blend_src: None,
307 per_primitive: false,
308 };
309 for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
310 for (attribute_index, &(at_name, _)) in
311 vertex_fetch.layout.attributes.iter().enumerate()
312 {
313 if at_name == member_name {
314 log::debug!(
315 "Assigning location({}) for member '{}' to be using input {}:{}",
316 attribute_mappings.len(),
317 member_name,
318 buffer_index,
319 attribute_index
320 );
321 member.binding = Some(binding);
322 attribute_mappings.push(crate::VertexAttributeMapping {
323 buffer_index,
324 attribute_index,
325 });
326 continue 'member;
327 }
328 }
329 }
330 assert_ne!(
331 member.binding, None,
332 "Field {} is not covered by the vertex fetch layouts!",
333 member_name
334 );
335 }
336 module.types.replace(argument.ty, ty);
337 }
338 }
339 attribute_mappings
340 }
341}