Skip to main content

blade_graphics/
shader.rs

1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5    fn from(stage: naga::ShaderStage) -> Self {
6        match stage {
7            naga::ShaderStage::Compute => Self::COMPUTE,
8            naga::ShaderStage::Vertex => Self::VERTEX,
9            naga::ShaderStage::Fragment => Self::FRAGMENT,
10            _ => Self::empty(),
11        }
12    }
13}
14
15impl super::Context {
16    fn validate_module(
17        &self,
18        module: &naga::Module,
19        source: &str,
20    ) -> Result<naga::valid::ModuleInfo, &'static str> {
21        let device_caps = self.capabilities();
22
23        // Bindings are set up at pipeline creation, ignore here
24        let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
25        let mut caps = naga::valid::Capabilities::empty();
26        caps.set(
27            naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY
28                | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY
29                | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
30                | naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING,
31            device_caps.binding_array,
32        );
33        caps.set(
34            naga::valid::Capabilities::RAY_QUERY
35                | naga::valid::Capabilities::ACCELERATION_STRUCTURE_BINDING_ARRAY,
36            !device_caps.ray_query.is_empty(),
37        );
38        caps.set(
39            naga::valid::Capabilities::DUAL_SOURCE_BLENDING,
40            device_caps.dual_source_blending,
41        );
42        caps.set(
43            naga::valid::Capabilities::SHADER_FLOAT16,
44            device_caps.shader_float16,
45        );
46        caps.set(
47            naga::valid::Capabilities::COOPERATIVE_MATRIX,
48            device_caps.cooperative_matrix.is_supported(),
49        );
50        naga::valid::Validator::new(flags, caps)
51            .validate(module)
52            .map_err(|e| {
53                crate::util::emit_annotated_error(&e, "", source);
54                crate::util::print_err(&e);
55                "validation failed"
56            })
57    }
58
59    pub fn try_create_shader(
60        &self,
61        desc: super::ShaderDesc,
62    ) -> Result<super::Shader, &'static str> {
63        let module = match desc.naga_module {
64            Some(module) => module,
65            None => naga::front::wgsl::parse_str(desc.source).map_err(|e| {
66                eprintln!("{}", e.emit_to_string_with_path(desc.source, ""));
67                "compilation failed"
68            })?,
69        };
70        let info = self.validate_module(&module, desc.source)?;
71        Ok(super::Shader {
72            module,
73            info,
74            source: desc.source.to_owned(),
75        })
76    }
77
78    pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
79        self.try_create_shader(desc).unwrap()
80    }
81}
82
83pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
84
85impl super::Shader {
86    pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
87        super::ShaderFunction {
88            shader: self,
89            entry_point,
90            constants: Lazy::force(&EMPTY_CONSTANTS),
91        }
92    }
93
94    pub fn with_constants<'a>(
95        &'a self,
96        entry_point: &'a str,
97        constants: &'a super::PipelineConstants,
98    ) -> super::ShaderFunction<'a> {
99        super::ShaderFunction {
100            shader: self,
101            entry_point,
102            constants,
103        }
104    }
105
106    pub fn resolve_constants<'a>(
107        &'a self,
108        constants: &super::PipelineConstants,
109    ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
110        let (module, info) = naga::back::pipeline_constants::process_overrides(
111            &self.module,
112            &self.info,
113            None,
114            constants,
115        )
116        .unwrap();
117        (module.into_owned(), info)
118    }
119
120    pub fn get_struct_size(&self, struct_name: &str) -> u32 {
121        match self
122            .module
123            .types
124            .iter()
125            .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
126        {
127            Some((_, ty)) => match ty.inner {
128                naga::TypeInner::Struct { members: _, span } => span,
129                _ => panic!("Type '{struct_name}' is not a struct in the shader"),
130            },
131            None => panic!("Struct '{struct_name}' is not found in the shader"),
132        }
133    }
134
135    pub fn check_struct_size<T>(&self) {
136        use std::{any::type_name, mem::size_of};
137        let name = type_name::<T>().rsplit("::").next().unwrap();
138        assert_eq!(
139            size_of::<T>(),
140            self.get_struct_size(name) as usize,
141            "Host struct '{name}' size doesn't match the shader"
142        );
143    }
144
145    pub(crate) fn fill_resource_bindings(
146        module: &mut naga::Module,
147        sd_infos: &mut [crate::ShaderDataInfo],
148        naga_stage: naga::ShaderStage,
149        ep_info: &naga::valid::FunctionInfo,
150        group_layouts: &[&crate::ShaderDataLayout],
151    ) {
152        let mut layouter = naga::proc::Layouter::default();
153        layouter.update(module.to_ctx()).unwrap();
154
155        for (handle, var) in module.global_variables.iter_mut() {
156            if ep_info[handle].is_empty() {
157                continue;
158            }
159            let var_access = match var.space {
160                naga::AddressSpace::Storage { access } => access,
161                naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
162                    naga::StorageAccess::empty()
163                }
164                _ => continue,
165            };
166
167            assert_eq!(var.binding, None);
168            let var_name = var.name.as_ref().unwrap();
169            for (group_index, (&layout, info)) in
170                group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
171            {
172                if let Some((binding_index, &(_, proto_binding))) = layout
173                    .bindings
174                    .iter()
175                    .enumerate()
176                    .find(|&(_, &(name, _))| name == var_name)
177                {
178                    let (expected_proto, access) = match module.types[var.ty].inner {
179                        naga::TypeInner::Image {
180                            class: naga::ImageClass::Storage { access, format: _ },
181                            ..
182                        } => (crate::ShaderBinding::Texture, access),
183                        naga::TypeInner::Image { .. } => {
184                            (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
185                        }
186                        naga::TypeInner::Sampler { .. } => {
187                            (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
188                        }
189                        naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
190                            crate::ShaderBinding::AccelerationStructure,
191                            naga::StorageAccess::empty(),
192                        ),
193                        naga::TypeInner::BindingArray { base, size: _ } => {
194                            //Note: we could extract the count from `size` for more rigor
195                            let count = match proto_binding {
196                                crate::ShaderBinding::TextureArray { count } => count,
197                                crate::ShaderBinding::BufferArray { count } => count,
198                                crate::ShaderBinding::AccelerationStructureArray { count } => count,
199                                _ => 0,
200                            };
201                            let proto = match module.types[base].inner {
202                                naga::TypeInner::Image { .. } => {
203                                    crate::ShaderBinding::TextureArray { count }
204                                }
205                                naga::TypeInner::Struct { .. } => {
206                                    crate::ShaderBinding::BufferArray { count }
207                                }
208                                naga::TypeInner::AccelerationStructure { .. } => {
209                                    crate::ShaderBinding::AccelerationStructureArray { count }
210                                }
211                                ref other => panic!("Unsupported binding array for {:?}", other),
212                            };
213                            (proto, var_access)
214                        }
215                        _ => {
216                            let type_layout = &layouter[var.ty];
217                            let proto = if var_access.is_empty()
218                                && proto_binding != crate::ShaderBinding::Buffer
219                            {
220                                crate::ShaderBinding::Plain {
221                                    size: type_layout.size,
222                                }
223                            } else {
224                                crate::ShaderBinding::Buffer
225                            };
226                            (proto, var_access)
227                        }
228                    };
229                    assert_eq!(
230                        proto_binding, expected_proto,
231                        "Mismatched type for binding '{}'",
232                        var_name
233                    );
234                    assert_eq!(var.binding, None);
235                    var.binding = Some(naga::ResourceBinding {
236                        group: group_index as u32,
237                        binding: binding_index as u32,
238                    });
239                    info.visibility |= naga_stage.into();
240                    info.binding_access[binding_index] |= access;
241                    break;
242                }
243            }
244
245            assert!(
246                var.binding.is_some(),
247                "Unable to resolve binding for '{}' in stage '{:?}'",
248                var_name,
249                naga_stage,
250            );
251        }
252    }
253
254    pub(crate) fn fill_vertex_locations(
255        module: &mut naga::Module,
256        selected_ep_index: usize,
257        fetch_states: &[crate::VertexFetchState],
258    ) -> Vec<crate::VertexAttributeMapping> {
259        let mut attribute_mappings = Vec::new();
260        for (ep_index, ep) in module.entry_points.iter().enumerate() {
261            if ep.stage != naga::ShaderStage::Vertex {
262                continue;
263            }
264            if ep_index != selected_ep_index {
265                continue;
266            }
267
268            for argument in ep.function.arguments.iter() {
269                if argument.binding.is_some() {
270                    continue;
271                }
272
273                let arg_name = match argument.name {
274                    Some(ref name) => name.as_str(),
275                    None => "?",
276                };
277                let mut ty = module.types[argument.ty].clone();
278                let members = match ty.inner {
279                    naga::TypeInner::Struct {
280                        ref mut members, ..
281                    } => members,
282                    ref other => {
283                        log::error!("Unexpected type for '{}': {:?}", arg_name, other);
284                        continue;
285                    }
286                };
287
288                log::debug!("Processing vertex argument: {}", arg_name);
289                'member: for member in members.iter_mut() {
290                    let member_name = match member.name {
291                        Some(ref name) => name.as_str(),
292                        None => "?",
293                    };
294                    if let Some(ref binding) = member.binding {
295                        log::warn!(
296                            "Member '{}' already has binding: {:?}",
297                            member_name,
298                            binding
299                        );
300                        continue;
301                    }
302                    let binding = naga::Binding::Location {
303                        location: attribute_mappings.len() as u32,
304                        interpolation: None,
305                        sampling: None,
306                        blend_src: None,
307                        per_primitive: false,
308                    };
309                    for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
310                        for (attribute_index, &(at_name, _)) in
311                            vertex_fetch.layout.attributes.iter().enumerate()
312                        {
313                            if at_name == member_name {
314                                log::debug!(
315                                    "Assigning location({}) for member '{}' to be using input {}:{}",
316                                    attribute_mappings.len(),
317                                    member_name,
318                                    buffer_index,
319                                    attribute_index
320                                );
321                                member.binding = Some(binding);
322                                attribute_mappings.push(crate::VertexAttributeMapping {
323                                    buffer_index,
324                                    attribute_index,
325                                });
326                                continue 'member;
327                            }
328                        }
329                    }
330                    assert_ne!(
331                        member.binding, None,
332                        "Field {} is not covered by the vertex fetch layouts!",
333                        member_name
334                    );
335                }
336                module.types.replace(argument.ty, ty);
337            }
338        }
339        attribute_mappings
340    }
341}