rafx_framework/resources/
reflected_shader.rs

1use crate::resources::resource_lookup::ShaderResource;
2use crate::{
3    ComputePipelineResource, FixedFunctionState, MaterialPassResource, MaterialPassVertexInput,
4    ResourceArc, ResourceLookupSet, SamplerResource, ShaderModuleResource,
5};
6use fnv::{FnvHashMap, FnvHashSet};
7use rafx_api::{
8    RafxImmutableSamplerKey, RafxReflectedDescriptorSetLayout, RafxReflectedEntryPoint, RafxResult,
9    RafxShaderStageFlags,
10};
11use std::sync::Arc;
12
13#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
14pub struct SlotLocation {
15    pub layout_index: u32,
16    pub binding_index: u32,
17}
18
19pub type SlotNameLookup = FnvHashMap<String, FnvHashSet<SlotLocation>>;
20
21pub struct ReflectedShaderMetadata {
22    pub descriptor_set_layout_defs: Vec<RafxReflectedDescriptorSetLayout>,
23    pub slot_name_lookup: SlotNameLookup,
24    pub vertex_inputs: Option<Arc<Vec<MaterialPassVertexInput>>>,
25}
26
27impl ReflectedShaderMetadata {
28    pub fn new(entry_points: &[&RafxReflectedEntryPoint]) -> RafxResult<ReflectedShaderMetadata> {
29        let mut descriptor_set_layout_defs = Vec::default();
30        let mut slot_name_lookup: SlotNameLookup = Default::default();
31        let mut vertex_inputs = None;
32
33        // We iterate through the entry points we will hit for each stage. Each stage may define
34        // slightly different reflection data/bindings in use.
35        for reflection_data in entry_points {
36            //log::trace!("  Reflection data:\n{:#?}", reflection_data);
37
38            if reflection_data
39                .rafx_api_reflection
40                .shader_stage
41                .intersects(RafxShaderStageFlags::VERTEX)
42            {
43                let inputs: Vec<_> = reflection_data
44                    .vertex_inputs
45                    .iter()
46                    .map(|x| MaterialPassVertexInput {
47                        semantic: x.semantic.clone(),
48                        location: x.location,
49                        gl_attribute_name: x.name.clone(),
50                    })
51                    .collect();
52
53                assert!(vertex_inputs.is_none());
54                vertex_inputs = Some(Arc::new(inputs));
55            }
56
57            // Currently not using push constants and it will be handled in the rafx api layer
58            // for (range_index, range) in reflection_data.push_constants.iter().enumerate() {
59            //     if let Some(existing_range) = push_constant_ranges.get(range_index) {
60            //         if range.push_constant != *existing_range {
61            //             let error = format!(
62            //                 "Load Material Failed - Pass has shaders with conflicting push constants",
63            //             );
64            //             log::error!("{}", error);
65            //             return Err(error)?;
66            //         } else {
67            //             log::trace!("    Range index {} already exists and matches", range_index);
68            //         }
69            //     } else {
70            //         log::trace!("    Add range index {} {:?}", range_index, range);
71            //         push_constant_ranges.push(range.push_constant.clone());
72            //     }
73            // }
74
75            for (set_index, layout) in reflection_data.descriptor_set_layouts.iter().enumerate() {
76                // Expand the layout def to include the given set index
77                while descriptor_set_layout_defs.len() <= set_index {
78                    descriptor_set_layout_defs.push(RafxReflectedDescriptorSetLayout::default());
79                }
80
81                if let Some(layout) = layout.as_ref() {
82                    for binding in &layout.bindings {
83                        let existing_binding = descriptor_set_layout_defs[set_index]
84                            .bindings
85                            .iter_mut()
86                            .find(|x| x.resource.binding == binding.resource.binding);
87
88                        if let Some(existing_binding) = existing_binding {
89                            //
90                            // Binding already exists, just make sure this shader's definition for this binding matches
91                            // the shader that added it originally
92                            //
93                            if existing_binding.resource.resource_type
94                                != binding.resource.resource_type
95                            {
96                                let error = format!(
97                                    "Load Material Failed - Pass is using shaders in different stages with different descriptor types for set={} binding={}",
98                                    set_index,
99                                    binding.resource.binding
100                                );
101                                log::error!("{}", error);
102                                return Err(error)?;
103                            }
104
105                            if existing_binding.resource.element_count_normalized()
106                                != binding.resource.element_count_normalized()
107                            {
108                                let error = format!(
109                                    "Load Material Failed - Pass is using shaders in different stages with different descriptor counts for set={} binding={}",
110                                    set_index,
111                                    binding.resource.binding
112                                );
113                                log::error!("{}", error);
114                                return Err(error)?;
115                            }
116
117                            if existing_binding.immutable_samplers != binding.immutable_samplers {
118                                let error = format!(
119                                    "Load Material Failed - Pass is using shaders in different stages with different immutable samplers for set={} binding={}",
120                                    set_index,
121                                    binding.resource.binding
122                                );
123                                log::error!("{}", error);
124                                return Err(error)?;
125                            }
126
127                            if existing_binding.internal_buffer_per_descriptor_size
128                                != binding.internal_buffer_per_descriptor_size
129                            {
130                                let error = format!(
131                                    "Load Material Failed - Pass is using shaders in different stages with different internal buffer configuration for set={} binding={}",
132                                    set_index,
133                                    binding.resource.binding
134                                );
135                                log::error!("{}", error);
136                                return Err(error)?;
137                            }
138
139                            log::trace!("    Descriptor for binding set={} binding={} already exists, adding stage {:?}", set_index, binding.resource.binding, binding.resource.used_in_shader_stages);
140                            existing_binding.resource.used_in_shader_stages |=
141                                binding.resource.used_in_shader_stages;
142                        } else {
143                            //
144                            // This binding was not bound by a previous shader stage, set it up and apply any configuration from this material
145                            //
146                            log::trace!(
147                                "    Add descriptor binding set={} binding={} for stage {:?}",
148                                set_index,
149                                binding.resource.binding,
150                                binding.resource.used_in_shader_stages
151                            );
152                            let def = binding.clone().into();
153
154                            descriptor_set_layout_defs[set_index].bindings.push(def);
155                        }
156
157                        if let Some(slot_name) = &binding.resource.name {
158                            log::trace!(
159                                "  Assign slot name '{}' to binding set={} binding={}",
160                                slot_name,
161                                set_index,
162                                binding.resource.binding
163                            );
164                            slot_name_lookup
165                                .entry(slot_name.clone())
166                                .or_default()
167                                .insert(SlotLocation {
168                                    layout_index: set_index as u32,
169                                    binding_index: binding.resource.binding,
170                                });
171                        }
172                    }
173                }
174            }
175        }
176
177        Ok(ReflectedShaderMetadata {
178            vertex_inputs,
179            descriptor_set_layout_defs,
180            slot_name_lookup,
181        })
182    }
183}
184
185pub struct ReflectedShader {
186    pub metadata: ReflectedShaderMetadata,
187    pub shader: ResourceArc<ShaderResource>,
188}
189
190impl ReflectedShader {
191    pub fn new(
192        resources: &ResourceLookupSet,
193        shader_modules: &[ResourceArc<ShaderModuleResource>],
194        entry_points: &[&RafxReflectedEntryPoint],
195    ) -> RafxResult<Self> {
196        let metadata = ReflectedShaderMetadata::new(entry_points)?;
197        let shader = resources.get_or_create_shader(shader_modules, entry_points)?;
198
199        Ok(ReflectedShader { metadata, shader })
200    }
201
202    pub fn create_immutable_samplers<'a>(
203        resources: &'a ResourceLookupSet,
204        descriptor_set_layouts: &'a [RafxReflectedDescriptorSetLayout],
205    ) -> RafxResult<(
206        Vec<RafxImmutableSamplerKey<'a>>,
207        Vec<Vec<ResourceArc<SamplerResource>>>,
208    )> {
209        // Put all samplers into a hashmap so that we avoid collecting duplicates, and keep them
210        // around to prevent the ResourceArcs from dropping out of scope and being destroyed
211        let mut immutable_samplers = FnvHashSet::default();
212
213        // We also need to save vecs of samplers that are immutable
214        let mut immutable_rafx_sampler_lists = Vec::default();
215        let mut immutable_rafx_sampler_keys = Vec::default();
216
217        for (set_index, descriptor_set_layout_def) in descriptor_set_layouts.iter().enumerate() {
218            // Get or create samplers and add them to the two above structures
219            for binding in &descriptor_set_layout_def.bindings {
220                if let Some(sampler_defs) = &binding.immutable_samplers {
221                    let mut samplers = Vec::with_capacity(sampler_defs.len());
222                    for sampler_def in sampler_defs {
223                        let sampler = resources.get_or_create_sampler(sampler_def)?;
224                        samplers.push(sampler.clone());
225                        immutable_samplers.insert(sampler);
226                    }
227
228                    immutable_rafx_sampler_keys.push(RafxImmutableSamplerKey::Binding(
229                        set_index as u32,
230                        binding.resource.binding,
231                    ));
232                    immutable_rafx_sampler_lists.push(samplers);
233                }
234            }
235        }
236        Ok((immutable_rafx_sampler_keys, immutable_rafx_sampler_lists))
237    }
238
239    pub fn load_material_pass(
240        &self,
241        resources: &ResourceLookupSet,
242        fixed_function_state: Arc<FixedFunctionState>,
243        debug_name: Option<&str>,
244    ) -> RafxResult<ResourceArc<MaterialPassResource>> {
245        let vertex_inputs = self
246            .metadata
247            .vertex_inputs
248            .as_ref()
249            .ok_or_else(|| "The material pass does not specify a vertex shader")?
250            .clone();
251
252        //
253        // Root Signature
254        //
255        let (immutable_rafx_sampler_keys, immutable_rafx_sampler_lists) =
256            ReflectedShader::create_immutable_samplers(
257                resources,
258                &self.metadata.descriptor_set_layout_defs,
259            )?;
260
261        let root_signature = resources.get_or_create_root_signature(
262            &[self.shader.clone()],
263            &immutable_rafx_sampler_keys,
264            &immutable_rafx_sampler_lists,
265        )?;
266
267        //
268        // Descriptor set layout
269        //
270        let mut descriptor_set_layouts =
271            Vec::with_capacity(self.metadata.descriptor_set_layout_defs.len());
272
273        for (set_index, descriptor_set_layout_def) in
274            self.metadata.descriptor_set_layout_defs.iter().enumerate()
275        {
276            let descriptor_set_layout = resources.get_or_create_descriptor_set_layout(
277                &root_signature,
278                set_index as u32,
279                &descriptor_set_layout_def,
280            )?;
281            descriptor_set_layouts.push(descriptor_set_layout);
282        }
283
284        //
285        // Create the material pass
286        //
287        resources.get_or_create_material_pass(
288            self.shader.clone(),
289            root_signature,
290            descriptor_set_layouts,
291            fixed_function_state,
292            vertex_inputs.clone(),
293            debug_name,
294        )
295    }
296
297    pub fn load_compute_pipeline(
298        &self,
299        resources: &ResourceLookupSet,
300        //shader_module: &ResourceArc<ShaderModuleResource>,
301        //entry_point: &ReflectedEntryPoint,
302        debug_name: Option<&str>,
303    ) -> RafxResult<ResourceArc<ComputePipelineResource>> {
304        let (immutable_rafx_sampler_keys, immutable_rafx_sampler_lists) =
305            ReflectedShader::create_immutable_samplers(
306                resources,
307                &self.metadata.descriptor_set_layout_defs,
308            )?;
309
310        let root_signature = resources.get_or_create_root_signature(
311            &[self.shader.clone()],
312            &immutable_rafx_sampler_keys,
313            &immutable_rafx_sampler_lists,
314        )?;
315
316        //
317        // Create the push constant ranges
318        //
319
320        // Currently unused, can be handled by the rafx api layer
321        // let mut push_constant_ranges = vec![];
322        // for (range_index, range) in entry_point.push_constants.iter().enumerate() {
323        //     log::trace!("    Add range index {} {:?}", range_index, range);
324        //     push_constant_ranges.push(range.push_constant.clone());
325        // }
326
327        //
328        // Gather the descriptor set bindings
329        //
330        let mut descriptor_set_layout_defs = Vec::default();
331        for (set_index, layout) in self.metadata.descriptor_set_layout_defs.iter().enumerate() {
332            // Expand the layout def to include the given set index
333            while descriptor_set_layout_defs.len() <= set_index {
334                descriptor_set_layout_defs.push(RafxReflectedDescriptorSetLayout::default());
335            }
336
337            for binding in &layout.bindings {
338                log::trace!(
339                    "    Add descriptor binding set={} binding={} for stage {:?}",
340                    set_index,
341                    binding.resource.binding,
342                    binding.resource.used_in_shader_stages
343                );
344                let def = binding.clone().into();
345
346                descriptor_set_layout_defs[set_index].bindings.push(def);
347            }
348        }
349
350        //
351        // Create the descriptor set layout
352        //
353        let mut descriptor_set_layouts = Vec::with_capacity(descriptor_set_layout_defs.len());
354
355        for (set_index, descriptor_set_layout_def) in descriptor_set_layout_defs.iter().enumerate()
356        {
357            let descriptor_set_layout = resources.get_or_create_descriptor_set_layout(
358                &root_signature,
359                set_index as u32,
360                &descriptor_set_layout_def,
361            )?;
362            descriptor_set_layouts.push(descriptor_set_layout);
363        }
364
365        //
366        // Create the compute pipeline
367        //
368        resources.get_or_create_compute_pipeline(
369            &self.shader,
370            &root_signature,
371            descriptor_set_layouts,
372            debug_name,
373        )
374    }
375}