1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5 fn from(stage: naga::ShaderStage) -> Self {
6 match stage {
7 naga::ShaderStage::Compute => Self::COMPUTE,
8 naga::ShaderStage::Vertex => Self::VERTEX,
9 naga::ShaderStage::Fragment => Self::FRAGMENT,
10 _ => Self::empty(),
11 }
12 }
13}
14
15impl super::Context {
16 pub fn try_create_shader(
17 &self,
18 desc: super::ShaderDesc,
19 ) -> Result<super::Shader, &'static str> {
20 let module = naga::front::wgsl::parse_str(desc.source).map_err(|e| {
21 e.emit_to_stderr_with_path(desc.source, "");
22 "compilation failed"
23 })?;
24
25 let device_caps = self.capabilities();
26
27 let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
29 let mut caps = naga::valid::Capabilities::empty();
30 caps.set(
31 naga::valid::Capabilities::RAY_QUERY | naga::valid::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
32 !device_caps.ray_query.is_empty(),
33 );
34 let info = naga::valid::Validator::new(flags, caps)
35 .validate(&module)
36 .map_err(|e| {
37 crate::util::emit_annotated_error(&e, "", desc.source);
38 crate::util::print_err(&e);
39 "validation failed"
40 })?;
41
42 Ok(super::Shader {
43 module,
44 info,
45 source: desc.source.to_owned(),
46 })
47 }
48
49 pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
50 self.try_create_shader(desc).unwrap()
51 }
52}
53
54pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
55
56impl super::Shader {
57 pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
58 super::ShaderFunction {
59 shader: self,
60 entry_point,
61 constants: Lazy::force(&EMPTY_CONSTANTS),
62 }
63 }
64
65 pub fn with_constants<'a>(
66 &'a self,
67 entry_point: &'a str,
68 constants: &'a super::PipelineConstants,
69 ) -> super::ShaderFunction<'a> {
70 super::ShaderFunction {
71 shader: self,
72 entry_point,
73 constants,
74 }
75 }
76
77 pub fn resolve_constants<'a>(
78 &'a self,
79 constants: &super::PipelineConstants,
80 ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
81 let (module, info) =
82 naga::back::pipeline_constants::process_overrides(&self.module, &self.info, constants)
83 .unwrap();
84 (module.into_owned(), info)
85 }
86
87 pub fn get_struct_size(&self, struct_name: &str) -> u32 {
88 match self
89 .module
90 .types
91 .iter()
92 .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
93 {
94 Some((_, ty)) => match ty.inner {
95 naga::TypeInner::Struct { members: _, span } => span,
96 _ => panic!("Type '{struct_name}' is not a struct in the shader"),
97 },
98 None => panic!("Struct '{struct_name}' is not found in the shader"),
99 }
100 }
101
102 pub fn check_struct_size<T>(&self) {
103 use std::{any::type_name, mem::size_of};
104 let name = type_name::<T>().rsplit("::").next().unwrap();
105 assert_eq!(
106 size_of::<T>(),
107 self.get_struct_size(name) as usize,
108 "Host struct '{name}' size doesn't match the shader"
109 );
110 }
111
112 pub(crate) fn fill_resource_bindings(
113 module: &mut naga::Module,
114 sd_infos: &mut [crate::ShaderDataInfo],
115 naga_stage: naga::ShaderStage,
116 ep_info: &naga::valid::FunctionInfo,
117 group_layouts: &[&crate::ShaderDataLayout],
118 ) {
119 let mut layouter = naga::proc::Layouter::default();
120 layouter.update(module.to_ctx()).unwrap();
121
122 for (handle, var) in module.global_variables.iter_mut() {
123 if ep_info[handle].is_empty() {
124 continue;
125 }
126 let var_access = match var.space {
127 naga::AddressSpace::Storage { access } => access,
128 naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
129 naga::StorageAccess::empty()
130 }
131 _ => continue,
132 };
133
134 assert_eq!(var.binding, None);
135 let var_name = var.name.as_ref().unwrap();
136 for (group_index, (&layout, info)) in
137 group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
138 {
139 if let Some((binding_index, &(_, proto_binding))) = layout
140 .bindings
141 .iter()
142 .enumerate()
143 .find(|&(_, &(name, _))| name == var_name)
144 {
145 let (expected_proto, access) = match module.types[var.ty].inner {
146 naga::TypeInner::Image {
147 class: naga::ImageClass::Storage { access, format: _ },
148 ..
149 } => (crate::ShaderBinding::Texture, access),
150 naga::TypeInner::Image { .. } => {
151 (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
152 }
153 naga::TypeInner::Sampler { .. } => {
154 (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
155 }
156 naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
157 crate::ShaderBinding::AccelerationStructure,
158 naga::StorageAccess::empty(),
159 ),
160 naga::TypeInner::BindingArray { base, size: _ } => {
161 let count = match proto_binding {
163 crate::ShaderBinding::TextureArray { count } => count,
164 crate::ShaderBinding::BufferArray { count } => count,
165 _ => 0,
166 };
167 let proto = match module.types[base].inner {
168 naga::TypeInner::Image { .. } => {
169 crate::ShaderBinding::TextureArray { count }
170 }
171 naga::TypeInner::Struct { .. } => {
172 crate::ShaderBinding::BufferArray { count }
173 }
174 ref other => panic!("Unsupported binding array for {:?}", other),
175 };
176 (proto, var_access)
177 }
178 _ => {
179 let type_layout = &layouter[var.ty];
180 let proto = if var_access.is_empty()
181 && proto_binding != crate::ShaderBinding::Buffer
182 {
183 crate::ShaderBinding::Plain {
184 size: type_layout.size,
185 }
186 } else {
187 crate::ShaderBinding::Buffer
188 };
189 (proto, var_access)
190 }
191 };
192 assert_eq!(
193 proto_binding, expected_proto,
194 "Mismatched type for binding '{}'",
195 var_name
196 );
197 assert_eq!(var.binding, None);
198 var.binding = Some(naga::ResourceBinding {
199 group: group_index as u32,
200 binding: binding_index as u32,
201 });
202 info.visibility |= naga_stage.into();
203 info.binding_access[binding_index] |= access;
204 break;
205 }
206 }
207
208 assert!(
209 var.binding.is_some(),
210 "Unable to resolve binding for '{}' in stage '{:?}'",
211 var_name,
212 naga_stage,
213 );
214 }
215 }
216
217 pub(crate) fn fill_vertex_locations(
218 module: &mut naga::Module,
219 selected_ep_index: usize,
220 fetch_states: &[crate::VertexFetchState],
221 ) -> Vec<crate::VertexAttributeMapping> {
222 let mut attribute_mappings = Vec::new();
223 for (ep_index, ep) in module.entry_points.iter().enumerate() {
224 let mut location = 0;
225 if ep.stage != naga::ShaderStage::Vertex {
226 continue;
227 }
228
229 for argument in ep.function.arguments.iter() {
230 if argument.binding.is_some() {
231 continue;
232 }
233
234 let arg_name = match argument.name {
235 Some(ref name) => name.as_str(),
236 None => "?",
237 };
238 let mut ty = module.types[argument.ty].clone();
239 let members = match ty.inner {
240 naga::TypeInner::Struct {
241 ref mut members, ..
242 } => members,
243 ref other => {
244 log::error!("Unexpected type for '{}': {:?}", arg_name, other);
245 continue;
246 }
247 };
248
249 if ep_index == selected_ep_index {
250 log::debug!("Processing vertex argument: {}", arg_name);
251
252 'member: for member in members.iter_mut() {
253 let member_name = match member.name {
254 Some(ref name) => name.as_str(),
255 None => "?",
256 };
257 if let Some(ref binding) = member.binding {
258 log::warn!(
259 "Member '{}' alread has binding: {:?}",
260 member_name,
261 binding
262 );
263 continue;
264 }
265 let binding = naga::Binding::Location {
266 location: attribute_mappings.len() as u32,
267 interpolation: None,
268 sampling: None,
269 blend_src: None,
270 };
271 for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
272 for (attribute_index, &(at_name, _)) in
273 vertex_fetch.layout.attributes.iter().enumerate()
274 {
275 if at_name == member_name {
276 log::debug!("Assigning location({}) for member '{}' to be using input {}:{}",
277 attribute_mappings.len(), member_name, buffer_index, attribute_index);
278 member.binding = Some(binding);
279 attribute_mappings.push(crate::VertexAttributeMapping {
280 buffer_index,
281 attribute_index,
282 });
283 continue 'member;
284 }
285 }
286 }
287 assert_ne!(
288 member.binding, None,
289 "Field {} is not covered by the vertex fetch layouts!",
290 member_name
291 );
292 }
293 } else {
294 for member in members.iter_mut() {
296 if member.binding.is_none() {
297 member.binding = Some(naga::Binding::Location {
298 location,
299 interpolation: None,
300 sampling: None,
301 blend_src: None,
302 });
303 location += 1;
304 }
305 }
306 }
307
308 module.types.replace(argument.ty, ty);
309 }
310 }
311 attribute_mappings
312 }
313}