librashader_reflect/reflect/naga/
mod.rs

1#[doc(hidden)]
2pub mod msl;
3
4#[doc(hidden)]
5pub mod spirv;
6
7#[doc(hidden)]
8pub mod wgsl;
9
10use crate::error::{SemanticsErrorKind, ShaderReflectError};
11use std::fmt::Debug;
12
13use crate::front::spirv_passes::lower_samplers;
14use crate::front::SpirvCompilation;
15use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
16use crate::reflect::semantics::{
17    BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding,
18    TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock,
19    UniqueSemanticMap, UniqueSemantics, ValidateTypeSemantics, VariableMeta, MAX_BINDINGS_COUNT,
20    MAX_PUSH_BUFFER_SIZE,
21};
22use crate::reflect::{align_uniform_size, ReflectShader, ShaderReflection};
23use librashader_common::map::ShortString;
24use naga::{
25    AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
26    Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
27};
28use rspirv::binary::Assemble;
29use rspirv::dr::Builder;
30use rustc_hash::FxHashSet;
31
32/// Reflect under Naga semantics
33///
34/// The Naga reflector will lower combined image samplers to split,
35/// with the same bind point on descriptor group 1.
36///
37/// Naga supports WGSL, SPIR-V, and MSL targets.
38#[derive(Debug)]
39pub struct Naga;
40#[derive(Debug)]
41pub(crate) struct NagaReflect {
42    pub(crate) vertex: Module,
43    pub(crate) fragment: Module,
44}
45
46/// Options to lower samplers and pcbs
47#[derive(Debug, Default, Clone)]
48pub struct NagaLoweringOptions {
49    /// Whether to write the PCB as a UBO.
50    pub write_pcb_as_ubo: bool,
51    /// The bind group to assign samplers to. This is to ensure that samplers will
52    /// maintain the same bindings as textures.
53    pub sampler_bind_group: u32,
54}
55
56impl NagaReflect {
57    pub fn do_lowering(&mut self, options: &NagaLoweringOptions) {
58        if options.write_pcb_as_ubo {
59            for (_, gv) in self.fragment.global_variables.iter_mut() {
60                if gv.space == AddressSpace::PushConstant {
61                    gv.space = AddressSpace::Uniform;
62                }
63            }
64
65            for (_, gv) in self.vertex.global_variables.iter_mut() {
66                if gv.space == AddressSpace::PushConstant {
67                    gv.space = AddressSpace::Uniform;
68                }
69            }
70        } else {
71            for (_, gv) in self.fragment.global_variables.iter_mut() {
72                if gv.space == AddressSpace::PushConstant {
73                    gv.binding = None;
74                }
75            }
76        }
77
78        // Reassign shit.
79        let images = self
80            .fragment
81            .global_variables
82            .iter()
83            .filter(|&(_, gv)| {
84                let ty = &self.fragment.types[gv.ty];
85                match ty.inner {
86                    naga::TypeInner::Image { .. } => true,
87                    naga::TypeInner::BindingArray { base, .. } => {
88                        let ty = &self.fragment.types[base];
89                        matches!(ty.inner, naga::TypeInner::Image { .. })
90                    }
91                    _ => false,
92                }
93            })
94            .map(|(_, gv)| (gv.binding.clone(), gv.space))
95            .collect::<naga::FastHashSet<_>>();
96
97        self.fragment
98            .global_variables
99            .iter_mut()
100            .filter(|(_, gv)| {
101                let ty = &self.fragment.types[gv.ty];
102                match ty.inner {
103                    naga::TypeInner::Sampler { .. } => true,
104                    naga::TypeInner::BindingArray { base, .. } => {
105                        let ty = &self.fragment.types[base];
106                        matches!(ty.inner, naga::TypeInner::Sampler { .. })
107                    }
108                    _ => false,
109                }
110            })
111            .for_each(|(_, gv)| {
112                if images.contains(&(gv.binding.clone(), gv.space)) {
113                    if let Some(binding) = &mut gv.binding {
114                        binding.group = options.sampler_bind_group;
115                    }
116                }
117            });
118    }
119}
120
121impl TryFrom<&SpirvCompilation> for NagaReflect {
122    type Error = ShaderReflectError;
123
124    fn try_from(compile: &SpirvCompilation) -> Result<Self, Self::Error> {
125        fn lower_fragment_shader(builder: &mut Builder) {
126            let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(builder);
127            pass.ensure_op_type_sampler();
128            pass.do_pass();
129        }
130
131        let options = naga::front::spv::Options {
132            adjust_coordinate_space: true,
133            strict_capabilities: false,
134            block_ctx_dump_prefix: None,
135        };
136
137        let vertex = crate::front::spirv_passes::load_module(&compile.vertex);
138        let fragment = crate::front::spirv_passes::load_module(&compile.fragment);
139
140        let mut fragment = Builder::new_from_module(fragment);
141        lower_fragment_shader(&mut fragment);
142
143        let vertex = vertex.assemble();
144        let fragment = fragment.module().assemble();
145
146        let vertex = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&vertex), &options)?;
147        let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?;
148
149        Ok(NagaReflect { vertex, fragment })
150    }
151}
152
153impl ValidateTypeSemantics<&TypeInner> for UniqueSemantics {
154    fn validate_type(&self, ty: &&TypeInner) -> Option<TypeInfo> {
155        let (TypeInner::Vector { .. } | TypeInner::Scalar { .. } | TypeInner::Matrix { .. }) = *ty
156        else {
157            return None;
158        };
159
160        match self {
161            UniqueSemantics::MVP => {
162                if matches!(ty, TypeInner::Matrix { columns, rows, scalar: Scalar { width, .. } } if *columns == VectorSize::Quad
163                    && *rows == VectorSize::Quad && *width == 4)
164                {
165                    return Some(TypeInfo {
166                        size: 4,
167                        columns: 4,
168                    });
169                }
170            }
171            UniqueSemantics::FrameCount
172            | UniqueSemantics::Rotation
173            | UniqueSemantics::CurrentSubFrame
174            | UniqueSemantics::TotalSubFrames
175            | UniqueSemantics::FrameTimeDelta => {
176                // Uint32 == width 4
177                if matches!(ty, TypeInner::Scalar( Scalar { kind, width }) if *kind == ScalarKind::Uint && *width == 4)
178                {
179                    return Some(TypeInfo {
180                        size: 1,
181                        columns: 1,
182                    });
183                }
184            }
185            UniqueSemantics::FrameDirection => {
186                // iint32 == width 4
187                if matches!(ty, TypeInner::Scalar( Scalar { kind, width }) if *kind == ScalarKind::Sint && *width == 4)
188                {
189                    return Some(TypeInfo {
190                        size: 1,
191                        columns: 1,
192                    });
193                }
194            }
195            UniqueSemantics::FloatParameter
196            | UniqueSemantics::OriginalFPS
197            | UniqueSemantics::OriginalAspectRotated
198            | UniqueSemantics::OriginalAspect => {
199                // Float32 == width 4
200                if matches!(ty, TypeInner::Scalar( Scalar { kind, width }) if *kind == ScalarKind::Float && *width == 4)
201                {
202                    return Some(TypeInfo {
203                        size: 1,
204                        columns: 1,
205                    });
206                }
207            }
208            _ => {
209                if matches!(ty, TypeInner::Vector { scalar: Scalar { width, kind }, size } if *kind == ScalarKind::Float && *width == 4 && *size == VectorSize::Quad)
210                {
211                    return Some(TypeInfo {
212                        size: 4,
213                        columns: 1,
214                    });
215                }
216            }
217        };
218
219        None
220    }
221}
222
223impl ValidateTypeSemantics<&TypeInner> for TextureSemantics {
224    fn validate_type(&self, ty: &&TypeInner) -> Option<TypeInfo> {
225        let TypeInner::Vector {
226            scalar: Scalar { width, kind },
227            size,
228        } = ty
229        else {
230            return None;
231        };
232
233        if *kind == ScalarKind::Float && *width == 4 && *size == VectorSize::Quad {
234            return Some(TypeInfo {
235                size: 4,
236                columns: 1,
237            });
238        }
239
240        None
241    }
242}
243
244impl NagaReflect {
245    fn reflect_ubos(
246        &self,
247        vertex_ubo: Option<Handle<GlobalVariable>>,
248        fragment_ubo: Option<Handle<GlobalVariable>>,
249    ) -> Result<Option<BufferReflection<u32>>, ShaderReflectError> {
250        // todo: merge this with the spirv-cross code
251        match (vertex_ubo, fragment_ubo) {
252            (None, None) => Ok(None),
253            (Some(vertex_ubo), Some(fragment_ubo)) => {
254                let vertex_ubo = Self::get_ubo_data(
255                    &self.vertex,
256                    &self.vertex.global_variables[vertex_ubo],
257                    SemanticErrorBlame::Vertex,
258                )?;
259                let fragment_ubo = Self::get_ubo_data(
260                    &self.fragment,
261                    &self.fragment.global_variables[fragment_ubo],
262                    SemanticErrorBlame::Fragment,
263                )?;
264                if vertex_ubo.binding != fragment_ubo.binding {
265                    return Err(ShaderReflectError::MismatchedUniformBuffer {
266                        vertex: vertex_ubo.binding,
267                        fragment: fragment_ubo.binding,
268                    });
269                }
270
271                let size = std::cmp::max(vertex_ubo.size, fragment_ubo.size);
272                Ok(Some(BufferReflection {
273                    binding: vertex_ubo.binding,
274                    size: align_uniform_size(size),
275                    stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
276                }))
277            }
278            (Some(vertex_ubo), None) => {
279                let vertex_ubo = Self::get_ubo_data(
280                    &self.vertex,
281                    &self.vertex.global_variables[vertex_ubo],
282                    SemanticErrorBlame::Vertex,
283                )?;
284                Ok(Some(BufferReflection {
285                    binding: vertex_ubo.binding,
286                    size: align_uniform_size(vertex_ubo.size),
287                    stage_mask: BindingStage::VERTEX,
288                }))
289            }
290            (None, Some(fragment_ubo)) => {
291                let fragment_ubo = Self::get_ubo_data(
292                    &self.fragment,
293                    &self.fragment.global_variables[fragment_ubo],
294                    SemanticErrorBlame::Fragment,
295                )?;
296                Ok(Some(BufferReflection {
297                    binding: fragment_ubo.binding,
298                    size: align_uniform_size(fragment_ubo.size),
299                    stage_mask: BindingStage::FRAGMENT,
300                }))
301            }
302        }
303    }
304
305    fn get_ubo_data(
306        module: &Module,
307        ubo: &GlobalVariable,
308        blame: SemanticErrorBlame,
309    ) -> Result<UboData, ShaderReflectError> {
310        let Some(binding) = &ubo.binding else {
311            return Err(blame.error(SemanticsErrorKind::MissingBinding));
312        };
313
314        if binding.binding >= MAX_BINDINGS_COUNT {
315            return Err(blame.error(SemanticsErrorKind::InvalidBinding(binding.binding)));
316        }
317
318        if binding.group != 0 {
319            return Err(blame.error(SemanticsErrorKind::InvalidDescriptorSet(binding.group)));
320        }
321
322        let ty = &module.types[ubo.ty];
323        let size = ty.inner.size(module.to_ctx());
324        Ok(UboData {
325            // descriptor_set,
326            // id: ubo.id,
327            binding: binding.binding,
328            size,
329        })
330    }
331
332    fn get_next_binding(&self, bind_group: u32) -> u32 {
333        let mut max_bind = 0;
334        for (_, gv) in self.vertex.global_variables.iter() {
335            let Some(binding) = &gv.binding else {
336                continue;
337            };
338            if binding.group != bind_group {
339                continue;
340            }
341            max_bind = std::cmp::max(max_bind, binding.binding);
342        }
343
344        for (_, gv) in self.fragment.global_variables.iter() {
345            let Some(binding) = &gv.binding else {
346                continue;
347            };
348            if binding.group != bind_group {
349                continue;
350            }
351            max_bind = std::cmp::max(max_bind, binding.binding);
352        }
353
354        max_bind + 1
355    }
356
357    fn get_push_size(
358        module: &Module,
359        push: &GlobalVariable,
360        blame: SemanticErrorBlame,
361    ) -> Result<u32, ShaderReflectError> {
362        let ty = &module.types[push.ty];
363        let size = ty.inner.size(module.to_ctx());
364        if size > MAX_PUSH_BUFFER_SIZE {
365            return Err(blame.error(SemanticsErrorKind::InvalidPushBufferSize(size)));
366        }
367        Ok(size)
368    }
369
370    fn reflect_push_constant_buffer(
371        &mut self,
372        vertex_pcb: Option<Handle<GlobalVariable>>,
373        fragment_pcb: Option<Handle<GlobalVariable>>,
374    ) -> Result<Option<BufferReflection<Option<u32>>>, ShaderReflectError> {
375        let binding = self.get_next_binding(0);
376        // Reassign to UBO later if we want during compilation.
377        if let Some(vertex_pcb) = vertex_pcb {
378            let pcb = &mut self.vertex.global_variables[vertex_pcb];
379            pcb.binding = Some(ResourceBinding { group: 0, binding });
380        }
381
382        if let Some(fragment_pcb) = fragment_pcb {
383            let pcb = &mut self.fragment.global_variables[fragment_pcb];
384            pcb.binding = Some(ResourceBinding { group: 0, binding });
385        };
386
387        match (vertex_pcb, fragment_pcb) {
388            (None, None) => Ok(None),
389            (Some(vertex_push), Some(fragment_push)) => {
390                let vertex_size = Self::get_push_size(
391                    &self.vertex,
392                    &self.vertex.global_variables[vertex_push],
393                    SemanticErrorBlame::Vertex,
394                )?;
395                let fragment_size = Self::get_push_size(
396                    &self.fragment,
397                    &self.fragment.global_variables[fragment_push],
398                    SemanticErrorBlame::Fragment,
399                )?;
400
401                let size = std::cmp::max(vertex_size, fragment_size);
402
403                Ok(Some(BufferReflection {
404                    binding: Some(binding),
405                    size: align_uniform_size(size),
406                    stage_mask: BindingStage::VERTEX | BindingStage::FRAGMENT,
407                }))
408            }
409            (Some(vertex_push), None) => {
410                let vertex_size = Self::get_push_size(
411                    &self.vertex,
412                    &self.vertex.global_variables[vertex_push],
413                    SemanticErrorBlame::Vertex,
414                )?;
415                Ok(Some(BufferReflection {
416                    binding: Some(binding),
417                    size: align_uniform_size(vertex_size),
418                    stage_mask: BindingStage::VERTEX,
419                }))
420            }
421            (None, Some(fragment_push)) => {
422                let fragment_size = Self::get_push_size(
423                    &self.fragment,
424                    &self.fragment.global_variables[fragment_push],
425                    SemanticErrorBlame::Fragment,
426                )?;
427                Ok(Some(BufferReflection {
428                    binding: Some(binding),
429                    size: align_uniform_size(fragment_size),
430                    stage_mask: BindingStage::FRAGMENT,
431                }))
432            }
433        }
434    }
435
436    fn validate_semantics(&self) -> Result<(), ShaderReflectError> {
437        // Verify types
438        if self.vertex.global_variables.iter().any(|(_, gv)| {
439            let ty = &self.vertex.types[gv.ty];
440            match ty.inner {
441                TypeInner::Scalar { .. }
442                | TypeInner::Vector { .. }
443                | TypeInner::Matrix { .. }
444                | TypeInner::Struct { .. } => false,
445                _ => true,
446            }
447        }) {
448            return Err(ShaderReflectError::VertexSemanticError(
449                SemanticsErrorKind::InvalidResourceType,
450            ));
451        }
452
453        if self.fragment.global_variables.iter().any(|(_, gv)| {
454            let ty = &self.fragment.types[gv.ty];
455            match ty.inner {
456                TypeInner::Scalar { .. }
457                | TypeInner::Vector { .. }
458                | TypeInner::Matrix { .. }
459                | TypeInner::Struct { .. }
460                | TypeInner::Image { .. }
461                | TypeInner::Sampler { .. } => false,
462                TypeInner::BindingArray { base, .. } => {
463                    let ty = &self.fragment.types[base];
464                    match ty.inner {
465                        TypeInner::Image { class, .. }
466                            if !matches!(class, ImageClass::Storage { .. }) =>
467                        {
468                            false
469                        }
470                        TypeInner::Sampler { .. } => false,
471                        _ => true,
472                    }
473                }
474                _ => true,
475            }
476        }) {
477            return Err(ShaderReflectError::FragmentSemanticError(
478                SemanticsErrorKind::InvalidResourceType,
479            ));
480        }
481
482        // Verify Vertex inputs
483        'vertex: {
484            if self.vertex.entry_points.len() != 1 {
485                return Err(ShaderReflectError::VertexSemanticError(
486                    SemanticsErrorKind::InvalidEntryPointCount(self.vertex.entry_points.len()),
487                ));
488            }
489
490            let vertex_entry_point = &self.vertex.entry_points[0];
491            let vert_inputs = vertex_entry_point.function.arguments.len();
492            if vert_inputs != 2 {
493                return Err(ShaderReflectError::VertexSemanticError(
494                    SemanticsErrorKind::InvalidInputCount(vert_inputs),
495                ));
496            }
497            for input in &vertex_entry_point.function.arguments {
498                let &Some(Binding::Location { location, .. }) = &input.binding else {
499                    return Err(ShaderReflectError::VertexSemanticError(
500                        SemanticsErrorKind::MissingBinding,
501                    ));
502                };
503
504                if location == 0 {
505                    let pos_type = &self.vertex.types[input.ty];
506                    if !matches!(pos_type.inner, TypeInner::Vector { size, ..} if size == VectorSize::Quad)
507                    {
508                        return Err(ShaderReflectError::VertexSemanticError(
509                            SemanticsErrorKind::InvalidLocation(location),
510                        ));
511                    }
512                    break 'vertex;
513                }
514
515                if location == 1 {
516                    let coord_type = &self.vertex.types[input.ty];
517                    if !matches!(coord_type.inner, TypeInner::Vector { size, ..} if size == VectorSize::Bi)
518                    {
519                        return Err(ShaderReflectError::VertexSemanticError(
520                            SemanticsErrorKind::InvalidLocation(location),
521                        ));
522                    }
523                    break 'vertex;
524                }
525
526                return Err(ShaderReflectError::VertexSemanticError(
527                    SemanticsErrorKind::InvalidLocation(location),
528                ));
529            }
530
531            let uniform_buffer_count = self
532                .vertex
533                .global_variables
534                .iter()
535                .filter(|(_, gv)| gv.space == AddressSpace::Uniform)
536                .count();
537
538            if uniform_buffer_count > 1 {
539                return Err(ShaderReflectError::VertexSemanticError(
540                    SemanticsErrorKind::InvalidUniformBufferCount(uniform_buffer_count),
541                ));
542            }
543
544            let push_buffer_count = self
545                .vertex
546                .global_variables
547                .iter()
548                .filter(|(_, gv)| gv.space == AddressSpace::PushConstant)
549                .count();
550
551            if push_buffer_count > 1 {
552                return Err(ShaderReflectError::VertexSemanticError(
553                    SemanticsErrorKind::InvalidPushBufferCount(push_buffer_count),
554                ));
555            }
556        }
557
558        {
559            if self.fragment.entry_points.len() != 1 {
560                return Err(ShaderReflectError::FragmentSemanticError(
561                    SemanticsErrorKind::InvalidEntryPointCount(self.vertex.entry_points.len()),
562                ));
563            }
564
565            let frag_entry_point = &self.fragment.entry_points[0];
566            let Some(frag_output) = &frag_entry_point.function.result else {
567                return Err(ShaderReflectError::FragmentSemanticError(
568                    SemanticsErrorKind::InvalidOutputCount(0),
569                ));
570            };
571
572            let &Some(Binding::Location { location, .. }) = &frag_output.binding else {
573                return Err(ShaderReflectError::FragmentSemanticError(
574                    SemanticsErrorKind::MissingBinding,
575                ));
576            };
577
578            if location != 0 {
579                return Err(ShaderReflectError::FragmentSemanticError(
580                    SemanticsErrorKind::InvalidLocation(location),
581                ));
582            }
583
584            let uniform_buffer_count = self
585                .fragment
586                .global_variables
587                .iter()
588                .filter(|(_, gv)| gv.space == AddressSpace::Uniform)
589                .count();
590
591            if uniform_buffer_count > 1 {
592                return Err(ShaderReflectError::FragmentSemanticError(
593                    SemanticsErrorKind::InvalidUniformBufferCount(uniform_buffer_count),
594                ));
595            }
596
597            let push_buffer_count = self
598                .fragment
599                .global_variables
600                .iter()
601                .filter(|(_, gv)| gv.space == AddressSpace::PushConstant)
602                .count();
603
604            if push_buffer_count > 1 {
605                return Err(ShaderReflectError::FragmentSemanticError(
606                    SemanticsErrorKind::InvalidPushBufferCount(push_buffer_count),
607                ));
608            }
609        }
610
611        Ok(())
612    }
613
614    fn collect_uniform_names(
615        module: &Module,
616        buffer_handle: Handle<GlobalVariable>,
617        blame: SemanticErrorBlame,
618    ) -> Result<FxHashSet<&StructMember>, ShaderReflectError> {
619        let mut names = FxHashSet::default();
620        let ubo = &module.global_variables[buffer_handle];
621
622        let TypeInner::Struct { members, .. } = &module.types[ubo.ty].inner else {
623            return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
624        };
625
626        // struct access is AccessIndex
627        for (_, fun) in module.functions.iter() {
628            for (_, expr) in fun.expressions.iter() {
629                let &Expression::AccessIndex { base, index } = expr else {
630                    continue;
631                };
632
633                let &Expression::GlobalVariable(base) = &fun.expressions[base] else {
634                    continue;
635                };
636
637                if base == buffer_handle {
638                    let member = members
639                        .get(index as usize)
640                        .ok_or(blame.error(SemanticsErrorKind::InvalidRange(index)))?;
641                    names.insert(member);
642                }
643            }
644        }
645
646        Ok(names)
647    }
648
649    fn reflect_buffer_struct_members(
650        module: &Module,
651        resource: Handle<GlobalVariable>,
652        pass_number: usize,
653        semantics: &ShaderSemantics,
654        meta: &mut BindingMeta,
655        offset_type: UniformMemberBlock,
656        blame: SemanticErrorBlame,
657    ) -> Result<(), ShaderReflectError> {
658        let reachable = Self::collect_uniform_names(&module, resource, blame)?;
659
660        let resource = &module.global_variables[resource];
661
662        let TypeInner::Struct { members, .. } = &module.types[resource.ty].inner else {
663            return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
664        };
665
666        for member in members {
667            let Some(name) = member.name.clone() else {
668                return Err(blame.error(SemanticsErrorKind::InvalidRange(member.offset)));
669            };
670
671            if !reachable.contains(member) {
672                continue;
673            }
674
675            let member_type = &module.types[member.ty].inner;
676
677            if let Some(parameter) = semantics.uniform_semantics.unique_semantic(&name) {
678                let Some(typeinfo) = parameter.semantics.validate_type(&member_type) else {
679                    return Err(blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name)));
680                };
681
682                match &parameter.semantics {
683                    UniqueSemantics::FloatParameter => {
684                        let offset = member.offset;
685                        if let Some(meta) = meta.parameter_meta.get_mut::<str>(name.as_ref()) {
686                            if let Some(expected) = meta
687                                .offset
688                                .offset(offset_type)
689                                .filter(|expected| *expected != offset as usize)
690                            {
691                                return Err(ShaderReflectError::MismatchedOffset {
692                                    semantic: name,
693                                    expected,
694                                    received: offset as usize,
695                                    ty: offset_type,
696                                    pass: pass_number,
697                                });
698                            }
699                            if meta.size != typeinfo.size {
700                                return Err(ShaderReflectError::MismatchedSize {
701                                    semantic: name,
702                                    vertex: meta.size,
703                                    fragment: typeinfo.size,
704                                    pass: pass_number,
705                                });
706                            }
707
708                            *meta.offset.offset_mut(offset_type) = Some(offset as usize);
709                        } else {
710                            let name = ShortString::from(name);
711                            meta.parameter_meta.insert(
712                                name.clone(),
713                                VariableMeta {
714                                    id: name,
715                                    offset: MemberOffset::new(offset as usize, offset_type),
716                                    size: typeinfo.size,
717                                },
718                            );
719                        }
720                    }
721                    semantics => {
722                        let offset = member.offset;
723                        if let Some(meta) = meta.unique_meta.get_mut(semantics) {
724                            if let Some(expected) = meta
725                                .offset
726                                .offset(offset_type)
727                                .filter(|expected| *expected != offset as usize)
728                            {
729                                return Err(ShaderReflectError::MismatchedOffset {
730                                    semantic: name,
731                                    expected,
732                                    received: offset as usize,
733                                    ty: offset_type,
734                                    pass: pass_number,
735                                });
736                            }
737                            if meta.size != typeinfo.size * typeinfo.columns {
738                                return Err(ShaderReflectError::MismatchedSize {
739                                    semantic: name,
740                                    vertex: meta.size,
741                                    fragment: typeinfo.size,
742                                    pass: pass_number,
743                                });
744                            }
745
746                            *meta.offset.offset_mut(offset_type) = Some(offset as usize);
747                        } else {
748                            meta.unique_meta.insert(
749                                *semantics,
750                                VariableMeta {
751                                    id: ShortString::from(name),
752                                    offset: MemberOffset::new(offset as usize, offset_type),
753                                    size: typeinfo.size * typeinfo.columns,
754                                },
755                            );
756                        }
757                    }
758                }
759            } else if let Some(texture) = semantics.uniform_semantics.texture_semantic(&name) {
760                let Some(_typeinfo) = texture.semantics.validate_type(&member_type) else {
761                    return Err(blame.error(SemanticsErrorKind::InvalidTypeForSemantic(name)));
762                };
763
764                if let TextureSemantics::PassOutput = texture.semantics {
765                    if texture.index >= pass_number {
766                        return Err(ShaderReflectError::NonCausalFilterChain {
767                            pass: pass_number,
768                            target: texture.index,
769                        });
770                    }
771                }
772
773                let offset = member.offset;
774                if let Some(meta) = meta.texture_size_meta.get_mut(&texture) {
775                    if let Some(expected) = meta
776                        .offset
777                        .offset(offset_type)
778                        .filter(|expected| *expected != offset as usize)
779                    {
780                        return Err(ShaderReflectError::MismatchedOffset {
781                            semantic: name,
782                            expected,
783                            received: offset as usize,
784                            ty: offset_type,
785                            pass: pass_number,
786                        });
787                    }
788
789                    meta.stage_mask.insert(match blame {
790                        SemanticErrorBlame::Vertex => BindingStage::VERTEX,
791                        SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
792                    });
793
794                    *meta.offset.offset_mut(offset_type) = Some(offset as usize);
795                } else {
796                    meta.texture_size_meta.insert(
797                        texture,
798                        TextureSizeMeta {
799                            offset: MemberOffset::new(offset as usize, offset_type),
800                            stage_mask: match blame {
801                                SemanticErrorBlame::Vertex => BindingStage::VERTEX,
802                                SemanticErrorBlame::Fragment => BindingStage::FRAGMENT,
803                            },
804                            id: ShortString::from(name),
805                        },
806                    );
807                }
808            } else {
809                return Err(blame.error(SemanticsErrorKind::UnknownSemantics(name)));
810            }
811        }
812        Ok(())
813    }
814
815    fn reflect_texture<'a>(
816        &'a self,
817        texture: &'a GlobalVariable,
818    ) -> Result<TextureData<'a>, ShaderReflectError> {
819        let Some(binding) = &texture.binding else {
820            return Err(ShaderReflectError::FragmentSemanticError(
821                SemanticsErrorKind::MissingBinding,
822            ));
823        };
824
825        let Some(name) = texture.name.as_ref() else {
826            return Err(ShaderReflectError::FragmentSemanticError(
827                SemanticsErrorKind::InvalidBinding(binding.binding),
828            ));
829        };
830
831        if binding.group != 0 {
832            return Err(ShaderReflectError::FragmentSemanticError(
833                SemanticsErrorKind::InvalidDescriptorSet(binding.group),
834            ));
835        }
836        if binding.binding >= MAX_BINDINGS_COUNT {
837            return Err(ShaderReflectError::FragmentSemanticError(
838                SemanticsErrorKind::InvalidBinding(binding.binding),
839            ));
840        }
841
842        Ok(TextureData {
843            // id: texture.id,
844            // descriptor_set,
845            name: &name,
846            binding: binding.binding,
847        })
848    }
849
850    // todo: share this with cross
851    fn reflect_texture_metas(
852        &self,
853        texture: TextureData,
854        pass_number: usize,
855        semantics: &ShaderSemantics,
856        meta: &mut BindingMeta,
857    ) -> Result<(), ShaderReflectError> {
858        let Some(semantic) = semantics.texture_semantics.texture_semantic(texture.name) else {
859            return Err(
860                SemanticErrorBlame::Fragment.error(SemanticsErrorKind::UnknownSemantics(
861                    texture.name.to_string(),
862                )),
863            );
864        };
865
866        if semantic.semantics == TextureSemantics::PassOutput && semantic.index >= pass_number {
867            return Err(ShaderReflectError::NonCausalFilterChain {
868                pass: pass_number,
869                target: semantic.index,
870            });
871        }
872
873        meta.texture_meta.insert(
874            semantic,
875            TextureBinding {
876                binding: texture.binding,
877            },
878        );
879        Ok(())
880    }
881}
882
883impl ReflectShader for NagaReflect {
884    fn reflect(
885        &mut self,
886        pass_number: usize,
887        semantics: &ShaderSemantics,
888    ) -> Result<ShaderReflection, ShaderReflectError> {
889        self.validate_semantics()?;
890
891        // Validate verifies that there's only one uniform block.
892        let vertex_ubo = self
893            .vertex
894            .global_variables
895            .iter()
896            .find_map(|(handle, gv)| {
897                if gv.space == AddressSpace::Uniform {
898                    Some(handle)
899                } else {
900                    None
901                }
902            });
903
904        let fragment_ubo = self
905            .fragment
906            .global_variables
907            .iter()
908            .find_map(|(handle, gv)| {
909                if gv.space == AddressSpace::Uniform {
910                    Some(handle)
911                } else {
912                    None
913                }
914            });
915
916        let ubo = self.reflect_ubos(vertex_ubo, fragment_ubo)?;
917
918        let vertex_push = self
919            .vertex
920            .global_variables
921            .iter()
922            .find_map(|(handle, gv)| {
923                if gv.space == AddressSpace::PushConstant {
924                    Some(handle)
925                } else {
926                    None
927                }
928            });
929
930        let fragment_push = self
931            .fragment
932            .global_variables
933            .iter()
934            .find_map(|(handle, gv)| {
935                if gv.space == AddressSpace::PushConstant {
936                    Some(handle)
937                } else {
938                    None
939                }
940            });
941
942        let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
943        let mut meta = BindingMeta::default();
944
945        if let Some(ubo) = vertex_ubo {
946            Self::reflect_buffer_struct_members(
947                &self.vertex,
948                ubo,
949                pass_number,
950                semantics,
951                &mut meta,
952                UniformMemberBlock::Ubo,
953                SemanticErrorBlame::Vertex,
954            )?;
955        }
956
957        if let Some(ubo) = fragment_ubo {
958            Self::reflect_buffer_struct_members(
959                &self.fragment,
960                ubo,
961                pass_number,
962                semantics,
963                &mut meta,
964                UniformMemberBlock::Ubo,
965                SemanticErrorBlame::Fragment,
966            )?;
967        }
968
969        if let Some(push) = vertex_push {
970            Self::reflect_buffer_struct_members(
971                &self.vertex,
972                push,
973                pass_number,
974                semantics,
975                &mut meta,
976                UniformMemberBlock::PushConstant,
977                SemanticErrorBlame::Vertex,
978            )?;
979        }
980
981        if let Some(push) = fragment_push {
982            Self::reflect_buffer_struct_members(
983                &self.fragment,
984                push,
985                pass_number,
986                semantics,
987                &mut meta,
988                UniformMemberBlock::PushConstant,
989                SemanticErrorBlame::Fragment,
990            )?;
991        }
992
993        let mut ubo_bindings = 0u16;
994        if vertex_ubo.is_some() || fragment_ubo.is_some() {
995            ubo_bindings = 1 << ubo.as_ref().expect("UBOs should be present").binding;
996        }
997
998        let textures = self.fragment.global_variables.iter().filter(|(_, gv)| {
999            let ty = &self.fragment.types[gv.ty];
1000            matches!(ty.inner, TypeInner::Image { .. })
1001        });
1002
1003        for (_, texture) in textures {
1004            let texture_data = self.reflect_texture(texture)?;
1005            if ubo_bindings & (1 << texture_data.binding) != 0 {
1006                return Err(ShaderReflectError::BindingInUse(texture_data.binding));
1007            }
1008            ubo_bindings |= 1 << texture_data.binding;
1009
1010            self.reflect_texture_metas(texture_data, pass_number, semantics, &mut meta)?;
1011        }
1012
1013        Ok(ShaderReflection {
1014            ubo,
1015            push_constant,
1016            meta,
1017        })
1018    }
1019
1020    fn validate(&mut self) -> Result<(), ShaderReflectError> {
1021        self.validate_semantics()?;
1022        let vertex_push = self
1023            .vertex
1024            .global_variables
1025            .iter()
1026            .find_map(|(handle, gv)| {
1027                if gv.space == AddressSpace::PushConstant {
1028                    Some(handle)
1029                } else {
1030                    None
1031                }
1032            });
1033
1034        let fragment_push = self
1035            .fragment
1036            .global_variables
1037            .iter()
1038            .find_map(|(handle, gv)| {
1039                if gv.space == AddressSpace::PushConstant {
1040                    Some(handle)
1041                } else {
1042                    None
1043                }
1044            });
1045
1046        self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
1047        Ok(())
1048    }
1049}
1050
1051#[cfg(test)]
1052mod test {
1053    use crate::front::ShaderInputCompiler;
1054    use crate::reflect::semantics::{Semantic, TextureSemantics, UniformSemantic};
1055    use librashader_common::map::FastHashMap;
1056    use librashader_preprocess::ShaderSource;
1057    use librashader_presets::ShaderPreset;
1058
1059    // #[test]
1060    // pub fn test_into() {
1061    //     let result = ShaderSource::load("../test/slang-shaders/misc/shaders/simple_color_controls.slang").unwrap();
1062    //     let compilation = crate::front::Glslang::compile(&result).unwrap();
1063    //
1064    //     crate::front::n
1065    //     // let mut loader = rspirv::dr::Loader::new();
1066    //     // rspirv::binary::parse_words(compilation.vertex.as_binary(), &mut loader).unwrap();
1067    //     // let module = loader.module();
1068    //     //
1069    //     // let outputs: Vec<&Instruction> = module
1070    //     //     .types_global_values
1071    //     //     .iter()
1072    //     //     .filter(|i| i.class.opcode == Op::Variable)
1073    //     //     .collect();
1074    //
1075    //     println!("{outputs:#?}");
1076    // }
1077
1078    // #[test]
1079    // pub fn mega_bezel_reflect() {
1080    //     let preset = ShaderPreset::try_parse(
1081    //         "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
1082    //     )
1083    //         .unwrap();
1084    //
1085    //     let mut uniform_semantics: FastHashMap<String, UniformSemantic> = Default::default();
1086    //     let mut texture_semantics: FastHashMap<String, Semantic<TextureSemantics>> = Default::default();
1087    //
1088    //
1089    //
1090    //
1091    // }
1092}