Skip to main content

blade_graphics/
shader.rs

1use once_cell::sync::Lazy;
2use std::borrow::Cow;
3
4impl From<naga::ShaderStage> for super::ShaderVisibility {
5    fn from(stage: naga::ShaderStage) -> Self {
6        match stage {
7            naga::ShaderStage::Compute => Self::COMPUTE,
8            naga::ShaderStage::Vertex => Self::VERTEX,
9            naga::ShaderStage::Fragment => Self::FRAGMENT,
10            _ => Self::empty(),
11        }
12    }
13}
14
15impl super::Context {
16    fn validate_module(
17        &self,
18        module: &naga::Module,
19        source: &str,
20    ) -> Result<naga::valid::ModuleInfo, &'static str> {
21        let device_caps = self.capabilities();
22
23        // Bindings are set up at pipeline creation, ignore here
24        let flags = naga::valid::ValidationFlags::all() ^ naga::valid::ValidationFlags::BINDINGS;
25        let mut caps = naga::valid::Capabilities::empty();
26        caps.set(
27            naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY
28                | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY
29                | naga::valid::Capabilities::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING
30                | naga::valid::Capabilities::STORAGE_BUFFER_BINDING_ARRAY_NON_UNIFORM_INDEXING,
31            device_caps.binding_array,
32        );
33        caps.set(
34            naga::valid::Capabilities::RAY_QUERY
35                | naga::valid::Capabilities::ACCELERATION_STRUCTURE_BINDING_ARRAY,
36            !device_caps.ray_query.is_empty(),
37        );
38        caps.set(
39            naga::valid::Capabilities::DUAL_SOURCE_BLENDING,
40            device_caps.dual_source_blending,
41        );
42        caps.set(
43            naga::valid::Capabilities::SHADER_FLOAT16,
44            device_caps.shader_float16,
45        );
46        caps.set(
47            naga::valid::Capabilities::COOPERATIVE_MATRIX,
48            device_caps.cooperative_matrix.is_supported(),
49        );
50        // Subgroup operations (subgroupAdd, subgroupBroadcast, etc.) are
51        // broadly supported on modern GPUs via VK_EXT_subgroup_size_control.
52        // Enable unconditionally so naga validates subgroup ops and emits
53        // the correct SPIR-V capabilities (GroupNonUniform, etc.).
54        caps.set(naga::valid::Capabilities::SUBGROUP, true);
55
56        naga::valid::Validator::new(flags, caps)
57            .validate(module)
58            .map_err(|e| {
59                crate::util::emit_annotated_error(&e, "", source);
60                crate::util::print_err(&e);
61                "validation failed"
62            })
63    }
64
65    pub fn try_create_shader(
66        &self,
67        desc: super::ShaderDesc,
68    ) -> Result<super::Shader, &'static str> {
69        let module = match desc.naga_module {
70            Some(module) => module,
71            None => naga::front::wgsl::parse_str(desc.source).map_err(|e| {
72                eprintln!("{}", e.emit_to_string_with_path(desc.source, ""));
73                "compilation failed"
74            })?,
75        };
76        let info = self.validate_module(&module, desc.source)?;
77        Ok(super::Shader {
78            module,
79            info,
80            source: desc.source.to_owned(),
81        })
82    }
83
84    pub fn create_shader(&self, desc: super::ShaderDesc) -> super::Shader {
85        self.try_create_shader(desc).unwrap()
86    }
87}
88
89pub static EMPTY_CONSTANTS: Lazy<super::PipelineConstants> = Lazy::new(Default::default);
90
91impl super::Shader {
92    pub fn at<'a>(&'a self, entry_point: &'a str) -> super::ShaderFunction<'a> {
93        super::ShaderFunction {
94            shader: self,
95            entry_point,
96            constants: Lazy::force(&EMPTY_CONSTANTS),
97        }
98    }
99
100    pub fn with_constants<'a>(
101        &'a self,
102        entry_point: &'a str,
103        constants: &'a super::PipelineConstants,
104    ) -> super::ShaderFunction<'a> {
105        super::ShaderFunction {
106            shader: self,
107            entry_point,
108            constants,
109        }
110    }
111
112    pub fn resolve_constants<'a>(
113        &'a self,
114        constants: &super::PipelineConstants,
115    ) -> (naga::Module, Cow<'a, naga::valid::ModuleInfo>) {
116        let (module, info) = naga::back::pipeline_constants::process_overrides(
117            &self.module,
118            &self.info,
119            None,
120            constants,
121        )
122        .unwrap();
123        (module.into_owned(), info)
124    }
125
126    pub fn get_struct_size(&self, struct_name: &str) -> u32 {
127        match self
128            .module
129            .types
130            .iter()
131            .find(|&(_, ty)| ty.name.as_deref() == Some(struct_name))
132        {
133            Some((_, ty)) => match ty.inner {
134                naga::TypeInner::Struct { members: _, span } => span,
135                _ => panic!("Type '{struct_name}' is not a struct in the shader"),
136            },
137            None => panic!("Struct '{struct_name}' is not found in the shader"),
138        }
139    }
140
141    pub fn check_struct_size<T>(&self) {
142        use std::{any::type_name, mem::size_of};
143        let name = type_name::<T>().rsplit("::").next().unwrap();
144        assert_eq!(
145            size_of::<T>(),
146            self.get_struct_size(name) as usize,
147            "Host struct '{name}' size doesn't match the shader"
148        );
149    }
150
151    pub(crate) fn fill_resource_bindings(
152        module: &mut naga::Module,
153        sd_infos: &mut [crate::ShaderDataInfo],
154        naga_stage: naga::ShaderStage,
155        ep_info: &naga::valid::FunctionInfo,
156        group_layouts: &[&crate::ShaderDataLayout],
157    ) {
158        let mut layouter = naga::proc::Layouter::default();
159        layouter.update(module.to_ctx()).unwrap();
160
161        for (handle, var) in module.global_variables.iter_mut() {
162            if ep_info[handle].is_empty() {
163                continue;
164            }
165            let var_access = match var.space {
166                naga::AddressSpace::Storage { access } => access,
167                naga::AddressSpace::Uniform | naga::AddressSpace::Handle => {
168                    naga::StorageAccess::empty()
169                }
170                _ => continue,
171            };
172
173            assert_eq!(var.binding, None);
174            let var_name = var.name.as_ref().unwrap();
175            for (group_index, (&layout, info)) in
176                group_layouts.iter().zip(sd_infos.iter_mut()).enumerate()
177            {
178                if let Some((binding_index, &(_, proto_binding))) = layout
179                    .bindings
180                    .iter()
181                    .enumerate()
182                    .find(|&(_, &(name, _))| name == var_name)
183                {
184                    let (expected_proto, access) = match module.types[var.ty].inner {
185                        naga::TypeInner::Image {
186                            class: naga::ImageClass::Storage { access, format: _ },
187                            ..
188                        } => (crate::ShaderBinding::Texture, access),
189                        naga::TypeInner::Image { .. } => {
190                            (crate::ShaderBinding::Texture, naga::StorageAccess::empty())
191                        }
192                        naga::TypeInner::Sampler { .. } => {
193                            (crate::ShaderBinding::Sampler, naga::StorageAccess::empty())
194                        }
195                        naga::TypeInner::AccelerationStructure { vertex_return: _ } => (
196                            crate::ShaderBinding::AccelerationStructure,
197                            naga::StorageAccess::empty(),
198                        ),
199                        naga::TypeInner::BindingArray { base, size: _ } => {
200                            //Note: we could extract the count from `size` for more rigor
201                            let count = match proto_binding {
202                                crate::ShaderBinding::TextureArray { count } => count,
203                                crate::ShaderBinding::BufferArray { count } => count,
204                                crate::ShaderBinding::AccelerationStructureArray { count } => count,
205                                _ => 0,
206                            };
207                            let proto = match module.types[base].inner {
208                                naga::TypeInner::Image { .. } => {
209                                    crate::ShaderBinding::TextureArray { count }
210                                }
211                                naga::TypeInner::Struct { .. } => {
212                                    crate::ShaderBinding::BufferArray { count }
213                                }
214                                naga::TypeInner::AccelerationStructure { .. } => {
215                                    crate::ShaderBinding::AccelerationStructureArray { count }
216                                }
217                                ref other => panic!("Unsupported binding array for {:?}", other),
218                            };
219                            (proto, var_access)
220                        }
221                        _ => {
222                            let type_layout = &layouter[var.ty];
223                            let proto = if var_access.is_empty()
224                                && proto_binding != crate::ShaderBinding::Buffer
225                            {
226                                crate::ShaderBinding::Plain {
227                                    size: type_layout.size,
228                                }
229                            } else {
230                                crate::ShaderBinding::Buffer
231                            };
232                            (proto, var_access)
233                        }
234                    };
235                    assert_eq!(
236                        proto_binding, expected_proto,
237                        "Mismatched type for binding '{}'",
238                        var_name
239                    );
240                    assert_eq!(var.binding, None);
241                    var.binding = Some(naga::ResourceBinding {
242                        group: group_index as u32,
243                        binding: binding_index as u32,
244                    });
245                    info.visibility |= naga_stage.into();
246                    info.binding_access[binding_index] |= access;
247                    break;
248                }
249            }
250
251            assert!(
252                var.binding.is_some(),
253                "Unable to resolve binding for '{}' in stage '{:?}'",
254                var_name,
255                naga_stage,
256            );
257        }
258    }
259
260    pub(crate) fn fill_vertex_locations(
261        module: &mut naga::Module,
262        selected_ep_index: usize,
263        fetch_states: &[crate::VertexFetchState],
264    ) -> Vec<crate::VertexAttributeMapping> {
265        let mut attribute_mappings = Vec::new();
266        for (ep_index, ep) in module.entry_points.iter().enumerate() {
267            if ep.stage != naga::ShaderStage::Vertex {
268                continue;
269            }
270            if ep_index != selected_ep_index {
271                continue;
272            }
273
274            for argument in ep.function.arguments.iter() {
275                if argument.binding.is_some() {
276                    continue;
277                }
278
279                let arg_name = match argument.name {
280                    Some(ref name) => name.as_str(),
281                    None => "?",
282                };
283                let mut ty = module.types[argument.ty].clone();
284                let members = match ty.inner {
285                    naga::TypeInner::Struct {
286                        ref mut members, ..
287                    } => members,
288                    ref other => {
289                        log::error!("Unexpected type for '{}': {:?}", arg_name, other);
290                        continue;
291                    }
292                };
293
294                log::debug!("Processing vertex argument: {}", arg_name);
295                'member: for member in members.iter_mut() {
296                    let member_name = match member.name {
297                        Some(ref name) => name.as_str(),
298                        None => "?",
299                    };
300                    if let Some(ref binding) = member.binding {
301                        log::warn!(
302                            "Member '{}' already has binding: {:?}",
303                            member_name,
304                            binding
305                        );
306                        continue;
307                    }
308                    let binding = naga::Binding::Location {
309                        location: attribute_mappings.len() as u32,
310                        interpolation: None,
311                        sampling: None,
312                        blend_src: None,
313                        per_primitive: false,
314                    };
315                    for (buffer_index, vertex_fetch) in fetch_states.iter().enumerate() {
316                        for (attribute_index, &(at_name, _)) in
317                            vertex_fetch.layout.attributes.iter().enumerate()
318                        {
319                            if at_name == member_name {
320                                log::debug!(
321                                    "Assigning location({}) for member '{}' to be using input {}:{}",
322                                    attribute_mappings.len(),
323                                    member_name,
324                                    buffer_index,
325                                    attribute_index
326                                );
327                                member.binding = Some(binding);
328                                attribute_mappings.push(crate::VertexAttributeMapping {
329                                    buffer_index,
330                                    attribute_index,
331                                });
332                                continue 'member;
333                            }
334                        }
335                    }
336                    assert_ne!(
337                        member.binding, None,
338                        "Field {} is not covered by the vertex fetch layouts!",
339                        member_name
340                    );
341                }
342                module.types.replace(argument.ty, ty);
343            }
344        }
345        attribute_mappings
346    }
347}