rafx_api/
reflection.rs

1use crate::types::{RafxResourceType, RafxShaderStageFlags};
2use crate::{RafxResult, RafxSamplerDef, RafxShaderStageDef, MAX_DESCRIPTOR_SET_LAYOUTS};
3use fnv::FnvHashMap;
4#[cfg(feature = "serde-support")]
5use serde::{Deserialize, Serialize};
6
7/// Indicates where a resource is bound
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
9pub struct RafxShaderResourceBindingKey {
10    pub set: u32,
11    pub binding: u32,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
15#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
16pub struct RafxGlUniformMember {
17    pub name: String,
18    pub offset: u32,
19}
20
21impl RafxGlUniformMember {
22    pub fn new<T: Into<String>>(
23        name: T,
24        offset: u32,
25    ) -> Self {
26        RafxGlUniformMember {
27            name: name.into(),
28            offset,
29        }
30    }
31}
32
33/// A data source within a shader. Often a descriptor or push constant.
34///
35/// A RafxShaderResource may be specified by hand or generated using rafx-shader-processor
36//TODO: Consider separate type for bindings vs. push constants
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
39pub struct RafxShaderResource {
40    pub resource_type: RafxResourceType,
41    pub set_index: u32,
42    pub binding: u32,
43    // Valid only for descriptors (resource_type != ROOT_CONSTANT)
44    // This must remain pub to init the struct as "normal" but in general,
45    // access it via element_count_normalized(). This ensures that if it
46    // is default-initialized to 0, it is treated as 1
47    pub element_count: u32,
48    // Valid only for push constants (resource_type == ROOT_CONSTANT)
49    pub size_in_bytes: u32,
50    pub used_in_shader_stages: RafxShaderStageFlags,
51    // Name is optional
52    //TODO: Add some sort of hashing-friendly option
53    pub name: Option<String>,
54    //pub texture_dimensions: Option<RafxTextureDimension>,
55    // metal stuff?
56
57    //TODO: Generate MSL buffer IDs offline rather than when creating root signature?
58    // What we do now works but requires shader's argument buffer assignments to be assigned in a
59    // very specific way. Would be better if user could provide the argument buffer ID
60
61    // HLSL-specific binding info
62    pub dx12_reg: Option<u32>,
63    pub dx12_space: Option<u32>,
64
65    // Required for GL ES (2.0/3.0) only. Other APIs use set_index and binding. (Rafx shader processor
66    // can produce this metadata automatically)
67    pub gles_name: Option<String>,
68
69    // Required for GL ES (2.0/3.0) only. Every texture must have exactly one sampler associated with it.
70    // Samplers are defined by adding a SAMPLER RafxShaderResource with a valid gl_name. The
71    // gl_sampler_name specified here will reference that sampler. While the GLSL code will not have
72    // a sampler object, rafx API will act as though there is a sampler object. It can be set as if
73    // it was a normal descriptor in a descriptor set. (Rafx shader processor can produce this
74    // metadata automatically)
75    pub gles_sampler_name: Option<String>,
76
77    // Required for GL ES 2.0 only, every field within a uniform must be specified with the byte
78    // offset. This includes elements within arrays. (Rafx shader processor can produce rust structs
79    // and the necessary metadata automatically.)
80    pub gles2_uniform_members: Vec<RafxGlUniformMember>,
81}
82
83impl Default for RafxShaderResource {
84    fn default() -> Self {
85        RafxShaderResource {
86            resource_type: Default::default(),
87            set_index: u32::MAX,
88            binding: u32::MAX,
89            element_count: 0,
90            size_in_bytes: 0,
91            used_in_shader_stages: Default::default(),
92            name: None,
93            dx12_reg: None,
94            dx12_space: None,
95            gles_name: None,
96            gles_sampler_name: None,
97            gles2_uniform_members: Vec::default(),
98        }
99    }
100}
101
102impl RafxShaderResource {
103    pub fn element_count_normalized(&self) -> u32 {
104        // Assume 0 = default of 1
105        self.element_count.max(1)
106    }
107
108    pub fn validate(&self) -> RafxResult<()> {
109        if self.resource_type == RafxResourceType::ROOT_CONSTANT {
110            if self.element_count != 0 {
111                Err(
112                    format!(
113                        "binding (set={:?} binding={:?} name={:?} type={:?}) has non-zero element_count",
114                        self.set_index,
115                        self.binding,
116                        self.name,
117                        self.resource_type
118                    )
119                )?;
120            }
121
122            if self.size_in_bytes == 0 {
123                Err(format!(
124                    "binding (set={:?} binding={:?} name={:?} type={:?}) has zero size_in_bytes",
125                    self.set_index, self.binding, self.name, self.resource_type
126                ))?;
127            }
128
129            if self.set_index != u32::MAX {
130                Err(format!(
131                    "binding (set={:?} binding={:?} name={:?} type={:?}) has set_index != u32::MAX",
132                    self.set_index, self.binding, self.name, self.resource_type
133                ))?;
134            }
135
136            if self.binding != u32::MAX {
137                Err(format!(
138                    "binding (set={:?} binding={:?} name={:?} type={:?}) has binding != u32::MAX",
139                    self.set_index, self.binding, self.name, self.resource_type
140                ))?;
141            }
142        } else {
143            if self.size_in_bytes != 0 {
144                Err(
145                    format!(
146                        "binding (set={:?} binding={:?} name={:?} type={:?}) has non-zero size_in_bytes",
147                        self.set_index,
148                        self.binding,
149                        self.name,
150                        self.resource_type
151                    )
152                )?;
153            }
154
155            if self.set_index == u32::MAX {
156                Err(format!(
157                    "binding (set={:?} binding={:?} name={:?} type={:?}) has binding == u32::MAX",
158                    self.set_index, self.binding, self.name, self.resource_type
159                ))?;
160            }
161
162            if self.binding == u32::MAX {
163                Err(format!(
164                    "binding (set={:?} binding={:?} name={:?} type={:?}) has binding == u32::MAX",
165                    self.set_index, self.binding, self.name, self.resource_type
166                ))?;
167            }
168
169            if self.set_index as usize >= MAX_DESCRIPTOR_SET_LAYOUTS {
170                Err(format!(
171                    "Descriptor (set={:?} binding={:?}) named {:?} has a set index >= 4. This is not supported",
172                    self.set_index, self.binding, self.name,
173                ))?;
174            }
175        }
176
177        Ok(())
178    }
179
180    fn binding_key(&self) -> RafxShaderResourceBindingKey {
181        RafxShaderResourceBindingKey {
182            set: self.set_index,
183            binding: self.binding,
184        }
185    }
186
187    fn verify_compatible_across_stages(
188        &self,
189        other: &Self,
190    ) -> RafxResult<()> {
191        if self.resource_type != other.resource_type {
192            Err(format!(
193                "Pass is using shaders in different stages with different resource_type {:?} and {:?} (set={} binding={})",
194                self.resource_type, other.resource_type,
195                self.set_index,
196                self.binding,
197            ))?;
198        }
199
200        if self.element_count_normalized() != other.element_count_normalized() {
201            Err(format!(
202                "Pass is using shaders in different stages with different element_count {} and {} (set={} binding={})", self.element_count_normalized(), other.element_count_normalized(),
203                self.set_index, self.binding
204            ))?;
205        }
206
207        if self.size_in_bytes != other.size_in_bytes {
208            Err(format!(
209                "Pass is using shaders in different stages with different size_in_bytes {} and {} (set={} binding={})",
210                self.size_in_bytes, other.size_in_bytes,
211                self.set_index, self.binding
212            ))?;
213        }
214
215        if self.gles2_uniform_members != other.gles2_uniform_members {
216            Err(format!(
217                "Pass is using shaders in different stages with different gl_uniform_members (set={} binding={})",
218                self.set_index, self.binding
219            ))?;
220        }
221
222        if self.gles_name != other.gles_name {
223            Err(format!(
224                "Pass is using shaders in different stages with different gles2_name (set={} binding={})",
225                self.set_index, self.binding
226            ))?;
227        }
228
229        if self.dx12_reg != other.dx12_reg {
230            Err(format!(
231                "Pass is using shaders in different stages with different dx12_reg (set={} binding={})",
232                self.set_index, self.binding
233            ))?;
234        }
235
236        if self.dx12_space != other.dx12_space {
237            Err(format!(
238                "Pass is using shaders in different stages with different dx12_space (set={} binding={})",
239                self.set_index, self.binding
240            ))?;
241        }
242
243        if self.gles_sampler_name.is_some()
244            && other.gles_sampler_name.is_some()
245            && self.gles_sampler_name != other.gles_sampler_name
246        {
247            Err(format!(
248                "Pass is using shaders in different stages with different non-None gles2_sampler_name (set={} binding={})",
249                self.set_index, self.binding
250            ))?;
251        }
252
253        Ok(())
254    }
255}
256
257/// Reflection data for a single shader stage
258#[derive(Debug, Clone, PartialEq, Eq, Hash)]
259#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
260pub struct RafxShaderStageReflection {
261    // For now, this doesn't do anything, so commented out
262    //pub vertex_inputs: Vec<RafxVertexInput>,
263    pub shader_stage: RafxShaderStageFlags,
264    pub resources: Vec<RafxShaderResource>,
265    pub compute_threads_per_group: Option<[u32; 3]>,
266    pub entry_point_name: String,
267    // Right now we will infer mappings based on spirv_cross default behavior, but likely will want
268    // to allow providing them explicitly. This isn't implemented yet
269    //pub binding_arg_buffer_mappings: FnvHashMap<(u32, u32), u32>
270}
271
272/// Reflection data for a pipeline, created by merging shader stage reflection data
273#[derive(Debug)]
274pub struct RafxPipelineReflection {
275    pub shader_stages: RafxShaderStageFlags,
276    pub resources: Vec<RafxShaderResource>,
277    pub compute_threads_per_group: Option<[u32; 3]>,
278}
279
280impl RafxPipelineReflection {
281    pub fn from_stages(stages: &[RafxShaderStageDef]) -> RafxResult<RafxPipelineReflection> {
282        let mut unmerged_resources = Vec::default();
283        for stage in stages {
284            assert!(!stage.reflection.shader_stage.is_empty());
285            for resource in &stage.reflection.resources {
286                // The provided resource MAY (but does not need to) have the shader stage flag set.
287                // (Leaving it default empty is fine). It will automatically be set here.
288                if !(resource.used_in_shader_stages - stage.reflection.shader_stage).is_empty() {
289                    let message = format!(
290                        "A resource in shader stage {:?} has other stages {:?} set",
291                        stage.reflection.shader_stage,
292                        resource.used_in_shader_stages - stage.reflection.shader_stage
293                    );
294                    log::error!("{}", message);
295                    Err(message)?;
296                }
297
298                let mut resource = resource.clone();
299                resource.used_in_shader_stages |= stage.reflection.shader_stage;
300                unmerged_resources.push(resource);
301            }
302        }
303
304        let mut compute_threads_per_group = None;
305        for stage in stages {
306            if stage
307                .reflection
308                .shader_stage
309                .intersects(RafxShaderStageFlags::COMPUTE)
310            {
311                compute_threads_per_group = stage.reflection.compute_threads_per_group;
312            }
313        }
314
315        log::trace!("Create RafxPipelineReflection from stages");
316        let mut all_shader_stages = RafxShaderStageFlags::empty();
317        for stage in stages {
318            if all_shader_stages.intersects(stage.reflection.shader_stage) {
319                Err(format!(
320                    "Duplicate shader stage ({}) found when creating RafxPipelineReflection",
321                    (all_shader_stages & stage.reflection.shader_stage).bits()
322                ))?;
323            }
324
325            all_shader_stages |= stage.reflection.shader_stage;
326        }
327
328        let mut merged_resources =
329            FnvHashMap::<RafxShaderResourceBindingKey, RafxShaderResource>::default();
330
331        //TODO: Merge push constants
332
333        //
334        // Merge the resources
335        //
336        for resource in &unmerged_resources {
337            log::trace!(
338                "    Resource {:?} from stage {:?}",
339                resource.name,
340                resource.used_in_shader_stages
341            );
342            let key = resource.binding_key();
343            if let Some(existing_resource) = merged_resources.get_mut(&key) {
344                // verify compatible
345                existing_resource.verify_compatible_across_stages(resource)?;
346
347                log::trace!(
348                    "      Already used in stages {:?} and is compatible, adding stage {:?}",
349                    existing_resource.used_in_shader_stages,
350                    resource.used_in_shader_stages,
351                );
352                existing_resource.used_in_shader_stages |= resource.used_in_shader_stages;
353                if existing_resource.gles_sampler_name.is_none() {
354                    existing_resource.gles_sampler_name = resource.gles_sampler_name.clone();
355                }
356            } else {
357                // insert it
358                log::trace!(
359                    "      Resource not yet used, adding it for stage {:?}",
360                    resource.used_in_shader_stages
361                );
362                assert!(!resource.used_in_shader_stages.is_empty());
363                let old = merged_resources.insert(key, resource.clone());
364                assert!(old.is_none());
365            }
366        }
367
368        let resources = merged_resources.into_iter().map(|(_, v)| v).collect();
369
370        Ok(RafxPipelineReflection {
371            shader_stages: all_shader_stages,
372            compute_threads_per_group,
373            resources,
374        })
375    }
376}
377
378///////////////////////////////////////////////////////////////////////////
379
380//TODO: Rename RafxReflected... to Rafx...Reflection
381#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
382#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
383pub struct RafxReflectedDescriptorSetLayoutBinding {
384    // Basic info required to create the RafxRootSignature
385    pub resource: RafxShaderResource,
386
387    // Samplers created here will be automatically created/bound
388    pub immutable_samplers: Option<Vec<RafxSamplerDef>>,
389
390    // If this is non-zero we will allocate a buffer owned by the descriptor set pool chunk,
391    // and automatically bind it - this makes binding data easy to do without having to manage
392    // buffers.
393    pub internal_buffer_per_descriptor_size: Option<u32>,
394}
395
396//TODO: Rename RafxReflected... to Rafx...Reflection
397#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
398#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
399pub struct RafxReflectedDescriptorSetLayout {
400    // These are NOT indexable by binding (i.e. may be sparse)
401    pub bindings: Vec<RafxReflectedDescriptorSetLayoutBinding>,
402}
403
404//TODO: Rename RafxReflected... to Rafx...Reflection
405#[derive(Debug, Clone, PartialEq, Eq, Hash)]
406#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
407pub struct RafxReflectedVertexInput {
408    pub name: String,
409    pub semantic: String,
410    pub location: u32,
411}
412
413//TODO: Rename RafxReflected... to Rafx...Reflection
414#[derive(Debug, Clone, PartialEq, Eq, Hash)]
415#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
416pub struct RafxReflectedEntryPoint {
417    // The reflection data used by rafx API
418    pub rafx_api_reflection: RafxShaderStageReflection,
419
420    // Additional reflection data used by the framework level for descriptor sets
421    pub descriptor_set_layouts: Vec<Option<RafxReflectedDescriptorSetLayout>>,
422
423    // Additional reflection data used by the framework level for vertex inputs
424    pub vertex_inputs: Vec<RafxReflectedVertexInput>,
425}