librashader_reflect/reflect/cross/
mod.rs

1#[doc(hidden)]
2pub mod glsl;
3
4#[doc(hidden)]
5pub mod hlsl;
6
7#[doc(hidden)]
8pub mod msl;
9
10use crate::error::{SemanticsErrorKind, ShaderReflectError};
11use crate::front::SpirvCompilation;
12use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
13use crate::reflect::semantics::{
14    BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderReflection, ShaderSemantics,
15    TextureBinding, TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo,
16    UniformMemberBlock, UniqueSemanticMap, UniqueSemantics, ValidateTypeSemantics, VariableMeta,
17    MAX_BINDINGS_COUNT, MAX_PUSH_BUFFER_SIZE,
18};
19use crate::reflect::{align_uniform_size, ReflectShader};
20use librashader_common::map::ShortString;
21use spirv_cross2::compile::CompiledArtifact;
22use spirv_cross2::reflect::{
23    AllResources, BitWidth, DecorationValue, Resource, Scalar, ScalarKind, TypeInner,
24};
25use spirv_cross2::spirv::Decoration;
26use spirv_cross2::Compiler;
27use spirv_cross2::Module;
28use std::fmt::Debug;
29
30/// Reflect shaders under SPIRV-Cross semantics.
31///
32/// SPIRV-Cross supports GLSL, HLSL, SPIR-V, and MSL targets.
33#[derive(Debug)]
34pub struct SpirvCross;
35
36// todo: make this under a mutex
37pub(crate) struct CrossReflect<T>
38where
39    T: spirv_cross2::compile::CompilableTarget,
40{
41    vertex: Compiler<T>,
42    fragment: Compiler<T>,
43}
44
45/// The compiled SPIR-V program after compilation.
46pub struct CompiledProgram<T>
47where
48    T: spirv_cross2::compile::CompilableTarget,
49{
50    pub vertex: CompiledArtifact<T>,
51    pub fragment: CompiledArtifact<T>,
52}
53
54impl ValidateTypeSemantics<TypeInner<'_>> for UniqueSemantics {
55    fn validate_type(&self, ty: &TypeInner) -> Option<TypeInfo> {
56        let (TypeInner::Vector { .. } | TypeInner::Scalar { .. } | TypeInner::Matrix { .. }) = *ty
57        else {
58            return None;
59        };
60
61        match self {
62            UniqueSemantics::MVP => {
63                if matches!(ty, TypeInner::Matrix { columns, rows, scalar: Scalar { size, .. } } if *columns == 4
64                    && *rows == 4 && *size == BitWidth::Word)
65                {
66                    return Some(TypeInfo {
67                        size: 4,
68                        columns: 4,
69                    });
70                }
71            }
72            UniqueSemantics::FrameCount
73            | UniqueSemantics::Rotation
74            | UniqueSemantics::CurrentSubFrame
75            | UniqueSemantics::TotalSubFrames
76            | UniqueSemantics::FrameTimeDelta => {
77                // Uint32 == width 4
78                if matches!(ty, TypeInner::Scalar( Scalar { kind, size }) if *kind == ScalarKind::Uint && *size == BitWidth::Word)
79                {
80                    return Some(TypeInfo {
81                        size: 1,
82                        columns: 1,
83                    });
84                }
85            }
86            UniqueSemantics::FrameDirection => {
87                // iint32 == width 4
88                if matches!(ty, TypeInner::Scalar( Scalar { kind, size }) if *kind == ScalarKind::Int && *size == BitWidth::Word)
89                {
90                    return Some(TypeInfo {
91                        size: 1,
92                        columns: 1,
93                    });
94                }
95            }
96            UniqueSemantics::FloatParameter
97            | UniqueSemantics::OriginalAspectRotated
98            | UniqueSemantics::OriginalAspect
99            | UniqueSemantics::OriginalFPS => {
100                // Float32 == width 4
101                if matches!(ty, TypeInner::Scalar( Scalar { kind, size }) if *kind == ScalarKind::Float && *size == BitWidth::Word)
102                {
103                    return Some(TypeInfo {
104                        size: 1,
105                        columns: 1,
106                    });
107                }
108            }
109            _ => {
110                if matches!(ty, TypeInner::Vector { scalar: Scalar { size, kind }, width: vecwidth, .. }
111                    if *kind == ScalarKind::Float && *size == BitWidth::Word && *vecwidth == 4)
112                {
113                    return Some(TypeInfo {
114                        size: 4,
115                        columns: 1,
116                    });
117                }
118            }
119        };
120
121        None
122    }
123}
124
125impl ValidateTypeSemantics<TypeInner<'_>> for TextureSemantics {
126    fn validate_type(&self, ty: &TypeInner) -> Option<TypeInfo> {
127        let TypeInner::Vector {
128            scalar: Scalar { size, kind },
129            width: vecwidth,
130        } = ty
131        else {
132            return None;
133        };
134
135        if *kind == ScalarKind::Float && *size == BitWidth::Word && *vecwidth == 4 {
136            return Some(TypeInfo {
137                size: 4,
138                columns: 1,
139            });
140        }
141
142        None
143    }
144}
145
146impl<T> TryFrom<&SpirvCompilation> for CrossReflect<T>
147where
148    T: spirv_cross2::compile::CompilableTarget,
149{
150    type Error = ShaderReflectError;
151
152    fn try_from(value: &SpirvCompilation) -> Result<Self, Self::Error> {
153        let vertex_module = Module::from_words(&value.vertex);
154        let fragment_module = Module::from_words(&value.fragment);
155
156        let vertex = Compiler::new(vertex_module)?;
157        let fragment = Compiler::new(fragment_module)?;
158
159        Ok(CrossReflect { vertex, fragment })
160    }
161}
162
163impl<T> CrossReflect<T>
164where
165    T: spirv_cross2::compile::CompilableTarget,
166{
167    fn validate_semantics(
168        &self,
169        vertex_res: &AllResources,
170        fragment_res: &AllResources,
171    ) -> Result<(), ShaderReflectError> {
172        if !vertex_res.sampled_images.is_empty()
173            || !vertex_res.storage_buffers.is_empty()
174            || !vertex_res.subpass_inputs.is_empty()
175            || !vertex_res.storage_images.is_empty()
176            || !vertex_res.atomic_counters.is_empty()
177        {
178            return Err(ShaderReflectError::VertexSemanticError(
179                SemanticsErrorKind::InvalidResourceType,
180            ));
181        }
182
183        if !fragment_res.storage_buffers.is_empty()
184            || !fragment_res.subpass_inputs.is_empty()
185            || !fragment_res.storage_images.is_empty()
186            || !fragment_res.atomic_counters.is_empty()
187        {
188            return Err(ShaderReflectError::FragmentSemanticError(
189                SemanticsErrorKind::InvalidResourceType,
190            ));
191        }
192
193        let vert_inputs = vertex_res.stage_inputs.len();
194        if vert_inputs != 2 {
195            return Err(ShaderReflectError::VertexSemanticError(
196                SemanticsErrorKind::InvalidInputCount(vert_inputs),
197            ));
198        }
199
200        let frag_outputs = fragment_res.stage_outputs.len();
201        if frag_outputs != 1 {
202            return Err(ShaderReflectError::FragmentSemanticError(
203                SemanticsErrorKind::InvalidOutputCount(frag_outputs),
204            ));
205        }
206
207        let Some(DecorationValue::Literal(fragment_location)) = self
208            .fragment
209            .decoration(fragment_res.stage_outputs[0].id, Decoration::Location)?
210        else {
211            return Err(ShaderReflectError::FragmentSemanticError(
212                SemanticsErrorKind::MissingBinding,
213            ));
214        };
215
216        if fragment_location != 0 {
217            return Err(ShaderReflectError::FragmentSemanticError(
218                SemanticsErrorKind::InvalidLocation(fragment_location),
219            ));
220        }
221
222        // Ensure that vertex attributes use location 0 and 1
223        // Verify Vertex inputs
224        'vertex: {
225            let entry_points = self.vertex.entry_points()?;
226            if entry_points.len() != 1 {
227                return Err(ShaderReflectError::VertexSemanticError(
228                    SemanticsErrorKind::InvalidEntryPointCount(entry_points.len()),
229                ));
230            }
231
232            let vert_inputs = vertex_res.stage_inputs.len();
233            if vert_inputs != 2 {
234                return Err(ShaderReflectError::VertexSemanticError(
235                    SemanticsErrorKind::InvalidInputCount(vert_inputs),
236                ));
237            }
238
239            for input in &vertex_res.stage_inputs {
240                let location = self.vertex.decoration(input.id, Decoration::Location)?;
241                let Some(DecorationValue::Literal(location)) = location else {
242                    return Err(ShaderReflectError::VertexSemanticError(
243                        SemanticsErrorKind::MissingBinding,
244                    ));
245                };
246
247                if location == 0 {
248                    let pos_type = &self.vertex.type_description(input.base_type_id)?;
249                    if !matches!(pos_type.inner, TypeInner::Vector { width, ..} if width == 4) {
250                        return Err(ShaderReflectError::VertexSemanticError(
251                            SemanticsErrorKind::InvalidLocation(location),
252                        ));
253                    }
254                    break 'vertex;
255                }
256
257                if location == 1 {
258                    let coord_type = &self.vertex.type_description(input.base_type_id)?;
259                    if !matches!(coord_type.inner, TypeInner::Vector { width, ..} if width == 2) {
260                        return Err(ShaderReflectError::VertexSemanticError(
261                            SemanticsErrorKind::InvalidLocation(location),
262                        ));
263                    }
264                    break 'vertex;
265                }
266
267                return Err(ShaderReflectError::VertexSemanticError(
268                    SemanticsErrorKind::InvalidLocation(location),
269                ));
270            }
271        }
272
273        if vertex_res.uniform_buffers.len() > 1 {
274            return Err(ShaderReflectError::VertexSemanticError(
275                SemanticsErrorKind::InvalidUniformBufferCount(vertex_res.uniform_buffers.len()),
276            ));
277        }
278
279        if vertex_res.push_constant_buffers.len() > 1 {
280            return Err(ShaderReflectError::VertexSemanticError(
281                SemanticsErrorKind::InvalidPushBufferCount(vertex_res.push_constant_buffers.len()),
282            ));
283        }
284
285        if fragment_res.uniform_buffers.len() > 1 {
286            return Err(ShaderReflectError::FragmentSemanticError(
287                SemanticsErrorKind::InvalidUniformBufferCount(fragment_res.uniform_buffers.len()),
288            ));
289        }
290
291        if fragment_res.push_constant_buffers.len() > 1 {
292            return Err(ShaderReflectError::FragmentSemanticError(
293                SemanticsErrorKind::InvalidPushBufferCount(
294                    fragment_res.push_constant_buffers.len(),
295                ),
296            ));
297        }
298        Ok(())
299    }
300}
301
302impl<T> CrossReflect<T>
303where
304    T: spirv_cross2::compile::CompilableTarget,
305{
306    fn get_ubo_data(
307        ast: &Compiler<T>,
308        ubo: &Resource,
309        blame: SemanticErrorBlame,
310    ) -> Result<UboData, ShaderReflectError> {
311        let Some(descriptor_set) = ast
312            .decoration(ubo.id, Decoration::DescriptorSet)?
313            .and_then(|l| l.as_literal())
314        else {
315            return Err(blame.error(SemanticsErrorKind::MissingBinding));
316        };
317
318        let Some(binding) = ast
319            .decoration(ubo.id, Decoration::Binding)?
320            .and_then(|l| l.as_literal())
321        else {
322            return Err(blame.error(SemanticsErrorKind::MissingBinding));
323        };
324
325        if binding >= MAX_BINDINGS_COUNT {
326            return Err(blame.error(SemanticsErrorKind::InvalidBinding(binding)));
327        }
328        if descriptor_set != 0 {
329            return Err(blame.error(SemanticsErrorKind::InvalidDescriptorSet(descriptor_set)));
330        }
331
332        let size = ast.type_description(ubo.base_type_id)?.size_hint.declared() as u32;
333        Ok(UboData { binding, size })
334    }
335
336    fn get_push_size(
337        ast: &Compiler<T>,
338        push: &Resource,
339        blame: SemanticErrorBlame,
340    ) -> Result<u32, ShaderReflectError> {
341        let size = ast
342            .type_description(push.base_type_id)?
343            .size_hint
344            .declared() as u32;
345        if size > MAX_PUSH_BUFFER_SIZE {
346            return Err(blame.error(SemanticsErrorKind::InvalidPushBufferSize(size)));
347        }
348        Ok(size)
349    }
350
351    fn reflect_buffer_range_metas(
352        ast: &Compiler<T>,
353        resource: &Resource,
354        pass_number: usize,
355        semantics: &ShaderSemantics,
356        meta: &mut BindingMeta,
357        offset_type: UniformMemberBlock,
358        blame: SemanticErrorBlame,
359    ) -> Result<(), ShaderReflectError> {
360        let ranges = ast.active_buffer_ranges(resource.id)?;
361        for range in ranges {
362            let Some(name) = ast.member_name(resource.base_type_id, range.index)? else {
363                // member has no name!
364                return Err(blame.error(SemanticsErrorKind::InvalidRange(range.index)));
365            };
366
367            let ubo_type = ast.type_description(resource.base_type_id)?;
368            let range_type = match ubo_type.inner {
369                TypeInner::Struct(struct_def) => {
370                    let range_type = struct_def
371                        .members
372                        .get(range.index as usize)
373                        .ok_or(blame.error(SemanticsErrorKind::InvalidRange(range.index)))?;
374                    ast.type_description(range_type.id)?
375                }
376                _ => return Err(blame.error(SemanticsErrorKind::InvalidResourceType)),
377            };
378
379            if let Some(parameter) = semantics.uniform_semantics.unique_semantic(&name) {
380                let Some(typeinfo) = parameter.semantics.validate_type(&range_type.inner) else {
381                    return Err(
382                        blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name.to_string()))
383                    );
384                };
385
386                match &parameter.semantics {
387                    UniqueSemantics::FloatParameter => {
388                        let offset = range.offset;
389                        if let Some(meta) = meta.parameter_meta.get_mut::<str>(&name.as_ref()) {
390                            if let Some(expected) = meta
391                                .offset
392                                .offset(offset_type)
393                                .filter(|expected| *expected != offset)
394                            {
395                                return Err(ShaderReflectError::MismatchedOffset {
396                                    semantic: name.to_string(),
397                                    expected,
398                                    received: offset,
399                                    ty: offset_type,
400                                    pass: pass_number,
401                                });
402                            }
403                            if meta.size != typeinfo.size {
404                                return Err(ShaderReflectError::MismatchedSize {
405                                    semantic: name.to_string(),
406                                    vertex: meta.size,
407                                    fragment: typeinfo.size,
408                                    pass: pass_number,
409                                });
410                            }
411
412                            *meta.offset.offset_mut(offset_type) = Some(offset);
413                        } else {
414                            let name = ShortString::from(name.as_ref());
415                            meta.parameter_meta.insert(
416                                name.clone(),
417                                VariableMeta {
418                                    id: name,
419                                    offset: MemberOffset::new(offset, offset_type),
420                                    size: typeinfo.size,
421                                },
422                            );
423                        }
424                    }
425                    semantics => {
426                        let offset = range.offset;
427                        if let Some(meta) = meta.unique_meta.get_mut(semantics) {
428                            if let Some(expected) = meta
429                                .offset
430                                .offset(offset_type)
431                                .filter(|expected| *expected != offset)
432                            {
433                                return Err(ShaderReflectError::MismatchedOffset {
434                                    semantic: name.to_string(),
435                                    expected,
436                                    received: offset,
437                                    ty: offset_type,
438                                    pass: pass_number,
439                                });
440                            }
441                            if meta.size != typeinfo.size * typeinfo.columns {
442                                return Err(ShaderReflectError::MismatchedSize {
443                                    semantic: name.to_string(),
444                                    vertex: meta.size,
445                                    fragment: typeinfo.size,
446                                    pass: pass_number,
447                                });
448                            }
449
450                            *meta.offset.offset_mut(offset_type) = Some(offset);
451                        } else {
452                            meta.unique_meta.insert(
453                                *semantics,
454                                VariableMeta {
455                                    id: ShortString::from(name.as_ref()),
456                                    offset: MemberOffset::new(offset, offset_type),
457                                    size: typeinfo.size * typeinfo.columns,
458                                },
459                            );
460                        }
461                    }
462                }
463            } else if let Some(texture) = semantics.uniform_semantics.texture_semantic(&name) {
464                let Some(_typeinfo) = texture.semantics.validate_type(&range_type.inner) else {
465                    return Err(
466                        blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name.to_string()))
467                    );
468                };
469
470                if let TextureSemantics::PassOutput = texture.semantics {
471                    if texture.index >= pass_number {
472                        return Err(ShaderReflectError::NonCausalFilterChain {
473                            pass: pass_number,
474                            target: texture.index,
475                        });
476                    }
477                }
478
479                let offset = range.offset;
480                if let Some(meta) = meta.texture_size_meta.get_mut(&texture) {
481                    if let Some(expected) = meta
482                        .offset
483                        .offset(offset_type)
484                        .filter(|expected| *expected != offset)
485                    {
486                        return Err(ShaderReflectError::MismatchedOffset {
487                            semantic: name.to_string(),
488                            expected,
489                            received: offset,
490                            ty: offset_type,
491                            pass: pass_number,
492                        });
493                    }
494
495                    meta.stage_mask.insert(match blame {
496                        SemanticErrorBlame::Vertex => BindingStage::VERTEX,
497                        SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
498                    });
499
500                    *meta.offset.offset_mut(offset_type) = Some(offset);
501                } else {
502                    meta.texture_size_meta.insert(
503                        texture,
504                        TextureSizeMeta {
505                            offset: MemberOffset::new(offset, offset_type),
506                            stage_mask: match blame {
507                                SemanticErrorBlame::Vertex => BindingStage::VERTEX,
508                                SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
509                            },
510                            id: ShortString::from(name.as_ref()),
511                        },
512                    );
513                }
514            } else {
515                return Err(blame.error(SemanticsErrorKind::UnknownSemantics(name.to_string())));
516            }
517        }
518        Ok(())
519    }
520
521    fn reflect_ubos(
522        &mut self,
523        vertex_ubo: Option<&Resource>,
524        fragment_ubo: Option<&Resource>,
525    ) -> Result<Option<BufferReflection<u32>>, ShaderReflectError> {
526        if let Some(vertex_ubo) = vertex_ubo {
527            self.vertex
528                .set_decoration(vertex_ubo.id, Decoration::Binding, Some(0))?;
529        }
530
531        if let Some(fragment_ubo) = fragment_ubo {
532            self.fragment
533                .set_decoration(fragment_ubo.id, Decoration::Binding, Some(0))?;
534        }
535
536        match (vertex_ubo, fragment_ubo) {
537            (None, None) => Ok(None),
538            (Some(vertex_ubo), Some(fragment_ubo)) => {
539                let vertex_ubo =
540                    Self::get_ubo_data(&self.vertex, vertex_ubo, SemanticErrorBlame::Vertex)?;
541                let fragment_ubo =
542                    Self::get_ubo_data(&self.fragment, fragment_ubo, SemanticErrorBlame::Fragment)?;
543                if vertex_ubo.binding != fragment_ubo.binding {
544                    return Err(ShaderReflectError::MismatchedUniformBuffer {
545                        vertex: vertex_ubo.binding,
546                        fragment: fragment_ubo.binding,
547                    });
548                }
549
550                let size = std::cmp::max(vertex_ubo.size, fragment_ubo.size);
551                Ok(Some(BufferReflection {
552                    binding: vertex_ubo.binding,
553                    size: align_uniform_size(size),
554                    stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
555                }))
556            }
557            (Some(vertex_ubo), None) => {
558                let vertex_ubo =
559                    Self::get_ubo_data(&self.vertex, vertex_ubo, SemanticErrorBlame::Vertex)?;
560                Ok(Some(BufferReflection {
561                    binding: vertex_ubo.binding,
562                    size: align_uniform_size(vertex_ubo.size),
563                    stage_mask: BindingStage::VERTEX,
564                }))
565            }
566            (None, Some(fragment_ubo)) => {
567                let fragment_ubo =
568                    Self::get_ubo_data(&self.fragment, fragment_ubo, SemanticErrorBlame::Fragment)?;
569                Ok(Some(BufferReflection {
570                    binding: fragment_ubo.binding,
571                    size: align_uniform_size(fragment_ubo.size),
572                    stage_mask: BindingStage::FRAGMENT,
573                }))
574            }
575        }
576    }
577
578    fn reflect_texture_metas(
579        &self,
580        texture: TextureData,
581        pass_number: usize,
582        semantics: &ShaderSemantics,
583        meta: &mut BindingMeta,
584    ) -> Result<(), ShaderReflectError> {
585        let Some(semantic) = semantics.texture_semantics.texture_semantic(texture.name) else {
586            return Err(
587                SemanticErrorBlame::Fragment.error(SemanticsErrorKind::UnknownSemantics(
588                    texture.name.to_string(),
589                )),
590            );
591        };
592
593        if semantic.semantics == TextureSemantics::PassOutput && semantic.index >= pass_number {
594            return Err(ShaderReflectError::NonCausalFilterChain {
595                pass: pass_number,
596                target: semantic.index,
597            });
598        }
599
600        meta.texture_meta.insert(
601            semantic,
602            TextureBinding {
603                binding: texture.binding,
604            },
605        );
606        Ok(())
607    }
608
609    fn reflect_texture<'a>(
610        &'a self,
611        texture: &'a Resource,
612    ) -> Result<TextureData<'a>, ShaderReflectError> {
613        let Some(descriptor_set) = self
614            .fragment
615            .decoration(texture.id, Decoration::DescriptorSet)?
616            .and_then(|l| l.as_literal())
617        else {
618            return Err(ShaderReflectError::FragmentSemanticError(
619                SemanticsErrorKind::MissingBinding,
620            ));
621        };
622        let Some(binding) = self
623            .fragment
624            .decoration(texture.id, Decoration::Binding)?
625            .and_then(|l| l.as_literal())
626        else {
627            return Err(ShaderReflectError::FragmentSemanticError(
628                SemanticsErrorKind::MissingBinding,
629            ));
630        };
631
632        if descriptor_set != 0 {
633            return Err(ShaderReflectError::FragmentSemanticError(
634                SemanticsErrorKind::InvalidDescriptorSet(descriptor_set),
635            ));
636        }
637        if binding >= MAX_BINDINGS_COUNT {
638            return Err(ShaderReflectError::FragmentSemanticError(
639                SemanticsErrorKind::InvalidBinding(binding),
640            ));
641        }
642
643        Ok(TextureData {
644            // id: texture.id,
645            // descriptor_set,
646            name: &texture.name,
647            binding,
648        })
649    }
650
651    fn reflect_push_constant_buffer(
652        &mut self,
653        vertex_pcb: Option<&Resource>,
654        fragment_pcb: Option<&Resource>,
655    ) -> Result<Option<BufferReflection<Option<u32>>>, ShaderReflectError> {
656        if let Some(vertex_pcb) = vertex_pcb {
657            self.vertex
658                .set_decoration(vertex_pcb.id, Decoration::Binding, Some(1))?;
659        }
660
661        if let Some(fragment_pcb) = fragment_pcb {
662            self.fragment
663                .set_decoration(fragment_pcb.id, Decoration::Binding, Some(1))?;
664        }
665
666        match (vertex_pcb, fragment_pcb) {
667            (None, None) => Ok(None),
668            (Some(vertex_push), Some(fragment_push)) => {
669                let vertex_size =
670                    Self::get_push_size(&self.vertex, vertex_push, SemanticErrorBlame::Vertex)?;
671                let fragment_size = Self::get_push_size(
672                    &self.fragment,
673                    fragment_push,
674                    SemanticErrorBlame::Fragment,
675                )?;
676
677                let size = std::cmp::max(vertex_size, fragment_size);
678
679                Ok(Some(BufferReflection {
680                    binding: None,
681                    size: align_uniform_size(size),
682                    stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
683                }))
684            }
685            (Some(vertex_push), None) => {
686                let vertex_size =
687                    Self::get_push_size(&self.vertex, vertex_push, SemanticErrorBlame::Vertex)?;
688                Ok(Some(BufferReflection {
689                    binding: None,
690                    size: align_uniform_size(vertex_size),
691                    stage_mask: BindingStage::VERTEX,
692                }))
693            }
694            (None, Some(fragment_push)) => {
695                let fragment_size = Self::get_push_size(
696                    &self.fragment,
697                    fragment_push,
698                    SemanticErrorBlame::Fragment,
699                )?;
700                Ok(Some(BufferReflection {
701                    binding: None,
702                    size: align_uniform_size(fragment_size),
703                    stage_mask: BindingStage::FRAGMENT,
704                }))
705            }
706        }
707    }
708}
709
710impl<T> ReflectShader for CrossReflect<T>
711where
712    T: spirv_cross2::compile::CompilableTarget,
713{
714    fn reflect(
715        &mut self,
716        pass_number: usize,
717        semantics: &ShaderSemantics,
718    ) -> Result<ShaderReflection, ShaderReflectError> {
719        let vertex_res = self.vertex.shader_resources()?.all_resources()?;
720        let fragment_res = self.fragment.shader_resources()?.all_resources()?;
721        self.validate_semantics(&vertex_res, &fragment_res)?;
722
723        let vertex_ubo = vertex_res.uniform_buffers.first();
724        let fragment_ubo = fragment_res.uniform_buffers.first();
725
726        let ubo = self.reflect_ubos(vertex_ubo, fragment_ubo)?;
727
728        let vertex_push = vertex_res.push_constant_buffers.first();
729        let fragment_push = fragment_res.push_constant_buffers.first();
730
731        let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
732
733        let mut meta = BindingMeta::default();
734
735        if let Some(ubo) = vertex_ubo {
736            Self::reflect_buffer_range_metas(
737                &self.vertex,
738                ubo,
739                pass_number,
740                semantics,
741                &mut meta,
742                UniformMemberBlock::Ubo,
743                SemanticErrorBlame::Vertex,
744            )?;
745        }
746
747        if let Some(ubo) = fragment_ubo {
748            Self::reflect_buffer_range_metas(
749                &self.fragment,
750                ubo,
751                pass_number,
752                semantics,
753                &mut meta,
754                UniformMemberBlock::Ubo,
755                SemanticErrorBlame::Fragment,
756            )?;
757        }
758
759        if let Some(push) = vertex_push {
760            Self::reflect_buffer_range_metas(
761                &self.vertex,
762                push,
763                pass_number,
764                semantics,
765                &mut meta,
766                UniformMemberBlock::PushConstant,
767                SemanticErrorBlame::Vertex,
768            )?;
769        }
770
771        if let Some(push) = fragment_push {
772            Self::reflect_buffer_range_metas(
773                &self.fragment,
774                push,
775                pass_number,
776                semantics,
777                &mut meta,
778                UniformMemberBlock::PushConstant,
779                SemanticErrorBlame::Fragment,
780            )?;
781        }
782
783        let mut ubo_bindings = 0u16;
784        if vertex_ubo.is_some() || fragment_ubo.is_some() {
785            ubo_bindings = 1 << ubo.as_ref().expect("UBOs should be present").binding;
786        }
787
788        for sampled_image in &fragment_res.sampled_images {
789            let texture_data = self.reflect_texture(sampled_image)?;
790            if ubo_bindings & (1 << texture_data.binding) != 0 {
791                return Err(ShaderReflectError::BindingInUse(texture_data.binding));
792            }
793            ubo_bindings |= 1 << texture_data.binding;
794
795            self.reflect_texture_metas(texture_data, pass_number, semantics, &mut meta)?;
796        }
797
798        Ok(ShaderReflection {
799            ubo,
800            push_constant,
801            meta,
802        })
803    }
804
805    fn validate(&mut self) -> Result<(), ShaderReflectError> {
806        let vertex_res = self.vertex.shader_resources()?.all_resources()?;
807        let fragment_res = self.fragment.shader_resources()?.all_resources()?;
808        self.validate_semantics(&vertex_res, &fragment_res)?;
809        let vertex_ubo = vertex_res.uniform_buffers.first();
810        let fragment_ubo = fragment_res.uniform_buffers.first();
811
812        self.reflect_ubos(vertex_ubo, fragment_ubo)?;
813
814        let vertex_push = vertex_res.push_constant_buffers.first();
815        let fragment_push = fragment_res.push_constant_buffers.first();
816
817        self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
818
819        Ok(())
820    }
821}
822
823#[cfg(test)]
824mod test {
825    use crate::reflect::cross::CrossReflect;
826    use crate::reflect::ReflectShader;
827    use rustc_hash::FxHashMap;
828
829    use crate::back::hlsl::CrossHlslContext;
830    use crate::back::targets::HLSL;
831    use crate::back::{CompileShader, ShaderCompilerOutput};
832    use crate::front::{Glslang, ShaderInputCompiler};
833    use crate::reflect::semantics::{Semantic, ShaderSemantics, UniformSemantic, UniqueSemantics};
834    use librashader_common::map::{FastHashMap, ShortString};
835    use librashader_preprocess::ShaderSource;
836
837    // #[test]
838    // pub fn test_into() {
839    //     let result = ShaderSource::load("../test/basic.slang").unwrap();
840    //     let mut uniform_semantics: FastHashMap<ShortString, UniformSemantic> = Default::default();
841    //
842    //     for (_index, param) in result.parameters.iter().enumerate() {
843    //         uniform_semantics.insert(
844    //             param.1.id.clone(),
845    //             UniformSemantic::Unique(Semantic {
846    //                 semantics: UniqueSemantics::FloatParameter,
847    //                 index: (),
848    //             }),
849    //         );
850    //     }
851    //     let spirv = Glslang::compile(&result).unwrap();
852    //     let mut reflect = CrossReflect::<hlsl::Target>::try_from(&spirv).unwrap();
853    //     let shader_reflection = reflect
854    //         .reflect(
855    //             0,
856    //             &ShaderSemantics {
857    //                 uniform_semantics,
858    //                 texture_semantics: Default::default(),
859    //             },
860    //         )
861    //         .unwrap();
862    //     let mut opts = hlsl::CompilerOptions::default();
863    //     opts.shader_model = ShaderModel::V3_0;
864    //
865    //     let compiled: ShaderCompilerOutput<String, CrossHlslContext> =
866    //         <CrossReflect<hlsl::Target> as CompileShader<HLSL>>::compile(
867    //             reflect,
868    //             Some(ShaderModel::V3_0),
869    //         )
870    //         .unwrap();
871    //
872    //     println!("{:?}", shader_reflection.meta);
873    //     println!("{}", compiled.fragment);
874    //     println!("{}", compiled.vertex);
875    //
876    //     // // eprintln!("{shader_reflection:#?}");
877    //     // eprintln!("{}", compiled.fragment)
878    //     // let mut loader = rspirv::dr::Loader::new();
879    //     // rspirv::binary::parse_words(spirv.fragment.as_binary(), &mut loader).unwrap();
880    //     // let module = loader.module();
881    //     // println!("{:#}", module.disassemble());
882    // }
883}