est_render/gpu/shader/
graphics.rs

1use core::panic;
2use std::{borrow::Cow, collections::HashMap, hash::Hash};
3
4use wgpu::{BindingType, SamplerBindingType, ShaderRuntimeChecks, ShaderStages, naga::front::wgsl};
5
6use crate::{
7    utils::ArcRef,
8};
9
10use super::{
11    types::{
12        BindGroupLayout, IndexBufferSize, 
13        ShaderBindingType, ShaderCullMode, 
14        ShaderFrontFace, ShaderPollygonMode, 
15        ShaderReflect, ShaderTopology, 
16        StorageAccess, VertexInputType,
17        VertexInputReflection,
18    },
19    super::GPUInner,
20};
21
22pub(crate) enum GraphicsShaderSource {
23    None,
24    Source(String),
25    SplitSource(String, String),
26    BinarySource(Vec<u8>),
27    BinarySplitSource(Vec<u8>, Vec<u8>),
28}
29
30/// Builder for creating graphics shaders.
31///
32/// This builder allows you to set the WGSL vertex and fragment shader source code from files or strings.
33/// You can also set the vertex and fragment shader source code separately.
34pub struct GraphicsShaderBuilder {
35    pub(crate) graphics: ArcRef<GPUInner>,
36    pub(crate) source: GraphicsShaderSource,
37}
38
39impl GraphicsShaderBuilder {
40    pub(crate) fn new(graphics: ArcRef<GPUInner>) -> Self {
41        Self {
42            graphics,
43            source: GraphicsShaderSource::None,
44        }
45    }
46
47    /// Sets the WGSL vertex and fragment shader source code from a file.
48    pub fn set_file(mut self, path: &str) -> Self {
49        let data = std::fs::read_to_string(path);
50        if let Err(err) = data {
51            panic!("Failed to read shader file: {:?}", err);
52        }
53
54        self.source = GraphicsShaderSource::Source(data.unwrap());
55
56        self
57    }
58
59    /// Sets the WGSL vertex and fragment shader source code from a string.
60    pub fn set_source(mut self, source: &str) -> Self {
61        self.source = GraphicsShaderSource::Source(source.to_string());
62        self
63    }
64
65    /// Sets the WGSL vertex shader source code from a file.
66    ///
67    /// You need to also set the fragment shader source code using `set_fragment_file` or `set_fragment_code`.
68    pub fn set_vertex_file(mut self, path: &str) -> Self {
69        let data = std::fs::read_to_string(path);
70        if let Err(err) = data {
71            panic!("Failed to read vertex shader file: {:?}", err);
72        }
73
74        match self.source {
75            GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
76                self.source =
77                    GraphicsShaderSource::SplitSource(data.unwrap(), vertex_source.clone());
78            }
79            _ => {
80                self.source = GraphicsShaderSource::SplitSource(data.unwrap(), "".to_string());
81            }
82        }
83
84        self
85    }
86
87    /// Sets the WGSL fragment shader source code from a file.
88    ///
89    /// You need to also set the vertex shader source code using `set_vertex_file` or `set_vertex_code`.
90    pub fn set_fragment_file(mut self, path: &str) -> Self {
91        let data = std::fs::read_to_string(path);
92        if let Err(err) = data {
93            panic!("Failed to read fragment shader file: {:?}", err);
94        }
95
96        match self.source {
97            GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
98                self.source =
99                    GraphicsShaderSource::SplitSource(vertex_source.clone(), data.unwrap());
100            }
101            _ => {
102                self.source = GraphicsShaderSource::SplitSource("".to_string(), data.unwrap());
103            }
104        }
105
106        self
107    }
108
109    /// Sets the WGSL vertex shader source code from a string.
110    ///
111    /// You need to also set the fragment shader source code using `set_fragment_code` or `set_fragment_file`.
112    pub fn set_vertex_code(mut self, source: &str) -> Self {
113        match self.source {
114            GraphicsShaderSource::SplitSource(_, ref mut fragment_source) => {
115                self.source =
116                    GraphicsShaderSource::SplitSource(source.to_string(), fragment_source.clone());
117            }
118            _ => {
119                self.source = GraphicsShaderSource::SplitSource(source.to_string(), "".to_string());
120            }
121        }
122
123        self
124    }
125
126    /// Sets the WGSL fragment shader source code from a string.
127    ///
128    /// You need to also set the vertex shader source code using `set_vertex_code` or `set_vertex_file`.
129    pub fn set_fragment_code(mut self, source: &str) -> Self {
130        match self.source {
131            GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
132                self.source =
133                    GraphicsShaderSource::SplitSource(vertex_source.clone(), source.to_string());
134            }
135            _ => {
136                self.source = GraphicsShaderSource::SplitSource("".to_string(), source.to_string());
137            }
138        }
139
140        self
141    }
142
143    /// Sets the precompiled binary shader source code.
144    ///
145    /// This is useful for using shaders compiled with tools like `glslangValidator` or `shaderc`.
146    pub fn set_binary_source(mut self, binary: &[u8]) -> Self {
147        self.source = GraphicsShaderSource::BinarySource(binary.to_vec());
148        self
149    }
150
151    /// Sets the precompiled binary vertex and fragment shader source code.
152    ///
153    /// This is useful for using shaders compiled with tools like `glslangValidator` or `shaderc`.
154    pub fn set_binary_file(mut self, path: &str) -> Self {
155        let data = std::fs::read(path);
156        if let Err(err) = data {
157            panic!("Failed to read binary shader file: {:?}", err);
158        }
159
160        self.source = GraphicsShaderSource::BinarySource(data.unwrap());
161        self
162    }
163
164    /// Sets the precompiled binary vertex shader source code.
165    ///
166    /// You need to also set the fragment shader source code using `set_binary_fragment`.
167    pub fn set_binary_vertex(mut self, binary: &[u8]) -> Self {
168        match self.source {
169            GraphicsShaderSource::BinarySplitSource(ref mut vertex_bin, _) => {
170                self.source =
171                    GraphicsShaderSource::BinarySplitSource(binary.to_vec(), vertex_bin.clone());
172            }
173            _ => {
174                self.source = GraphicsShaderSource::BinarySplitSource(binary.to_vec(), vec![]);
175            }
176        }
177
178        self
179    }
180
181    /// Sets the precompiled binary fragment shader source code.
182    ///
183    /// You need to also set the vertex shader source code using `set_binary_vertex`.
184    pub fn set_binary_fragment(mut self, binary: &[u8]) -> Self {
185        match self.source {
186            GraphicsShaderSource::BinarySplitSource(_, ref mut fragment_bin) => {
187                self.source =
188                    GraphicsShaderSource::BinarySplitSource(fragment_bin.clone(), binary.to_vec());
189            }
190            _ => {
191                self.source = GraphicsShaderSource::BinarySplitSource(vec![], binary.to_vec());
192            }
193        }
194
195        self
196    }
197
198    pub fn build(self) -> Result<GraphicsShader, String> {
199        GraphicsShader::new(self.graphics, self.source)
200    }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub enum GraphicsShaderType {
205    GraphicsSingle {
206        module: wgpu::ShaderModule,
207    },
208    GraphicsSplit {
209        vertex_module: wgpu::ShaderModule,
210        fragment_module: wgpu::ShaderModule,
211    },
212}
213
214impl Hash for GraphicsShaderType {
215    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
216        match self {
217            GraphicsShaderType::GraphicsSingle { module } => {
218                module.hash(state);
219            }
220            GraphicsShaderType::GraphicsSplit {
221                vertex_module,
222                fragment_module,
223            } => {
224                vertex_module.hash(state);
225                fragment_module.hash(state);
226            }
227        }
228    }
229}
230
231#[derive(Clone, Debug, Hash)]
232pub(crate) struct GraphicsShaderInner {
233    pub ty: GraphicsShaderType,
234    pub reflection: Vec<ShaderReflect>,
235
236    pub bind_group_layouts: Vec<BindGroupLayout>,
237}
238
239impl PartialEq for GraphicsShaderInner {
240    fn eq(&self, other: &Self) -> bool {
241        let ty_equal = self.ty == other.ty;
242
243        let reflection_equal = self.reflection.len() == other.reflection.len()
244            && self
245                .reflection
246                .iter()
247                .zip(&other.reflection)
248                .all(|(a, b)| a == b);
249
250        let layouts_equal = self.bind_group_layouts.len() == other.bind_group_layouts.len()
251            && self
252                .bind_group_layouts
253                .iter()
254                .zip(&other.bind_group_layouts)
255                .all(|(a, b)| {
256                    a.group == b.group && a.bindings == b.bindings && a.layout == b.layout
257                });
258
259        ty_equal && reflection_equal && layouts_equal
260    }
261}
262
263#[derive(Clone, Debug, Eq, Hash)]
264pub(crate) struct VertexInputDescription {
265    pub index: Option<IndexBufferSize>,
266    pub topology: ShaderTopology,
267    pub cull_mode: Option<ShaderCullMode>,
268    pub polygon_mode: ShaderPollygonMode,
269    pub front_face: ShaderFrontFace,
270    pub stride: wgpu::BufferAddress,
271    pub attributes: Vec<wgpu::VertexAttribute>,
272}
273
274impl PartialEq for VertexInputDescription {
275    fn eq(&self, other: &Self) -> bool {
276        self.index == other.index
277            && self.topology == other.topology
278            && self.cull_mode == other.cull_mode
279            && self.polygon_mode == other.polygon_mode
280            && self.front_face == other.front_face
281            && self.stride == other.stride
282            && self.attributes == other.attributes
283    }
284}
285
286#[derive(Clone, Debug, Eq)]
287#[allow(unused)]
288pub struct GraphicsShader {
289    pub(crate) graphics: ArcRef<GPUInner>,
290    pub(crate) inner: ArcRef<GraphicsShaderInner>,
291
292    pub(crate) attrib: ArcRef<VertexInputDescription>,
293}
294
295impl GraphicsShader {
296    pub(crate) fn new(
297        graphics: ArcRef<GPUInner>,
298        wgls_data: GraphicsShaderSource,
299    ) -> Result<Self, String> {
300        let graphics_ref = graphics.borrow();
301        let device_ref = graphics_ref.device.as_ref().ok_or("Missing device")?;
302
303        fn create_vertex_input_attrib(input: &VertexInputReflection) -> Vec<wgpu::VertexAttribute> {
304            input
305                .attributes
306                .iter()
307                .map(|(location, offset, vtype)| wgpu::VertexAttribute {
308                    format: vtype.clone().into(),
309                    offset: *offset as wgpu::BufferAddress,
310                    shader_location: *location,
311                })
312                .collect()
313        }
314
315        fn create_input_desc(reflection: &ShaderReflect) -> Result<VertexInputDescription, String> {
316            let (vertex_input, stride) = match reflection {
317                ShaderReflect::Vertex { input, .. }
318                | ShaderReflect::VertexFragment {
319                    vertex_input: input,
320                    ..
321                } => {
322                    let input = input.as_ref().ok_or("Missing vertex input")?;
323                    (input, input.stride as wgpu::BufferAddress)
324                }
325                _ => return Err("Invalid shader type for vertex input".to_string()),
326            };
327
328            let attributes = create_vertex_input_attrib(vertex_input);
329            Ok(VertexInputDescription {
330                index: Some(IndexBufferSize::U16),
331                stride,
332                attributes,
333                topology: ShaderTopology::TriangleList,
334                cull_mode: None,
335                polygon_mode: ShaderPollygonMode::Fill,
336                front_face: ShaderFrontFace::Clockwise,
337            })
338        }
339
340        fn build_single_shader(
341            device: &wgpu::Device,
342            source: &str,
343        ) -> Result<(wgpu::ShaderModule, ShaderReflect), String> {
344            let module = wgsl::parse_str(source).map_err(|e| format!("Parse error: {e:?}"))?;
345            let reflection = super::reflection::parse(module).map_err(|e| format!("Reflect error: {e:?}"))?;
346            Ok((
347                device.create_shader_module(wgpu::ShaderModuleDescriptor {
348                    label: None,
349                    source: wgpu::ShaderSource::Wgsl(source.into()),
350                }),
351                reflection,
352            ))
353        }
354
355        fn build_binary_shader(
356            device: &wgpu::Device,
357            binary: &[u8],
358        ) -> Result<(wgpu::ShaderModule, ShaderReflect), String> {
359            let binary_shader = super::reflection::load_binary_shader(binary)
360                .map_err(|e| format!("Binary load error: {e:?}"))?;
361            let spirv_u32 = Cow::Borrowed(bytemuck::cast_slice(&binary_shader.spirv));
362            Ok((
363                // SAFETY: All binary shaders are validated and built with our shader compiler (est-shader-compiler).
364                // This used for fast shader loading, so we assume that the binary shader is valid.
365                unsafe {
366                    let desc = wgpu::ShaderModuleDescriptor {
367                        label: None,
368                        source: wgpu::ShaderSource::SpirV(spirv_u32),
369                    };
370
371                    let runtime_checks = ShaderRuntimeChecks {
372                        bounds_checks: true,
373                        force_loop_bounding: false,
374                    };
375
376                    device.create_shader_module_trusted(desc, runtime_checks)
377                },
378                binary_shader.reflect,
379            ))
380        }
381
382        match wgls_data {
383            GraphicsShaderSource::None => Err("No shader source provided".to_string()),
384
385            GraphicsShaderSource::Source(source) => {
386                let (module, reflection) = build_single_shader(device_ref, &source)?;
387                match reflection {
388                    ShaderReflect::VertexFragment { .. } => {
389                        let layout = Self::make_group_layout(device_ref, &[reflection.clone()]);
390                        let input_desc = create_input_desc(&reflection)?;
391                        Ok(Self {
392                            graphics: ArcRef::clone(&graphics),
393                            inner: ArcRef::new(GraphicsShaderInner {
394                                ty: GraphicsShaderType::GraphicsSingle { module },
395                                reflection: vec![reflection],
396                                bind_group_layouts: layout,
397                            }),
398                            attrib: ArcRef::new(input_desc),
399                        })
400                    }
401                    _ => Err("Shader source is not VertexFragment shader!".to_string()),
402                }
403            }
404
405            GraphicsShaderSource::SplitSource(vertex_src, fragment_src) => {
406                let (vertex_module, vertex_reflect) = build_single_shader(device_ref, &vertex_src)?;
407                let (fragment_module, fragment_reflect) =
408                    build_single_shader(device_ref, &fragment_src)?;
409
410                match (&vertex_reflect, &fragment_reflect) {
411                    (ShaderReflect::Vertex { .. }, ShaderReflect::Fragment { .. }) => {
412                        let layout = Self::make_group_layout(
413                            device_ref,
414                            &[vertex_reflect.clone(), fragment_reflect.clone()],
415                        );
416                        let input_desc = create_input_desc(&vertex_reflect)?;
417                        Ok(Self {
418                            graphics: ArcRef::clone(&graphics),
419                            inner: ArcRef::new(GraphicsShaderInner {
420                                ty: GraphicsShaderType::GraphicsSplit {
421                                    vertex_module,
422                                    fragment_module,
423                                },
424                                reflection: vec![vertex_reflect, fragment_reflect],
425                                bind_group_layouts: layout,
426                            }),
427                            attrib: ArcRef::new(input_desc),
428                        })
429                    }
430                    _ => Err("Invalid shader pair for SplitSource".to_string()),
431                }
432            }
433
434            GraphicsShaderSource::BinarySource(binary) => {
435                let (module, reflection) = build_binary_shader(device_ref, &binary)?;
436                match reflection {
437                    ShaderReflect::VertexFragment { .. } => {
438                        let layout = Self::make_group_layout(device_ref, &[reflection.clone()]);
439                        let input_desc = create_input_desc(&reflection)?;
440                        Ok(Self {
441                            graphics: ArcRef::clone(&graphics),
442                            inner: ArcRef::new(GraphicsShaderInner {
443                                ty: GraphicsShaderType::GraphicsSingle { module },
444                                reflection: vec![reflection],
445                                bind_group_layouts: layout,
446                            }),
447                            attrib: ArcRef::new(input_desc),
448                        })
449                    }
450                    _ => Err("Binary shader is not VertexFragment shader!".to_string()),
451                }
452            }
453
454            GraphicsShaderSource::BinarySplitSource(vertex_bin, fragment_bin) => {
455                let (vertex_module, vertex_reflect) = build_binary_shader(device_ref, &vertex_bin)?;
456                let (fragment_module, fragment_reflect) =
457                    build_binary_shader(device_ref, &fragment_bin)?;
458
459                match (&vertex_reflect, &fragment_reflect) {
460                    (ShaderReflect::Vertex { .. }, ShaderReflect::Fragment { .. }) => {
461                        let layout = Self::make_group_layout(
462                            device_ref,
463                            &[vertex_reflect.clone(), fragment_reflect.clone()],
464                        );
465                        let input_desc = create_input_desc(&vertex_reflect)?;
466                        Ok(Self {
467                            graphics: ArcRef::clone(&graphics),
468                            inner: ArcRef::new(GraphicsShaderInner {
469                                ty: GraphicsShaderType::GraphicsSplit {
470                                    vertex_module,
471                                    fragment_module,
472                                },
473                                reflection: vec![vertex_reflect, fragment_reflect],
474                                bind_group_layouts: layout,
475                            }),
476                            attrib: ArcRef::new(input_desc),
477                        })
478                    }
479                    _ => Err("Invalid binary shader pair for BinarySplitSource".to_string()),
480                }
481            }
482        }
483    }
484
485    fn make_group_layout(
486        device: &wgpu::Device,
487        reflects: &[ShaderReflect],
488    ) -> Vec<BindGroupLayout> {
489        let mut layouts: HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>> = HashMap::new();
490
491        fn find_existing(
492            layouts: &mut HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>>,
493            group: u32,
494            binding: u32,
495            _ty: wgpu::BindingType,
496        ) -> Option<&mut wgpu::BindGroupLayoutEntry> {
497            layouts.get_mut(&group).and_then(|entries| {
498                entries
499                    .iter_mut()
500                    .find(|entry| entry.binding == binding && matches!(entry.ty, _ty))
501            })
502        }
503
504        fn create_layout_ty(ty: ShaderBindingType) -> wgpu::BindingType {
505            match ty {
506                ShaderBindingType::UniformBuffer(size) => BindingType::Buffer {
507                    ty: wgpu::BufferBindingType::Uniform,
508                    has_dynamic_offset: false,
509                    min_binding_size: if size == u32::MAX {
510                        None
511                    } else {
512                        wgpu::BufferSize::new(size as u64)
513                    },
514                },
515                ShaderBindingType::Texture(multisampled) => BindingType::Texture {
516                    sample_type: wgpu::TextureSampleType::Float { filterable: true },
517                    view_dimension: wgpu::TextureViewDimension::D2,
518                    multisampled,
519                },
520                ShaderBindingType::Sampler(comparison) => BindingType::Sampler(if comparison {
521                    SamplerBindingType::Comparison
522                } else {
523                    SamplerBindingType::Filtering
524                }),
525                ShaderBindingType::StorageBuffer(size, access) => BindingType::Buffer {
526                    ty: wgpu::BufferBindingType::Storage {
527                        read_only: access.contains(StorageAccess::READ)
528                            && !access.contains(StorageAccess::WRITE),
529                    },
530                    has_dynamic_offset: false,
531                    min_binding_size: if size == u32::MAX {
532                        None
533                    } else {
534                        wgpu::BufferSize::new(size as u64)
535                    },
536                },
537                ShaderBindingType::StorageTexture(access) => BindingType::StorageTexture {
538                    access: if access.contains(StorageAccess::READ)
539                        && access.contains(StorageAccess::WRITE)
540                    {
541                        wgpu::StorageTextureAccess::ReadWrite
542                    } else if access.contains(StorageAccess::READ) {
543                        wgpu::StorageTextureAccess::ReadOnly
544                    } else if access.contains(StorageAccess::WRITE) {
545                        wgpu::StorageTextureAccess::WriteOnly
546                    } else if access.contains(StorageAccess::ATOMIC) {
547                        wgpu::StorageTextureAccess::Atomic
548                    } else {
549                        panic!("Invalid storage texture access")
550                    },
551                    format: wgpu::TextureFormat::Rgba8Unorm,
552                    view_dimension: wgpu::TextureViewDimension::D2,
553                },
554                _ => unreachable!(),
555            }
556        }
557
558        for reflect in reflects {
559            match reflect {
560                ShaderReflect::Vertex { bindings, .. } => {
561                    for binding in bindings.iter() {
562                        let ty = create_layout_ty(binding.ty.clone());
563                        let existing =
564                            find_existing(&mut layouts, binding.group, binding.binding, ty);
565                        if let Some(existing) = existing {
566                            existing.visibility |= ShaderStages::VERTEX;
567                            crate::dbg_log!(
568                                "BindGroupLayout: group {}, binding: {}, ty: {:?} (existing)",
569                                binding.group,
570                                binding.binding,
571                                binding.ty
572                            );
573                        } else {
574                            // Push new layout entry
575                            let layout_desc = wgpu::BindGroupLayoutEntry {
576                                ty,
577                                binding: binding.binding,
578                                visibility: ShaderStages::VERTEX,
579                                count: None,
580                            };
581
582                            let group = layouts.entry(binding.group).or_insert_with(Vec::new);
583
584                            crate::dbg_log!(
585                                "BindGroupLayout: group {}, binding: {}, ty: {:?}",
586                                binding.group,
587                                binding.binding,
588                                binding.ty
589                            );
590                            group.push(layout_desc);
591                        }
592                    }
593                }
594                ShaderReflect::Fragment { bindings, .. } => {
595                    for binding in bindings.iter() {
596                        let ty = create_layout_ty(binding.ty.clone());
597                        let existing =
598                            find_existing(&mut layouts, binding.group, binding.binding, ty);
599                        if let Some(existing) = existing {
600                            existing.visibility |= ShaderStages::FRAGMENT;
601                            crate::dbg_log!(
602                                "BindGroupLayout: group {}, binding: {}, ty: {:?} (existing)",
603                                binding.group,
604                                binding.binding,
605                                binding.ty
606                            );
607                        } else {
608                            // Push new layout entry
609                            let layout_desc = wgpu::BindGroupLayoutEntry {
610                                ty,
611                                binding: binding.binding,
612                                visibility: ShaderStages::FRAGMENT,
613                                count: None,
614                            };
615
616                            let group = layouts.entry(binding.group).or_insert_with(Vec::new);
617
618                            crate::dbg_log!(
619                                "BindGroupLayout: group {}, binding: {}, ty: {:?}",
620                                binding.group,
621                                binding.binding,
622                                binding.ty
623                            );
624                            group.push(layout_desc);
625                        }
626                    }
627                }
628                ShaderReflect::VertexFragment { bindings, .. } => {
629                    for binding in bindings.iter() {
630                        let ty = create_layout_ty(binding.ty.clone());
631
632                        // Push new layout entry
633                        let layout_desc = wgpu::BindGroupLayoutEntry {
634                            ty,
635                            binding: binding.binding,
636                            visibility: ShaderStages::VERTEX_FRAGMENT,
637                            count: None,
638                        };
639
640                        let group = layouts.entry(binding.group).or_insert_with(Vec::new);
641
642                        crate::dbg_log!(
643                            "BindGroupLayout: group {}, binding: {}, ty: {:?}",
644                            binding.group,
645                            binding.binding,
646                            binding.ty
647                        );
648                        group.push(layout_desc);
649                    }
650                }
651                _ => continue,
652            }
653        }
654
655        let mut layout_vec = layouts.into_iter().collect::<Vec<_>>();
656        layout_vec.sort_by_key(|(group, _)| *group);
657        layout_vec
658            .into_iter()
659            .map(|(group, layout)| {
660                // Label: "BindGroupLayout for group {group}, binding: {binding} (ex: 0, 1, 2)"
661                let label = if !layout.is_empty() {
662                    let mut s = format!("BindGroupLayout for group {}, binding: ", group);
663                    for (i, entry) in layout.iter().enumerate() {
664                        s.push_str(&entry.binding.to_string());
665                        if i != layout.len() - 1 {
666                            s.push_str(", ");
667                        }
668                    }
669                    Some(s)
670                } else {
671                    None
672                };
673
674                let bind_group_layout =
675                    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
676                        label: label.as_deref(),
677                        entries: &layout,
678                    });
679
680                crate::dbg_log!(
681                    "Created BindGroupLayout for group {} with {} entries",
682                    group,
683                    layout.len()
684                );
685
686                BindGroupLayout {
687                    group,
688                    bindings: layout.iter().map(|entry| entry.binding).collect(),
689                    layout: bind_group_layout,
690                }
691            })
692            .collect()
693    }
694
695    pub fn get_uniform_location(&self, name: &str) -> Option<(u32, u32)> {
696        let inner = self.inner.borrow();
697
698        let reflection = &inner.reflection;
699        for reflect in reflection.iter() {
700            match reflect {
701                ShaderReflect::Vertex { bindings, .. } => {
702                    if let Some(binding) = bindings.iter().find(|b| {
703                        b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
704                    }) {
705                        return Some((binding.group, binding.binding));
706                    }
707                }
708                ShaderReflect::Fragment { bindings, .. } => {
709                    if let Some(binding) = bindings.iter().find(|b| {
710                        b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
711                    }) {
712                        return Some((binding.group, binding.binding));
713                    }
714                }
715                ShaderReflect::VertexFragment { bindings, .. } => {
716                    if let Some(binding) = bindings.iter().find(|b| {
717                        b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
718                    }) {
719                        return Some((binding.group, binding.binding));
720                    }
721                }
722                _ => continue,
723            }
724        }
725
726        None
727    }
728
729    pub fn get_uniform_size(&self, group: u32, binding: u32) -> Option<u32> {
730        let inner = self.inner.borrow();
731
732        let reflection = &inner.reflection;
733        for reflect in reflection.iter() {
734            match reflect {
735                ShaderReflect::Vertex { bindings, .. } => {
736                    if let Some(binding) = bindings
737                        .iter()
738                        .find(|b| b.group == group && b.binding == binding)
739                    {
740                        if let ShaderBindingType::UniformBuffer(size) = binding.ty {
741                            return Some(size);
742                        }
743                    }
744                }
745                ShaderReflect::Fragment { bindings, .. } => {
746                    if let Some(binding) = bindings
747                        .iter()
748                        .find(|b| b.group == group && b.binding == binding)
749                    {
750                        if let ShaderBindingType::UniformBuffer(size) = binding.ty {
751                            return Some(size);
752                        }
753                    }
754                }
755                ShaderReflect::VertexFragment { bindings, .. } => {
756                    if let Some(binding) = bindings
757                        .iter()
758                        .find(|b| b.group == group && b.binding == binding)
759                    {
760                        if let ShaderBindingType::UniformBuffer(size) = binding.ty {
761                            return Some(size);
762                        }
763                    }
764                }
765                _ => continue,
766            }
767        }
768
769        None
770    }
771
772    pub fn set_topology(&mut self, topology: ShaderTopology) -> Result<(), String> {
773        self.attrib.borrow_mut().topology = topology;
774        Ok(())
775    }
776
777    pub fn set_cull_mode(&mut self, cull_mode: Option<ShaderCullMode>) -> Result<(), String> {
778        self.attrib.borrow_mut().cull_mode = cull_mode;
779        Ok(())
780    }
781
782    pub fn set_polygon_mode(&mut self, polygon_mode: ShaderPollygonMode) -> Result<(), String> {
783        self.attrib.borrow_mut().polygon_mode = polygon_mode;
784        Ok(())
785    }
786
787    pub fn set_front_face(&mut self, front_face: ShaderFrontFace) -> Result<(), String> {
788        self.attrib.borrow_mut().front_face = front_face;
789        Ok(())
790    }
791
792    pub fn set_vertex_index_ty(&mut self, index_ty: Option<IndexBufferSize>) -> Result<(), String> {
793        self.attrib.borrow_mut().index = index_ty;
794        Ok(())
795    }
796
797    pub fn set_vertex_input(
798        &mut self,
799        location: u32,
800        vtype: VertexInputType,
801    ) -> Result<(), String> {
802        let inner = self.inner.borrow_mut();
803
804        let vertex_input = match inner.reflection.first() {
805            Some(ShaderReflect::Vertex { input, .. }) => input.as_ref(),
806            Some(ShaderReflect::VertexFragment { vertex_input, .. }) => vertex_input.as_ref(),
807            _ => None,
808        };
809
810        if vertex_input.is_none() {
811            return Err("Shader does not have vertex input".to_string());
812        }
813
814        let vertex_input = vertex_input.unwrap();
815
816        let input = vertex_input
817            .attributes
818            .iter()
819            .find(|attr| attr.0 == location);
820        if input.is_none() {
821            return Err(format!("Vertex input location {} not found", location));
822        }
823
824        let (location, _offset, og_vtype) = input.unwrap();
825        if !is_format_conversion_supported(*og_vtype, vtype) {
826            return Err(format!(
827                "Vertex input type {:?} is not supported for location {}",
828                vtype, location
829            ));
830        }
831
832        let mut attrib = self.attrib.borrow_mut();
833        let vertex_input_attrib = attrib
834            .attributes
835            .iter_mut()
836            .find(|attr| attr.shader_location == *location);
837
838        if vertex_input_attrib.is_none() {
839            return Err(format!(
840                "Vertex input location {} not found in shader attributes",
841                location
842            ));
843        }
844
845        let vertex_input_attrib = vertex_input_attrib.unwrap();
846        vertex_input_attrib.format = vtype.into();
847
848        Ok(())
849    }
850}
851
852impl std::hash::Hash for GraphicsShader {
853    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
854        ArcRef::as_ptr(&self.graphics).hash(state);
855        self.inner.hash(state);
856        self.attrib.hash(state);
857    }
858}
859
860// origin is O.G value based on reflection, and target is user-input type
861// example: if origin is Float32 and target is Unorm32, then it is supported
862#[inline]
863fn is_format_conversion_supported(origin: VertexInputType, target: VertexInputType) -> bool {
864    match origin {
865        VertexInputType::Float32 => match target {
866            VertexInputType::Float32 => true,
867            VertexInputType::Snorm8 => true,
868            VertexInputType::Unorm8 => true,
869            VertexInputType::Snorm16 => true,
870            _ => false,
871        },
872        VertexInputType::Float32x2 => match target {
873            VertexInputType::Float32x2 => true,
874            VertexInputType::Snorm8x2 => true,
875            VertexInputType::Unorm8x2 => true,
876            VertexInputType::Snorm16x2 => true,
877            _ => false,
878        },
879        VertexInputType::Float32x3 => {
880            match target {
881                VertexInputType::Float32x3 => true,
882                // normalized types are not supported for 3-component vectors
883                _ => false,
884            }
885        }
886        VertexInputType::Float32x4 => match target {
887            VertexInputType::Float32x4 => true,
888            VertexInputType::Snorm8x4 => true,
889            VertexInputType::Unorm8x4 => true,
890            VertexInputType::Snorm16x4 => true,
891            _ => false,
892        },
893        VertexInputType::Uint32 => match target {
894            VertexInputType::Uint32 => true,
895            VertexInputType::Uint16 => true,
896            VertexInputType::Uint8 => true,
897            _ => false,
898        },
899        VertexInputType::Uint32x2 => match target {
900            VertexInputType::Uint32x2 => true,
901            VertexInputType::Uint16x2 => true,
902            VertexInputType::Uint8x2 => true,
903            _ => false,
904        },
905        VertexInputType::Uint32x3 => match target {
906            VertexInputType::Uint32x3 => true,
907            VertexInputType::Uint16x4 => true,
908            VertexInputType::Uint8x4 => true,
909            _ => false,
910        },
911        VertexInputType::Uint32x4 => match target {
912            VertexInputType::Uint32x4 => true,
913            VertexInputType::Uint16x4 => true,
914            VertexInputType::Uint8x4 => true,
915            _ => false,
916        },
917        _ => origin == target,
918    }
919}
920
921// impl PartialEq for VertexInputDescription {
922//     fn eq(&self, other: &Self) -> bool {
923//         self.index == other.index
924//             && self.topology == other.topology
925//             && self.cull_mode == other.cull_mode
926//             && self.polygon_mode == other.polygon_mode
927//             && self.front_face == other.front_face
928//             && self.stride == other.stride
929//             && self.attributes == other.attributes
930//     }
931// }
932
933// impl PartialEq for GraphicsShaderInner {
934//     fn eq(&self, other: &Self) -> bool {
935//         // self.ty == other.ty && self.reflection == other.reflection
936//         self.ty == other.ty
937//     }
938// }
939
940impl PartialEq for GraphicsShader {
941    fn eq(&self, other: &Self) -> bool {
942        ArcRef::ptr_eq(&self.graphics, &other.graphics)
943            && ArcRef::ptr_eq(&self.inner, &other.inner)
944            && ArcRef::ptr_eq(&self.attrib, &other.attrib)
945    }
946}