blade_graphics/
shader.rs

1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5    fn from(stage: naga::ShaderStage) -> Self {
6        match stage {
7            naga::ShaderStage::Compute => Self::COMPUTE,
8            naga::ShaderStage::Vertex => Self::VERTEX,
9            naga::ShaderStage::Fragment => Self::FRAGMENT,
10            _ => Self::empty(),
11        }
12    }
13}
14
15impl super::Context {
16    pub fn try_create_shader(
17        &self,
18        desc: super::ShaderDesc,
19    ) -> Result<super::Shader, &'static str> {
20        let module = naga::front::wgsl::parse_str(desc.source).map_err(|e| {
21            e.emit_to_stderr_with_path(desc.source, "");
22            "compilation failed"
23        })?;
24
25        let device_caps = self.capabilities();
26
27        // Bindings are set up at pipeline creation, ignore here
28        let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
29        let mut caps = naga::valid::Capabilities::empty();
30        caps.set(
31            naga::valid::Capabilities::RAY_QUERY | naga::valid::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
32            !device_caps.ray_query.is_empty(),
33        );
34        let info = naga::valid::Validator::new(flags, caps)
35            .validate(&module)
36            .map_err(|e| {
37                crate::util::emit_annotated_error(&e, "", desc.source);
38                crate::util::print_err(&e);
39                "validation failed"
40            })?;
41
42        Ok(super::Shader {
43            module,
44            info,
45            source: desc.source.to_owned(),
46        })
47    }
48
49    pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
50        self.try_create_shader(desc).unwrap()
51    }
52}
53
54pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
55
56impl super::Shader {
57    pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
58        super::ShaderFunction {
59            shader: self,
60            entry_point,
61            constants: Lazy::force(&EMPTY_CONSTANTS),
62        }
63    }
64
65    pub fn with_constants<'a>(
66        &'a self,
67        entry_point: &'a str,
68        constants: &'a super::PipelineConstants,
69    ) -> super::ShaderFunction<'a> {
70        super::ShaderFunction {
71            shader: self,
72            entry_point,
73            constants,
74        }
75    }
76
77    pub fn resolve_constants<'a>(
78        &'a self,
79        constants: &super::PipelineConstants,
80    ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
81        let (module, info) =
82            naga::back::pipeline_constants::process_overrides(&self.module, &self.info, constants)
83                .unwrap();
84        (module.into_owned(), info)
85    }
86
87    pub fn get_struct_size(&self, struct_name: &str) -> u32 {
88        match self
89            .module
90            .types
91            .iter()
92            .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
93        {
94            Some((_, ty)) => match ty.inner {
95                naga::TypeInner::Struct { members: _, span } => span,
96                _ => panic!("Type '{struct_name}' is not a struct in the shader"),
97            },
98            None => panic!("Struct '{struct_name}' is not found in the shader"),
99        }
100    }
101
102    pub fn check_struct_size<T>(&self) {
103        use std::{any::type_name, mem::size_of};
104        let name = type_name::<T>().rsplit("::").next().unwrap();
105        assert_eq!(
106            size_of::<T>(),
107            self.get_struct_size(name) as usize,
108            "Host struct '{name}' size doesn't match the shader"
109        );
110    }
111
112    pub(crate) fn fill_resource_bindings(
113        module: &mut naga::Module,
114        sd_infos: &mut [crate::ShaderDataInfo],
115        naga_stage: naga::ShaderStage,
116        ep_info: &naga::valid::FunctionInfo,
117        group_layouts: &[&crate::ShaderDataLayout],
118    ) {
119        let mut layouter = naga::proc::Layouter::default();
120        layouter.update(module.to_ctx()).unwrap();
121
122        for (handle, var) in module.global_variables.iter_mut() {
123            if ep_info[handle].is_empty() {
124                continue;
125            }
126            let var_access = match var.space {
127                naga::AddressSpace::Storage { access } => access,
128                naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
129                    naga::StorageAccess::empty()
130                }
131                _ => continue,
132            };
133
134            assert_eq!(var.binding, None);
135            let var_name = var.name.as_ref().unwrap();
136            for (group_index, (&layout, info)) in
137                group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
138            {
139                if let Some((binding_index, &(_, proto_binding))) = layout
140                    .bindings
141                    .iter()
142                    .enumerate()
143                    .find(|&(_, &(name, _))| name == var_name)
144                {
145                    let (expected_proto, access) = match module.types[var.ty].inner {
146                        naga::TypeInner::Image {
147                            class: naga::ImageClass::Storage { access, format: _ },
148                            ..
149                        } => (crate::ShaderBinding::Texture, access),
150                        naga::TypeInner::Image { .. } => {
151                            (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
152                        }
153                        naga::TypeInner::Sampler { .. } => {
154                            (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
155                        }
156                        naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
157                            crate::ShaderBinding::AccelerationStructure,
158                            naga::StorageAccess::empty(),
159                        ),
160                        naga::TypeInner::BindingArray { base, size: _ } => {
161                            //Note: we could extract the count from `size` for more rigor
162                            let count = match proto_binding {
163                                crate::ShaderBinding::TextureArray { count } => count,
164                                crate::ShaderBinding::BufferArray { count } => count,
165                                _ => 0,
166                            };
167                            let proto = match module.types[base].inner {
168                                naga::TypeInner::Image { .. } => {
169                                    crate::ShaderBinding::TextureArray { count }
170                                }
171                                naga::TypeInner::Struct { .. } => {
172                                    crate::ShaderBinding::BufferArray { count }
173                                }
174                                ref other => panic!("Unsupported binding array for {:?}", other),
175                            };
176                            (proto, var_access)
177                        }
178                        _ => {
179                            let type_layout = &layouter[var.ty];
180                            let proto = if var_access.is_empty()
181                                && proto_binding != crate::ShaderBinding::Buffer
182                            {
183                                crate::ShaderBinding::Plain {
184                                    size: type_layout.size,
185                                }
186                            } else {
187                                crate::ShaderBinding::Buffer
188                            };
189                            (proto, var_access)
190                        }
191                    };
192                    assert_eq!(
193                        proto_binding, expected_proto,
194                        "Mismatched type for binding '{}'",
195                        var_name
196                    );
197                    assert_eq!(var.binding, None);
198                    var.binding = Some(naga::ResourceBinding {
199                        group: group_index as u32,
200                        binding: binding_index as u32,
201                    });
202                    info.visibility |= naga_stage.into();
203                    info.binding_access[binding_index] |= access;
204                    break;
205                }
206            }
207
208            assert!(
209                var.binding.is_some(),
210                "Unable to resolve binding for '{}' in stage '{:?}'",
211                var_name,
212                naga_stage,
213            );
214        }
215    }
216
217    pub(crate) fn fill_vertex_locations(
218        module: &mut naga::Module,
219        selected_ep_index: usize,
220        fetch_states: &[crate::VertexFetchState],
221    ) -> Vec<crate::VertexAttributeMapping> {
222        let mut attribute_mappings = Vec::new();
223        for (ep_index, ep) in module.entry_points.iter().enumerate() {
224            let mut location = 0;
225            if ep.stage != naga::ShaderStage::Vertex {
226                continue;
227            }
228
229            for argument in ep.function.arguments.iter() {
230                if argument.binding.is_some() {
231                    continue;
232                }
233
234                let arg_name = match argument.name {
235                    Some(ref name) => name.as_str(),
236                    None => "?",
237                };
238                let mut ty = module.types[argument.ty].clone();
239                let members = match ty.inner {
240                    naga::TypeInner::Struct {
241                        ref mut members, ..
242                    } => members,
243                    ref other => {
244                        log::error!("Unexpected type for '{}': {:?}", arg_name, other);
245                        continue;
246                    }
247                };
248
249                if ep_index == selected_ep_index {
250                    log::debug!("Processing vertex argument: {}", arg_name);
251
252                    'member: for member in members.iter_mut() {
253                        let member_name = match member.name {
254                            Some(ref name) => name.as_str(),
255                            None => "?",
256                        };
257                        if let Some(ref binding) = member.binding {
258                            log::warn!(
259                                "Member '{}' alread has binding: {:?}",
260                                member_name,
261                                binding
262                            );
263                            continue;
264                        }
265                        let binding = naga::Binding::Location {
266                            location: attribute_mappings.len() as u32,
267                            interpolation: None,
268                            sampling: None,
269                            blend_src: None,
270                        };
271                        for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
272                            for (attribute_index, &(at_name, _)) in
273                                vertex_fetch.layout.attributes.iter().enumerate()
274                            {
275                                if at_name == member_name {
276                                    log::debug!("Assigning location({}) for member '{}' to be using input {}:{}",
277                                        attribute_mappings.len(), member_name, buffer_index, attribute_index);
278                                    member.binding = Some(binding);
279                                    attribute_mappings.push(crate::VertexAttributeMapping {
280                                        buffer_index,
281                                        attribute_index,
282                                    });
283                                    continue 'member;
284                                }
285                            }
286                        }
287                        assert_ne!(
288                            member.binding, None,
289                            "Field {} is not covered by the vertex fetch layouts!",
290                            member_name
291                        );
292                    }
293                } else {
294                    // Just fill out the locations for the module to be valid
295                    for member in members.iter_mut() {
296                        if member.binding.is_none() {
297                            member.binding = Some(naga::Binding::Location {
298                                location,
299                                interpolation: None,
300                                sampling: None,
301                                blend_src: None,
302                            });
303                            location += 1;
304                        }
305                    }
306                }
307
308                module.types.replace(argument.ty, ty);
309            }
310        }
311        attribute_mappings
312    }
313}