est_render/gpu/shader/
reflection.rs

1use core::str;
2use std::io::{Cursor, Read};
3
4use byteorder_lite::{LittleEndian, ReadBytesExt};
5use wgpu::naga::{
6    AddressSpace, ArraySize, Binding, Module, Scalar, ScalarKind, ShaderStage, TypeInner,
7    VectorSize,
8};
9
10use super::types::{
11    ShaderBindingInfo, ShaderBindingType, ShaderReflect, StorageAccess, VertexInputReflection,
12    VertexInputType,
13};
14
15pub fn is_shader_valid(data: &str) -> bool {
16    match wgpu::naga::front::wgsl::parse_str(data) {
17        Ok(module) => {
18            let res = parse(module);
19            res.is_ok()
20        }
21        Err(err) => {
22            #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
23            eprintln!("Shader validation error: {:?}", err);
24            false
25        }
26    }
27}
28
29pub struct BinaryShader {
30    pub spirv: Vec<u8>,
31    pub reflect: ShaderReflect,
32}
33
34const BINARY_SHADER_MAGIC: [u8; 20] = *b"est-binary-shader-v1";
35
36fn read_u32(cursor: &mut Cursor<&[u8]>) -> Result<u32, String> {
37    cursor
38        .read_u32::<LittleEndian>()
39        .map_err(|_| "Failed to read u32".to_string())
40}
41
42fn read_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, String> {
43    cursor
44        .read_u64::<LittleEndian>()
45        .map_err(|_| "Failed to read u64".to_string())
46}
47
48fn read_bytes(cursor: &mut Cursor<&[u8]>, len: usize) -> Result<Vec<u8>, String> {
49    let mut buf = vec![0; len];
50    cursor
51        .read_exact(&mut buf)
52        .map_err(|_| "Failed to read bytes".to_string())?;
53    Ok(buf)
54}
55
56fn read_utf8_string(cursor: &mut Cursor<&[u8]>, len: usize) -> Result<String, String> {
57    let bytes = read_bytes(cursor, len)?;
58    String::from_utf8(bytes).map_err(|_| "Invalid UTF-8 string".to_string())
59}
60
61pub fn load_binary_shader(data: &[u8]) -> Result<BinaryShader, String> {
62    let mut cursor = Cursor::new(data);
63
64    let mut magic = [0; 20];
65    cursor
66        .read_exact(&mut magic)
67        .map_err(|_| "Failed to read magic".to_string())?;
68    if magic != BINARY_SHADER_MAGIC {
69        return Err("Invalid shader magic".to_string());
70    }
71
72    let shader_type_id = read_u32(&mut cursor)?;
73
74    let entry_point_sz = read_u32(&mut cursor)?;
75    let entry_point = read_utf8_string(&mut cursor, entry_point_sz as usize)?;
76
77    let binding_count = read_u32(&mut cursor)?;
78    let mut bindings = Vec::with_capacity(binding_count as usize);
79
80    for _ in 0..binding_count {
81        let group = read_u32(&mut cursor)?;
82        let binding = read_u32(&mut cursor)?;
83        let name_sz = read_u32(&mut cursor)?;
84        let name = read_utf8_string(&mut cursor, name_sz as usize)?;
85        let ty = match read_u32(&mut cursor)? {
86            0 => ShaderBindingType::UniformBuffer(read_u32(&mut cursor)?),
87            1 => {
88                let size = read_u32(&mut cursor)?;
89                let access = StorageAccess::from_bits(read_u32(&mut cursor)?)
90                    .ok_or("Invalid storage access")?;
91                ShaderBindingType::StorageBuffer(size, access)
92            }
93            2 => {
94                let access = StorageAccess::from_bits(read_u32(&mut cursor)?)
95                    .ok_or("Invalid storage texture access")?;
96                ShaderBindingType::StorageTexture(access)
97            }
98            3 => ShaderBindingType::Sampler(read_u32(&mut cursor)? != 0),
99            4 => ShaderBindingType::Texture(read_u32(&mut cursor)? != 0),
100            5 => ShaderBindingType::PushConstant(read_u32(&mut cursor)?),
101            t => return Err(format!("Unknown binding type ID: {}", t)),
102        };
103
104        bindings.push(ShaderBindingInfo {
105            binding,
106            group,
107            name,
108            ty,
109        });
110    }
111
112    let vertex_input = if shader_type_id == 0 || shader_type_id == 2 {
113        let name_sz = read_u32(&mut cursor)?;
114        let name = read_utf8_string(&mut cursor, name_sz as usize)?;
115        let stride = read_u32(&mut cursor)? as u64;
116        let attr_count = read_u32(&mut cursor)?;
117        let mut attributes = Vec::with_capacity(attr_count as usize);
118
119        for _ in 0..attr_count {
120            let location = read_u32(&mut cursor)?;
121            let offset = read_u64(&mut cursor)?;
122            let ty_id = read_u32(&mut cursor)?;
123            let ty = match ty_id {
124                0 => VertexInputType::Float32,
125                1 => VertexInputType::Float32x2,
126                2 => VertexInputType::Float32x3,
127                3 => VertexInputType::Float32x4,
128                4 => VertexInputType::Sint32,
129                5 => VertexInputType::Sint32x2,
130                6 => VertexInputType::Sint32x3,
131                7 => VertexInputType::Sint32x4,
132                8 => VertexInputType::Uint32,
133                9 => VertexInputType::Uint32x2,
134                10 => VertexInputType::Uint32x3,
135                11 => VertexInputType::Uint32x4,
136                _ => return Err(format!("Invalid vertex input type: {}", ty_id)),
137            };
138            attributes.push((location, offset, ty));
139        }
140
141        Some(VertexInputReflection {
142            name,
143            stride,
144            attributes,
145        })
146    } else {
147        None
148    };
149
150    let reflect = match shader_type_id {
151        0 => ShaderReflect::Vertex {
152            entry_point,
153            input: vertex_input,
154            bindings,
155        },
156        1 => ShaderReflect::Fragment {
157            entry_point,
158            bindings,
159        },
160        2 => {
161            let parts: Vec<&str> = entry_point.split(',').collect();
162            if parts.len() != 2 {
163                return Err("Invalid vertex/fragment entry point format".to_string());
164            }
165            ShaderReflect::VertexFragment {
166                vertex_entry_point: parts[0].to_string(),
167                vertex_input,
168                fragment_entry_point: parts[1].to_string(),
169                bindings,
170            }
171        }
172        3 => ShaderReflect::Compute {
173            entry_point,
174            bindings,
175        },
176        t => return Err(format!("Unknown shader type ID: {}", t)),
177    };
178
179    let spirv_sz = read_u32(&mut cursor)?;
180    let spirv = read_bytes(&mut cursor, spirv_sz as usize)?;
181
182    Ok(BinaryShader { spirv, reflect })
183}
184
185pub(crate) fn parse(module: Module) -> Result<ShaderReflect, String> {
186    let mut bindings = Vec::new();
187    for (handle, var) in module.global_variables.iter() {
188        if let Some(binding) = &var.binding {
189            match var.space {
190                AddressSpace::Uniform => {
191                    let ty = &module.types[var.ty];
192                    let size = get_size(&module, &ty.inner);
193                    let var_name = var
194                        .name
195                        .clone()
196                        .unwrap_or_else(|| format!("unnamed_{:?}", handle));
197
198                    if size <= 16 {
199                        // Uniforms smaller than 16 bytes are not supported
200                        #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
201                        return Err(format!(
202                            "Uniform variable '{}' is too small ({} bytes), must be at least 16 bytes",
203                            var_name, size
204                        ));
205                    }
206
207                    let binding_info = ShaderBindingInfo {
208                        binding: binding.binding as u32,
209                        group: binding.group as u32,
210                        name: var_name,
211                        ty: ShaderBindingType::UniformBuffer(size as u32),
212                    };
213
214                    bindings.push(binding_info);
215                }
216
217                AddressSpace::PushConstant => {
218                    let ty = &module.types[var.ty];
219                    let size = get_size(&module, &ty.inner);
220                    let var_name = var
221                        .name
222                        .clone()
223                        .unwrap_or_else(|| format!("unnamed_{:?}", handle));
224
225                    let binding_info = ShaderBindingInfo {
226                        binding: binding.binding as u32,
227                        group: binding.group as u32,
228                        name: var_name,
229                        ty: ShaderBindingType::PushConstant(size as u32),
230                    };
231
232                    bindings.push(binding_info);
233                }
234
235                AddressSpace::Storage { access: _access } => {
236                    let ty = &module.types[var.ty];
237                    let var_name = var
238                        .name
239                        .clone()
240                        .unwrap_or_else(|| format!("unnamed_{:?}", handle));
241
242                    let mut access = StorageAccess::empty();
243                    if _access.contains(wgpu::naga::StorageAccess::LOAD) {
244                        access |= StorageAccess::READ
245                    }
246
247                    if _access.contains(wgpu::naga::StorageAccess::STORE) {
248                        access |= StorageAccess::WRITE;
249                    }
250
251                    if _access.contains(wgpu::naga::StorageAccess::ATOMIC) {
252                        access |= StorageAccess::ATOMIC;
253                    }
254
255                    match &ty.inner {
256                        TypeInner::Struct {
257                            members: _,
258                            span: _,
259                        } => {
260                            let size = get_size(&module, &ty.inner);
261
262                            let binding_info = ShaderBindingInfo {
263                                binding: binding.binding as u32,
264                                group: binding.group as u32,
265                                name: var_name,
266                                ty: ShaderBindingType::StorageBuffer(size as u32, access),
267                            };
268
269                            bindings.push(binding_info);
270                        }
271
272                        TypeInner::Image {
273                            dim: _,
274                            arrayed: _,
275                            class: _,
276                        } => {
277                            let binding_info = ShaderBindingInfo {
278                                binding: binding.binding as u32,
279                                group: binding.group as u32,
280                                name: var_name,
281                                ty: ShaderBindingType::StorageTexture(access),
282                            };
283
284                            bindings.push(binding_info);
285                        }
286
287                        TypeInner::Array {
288                            base: _,
289                            size,
290                            stride: _,
291                        } => {
292                            let count = match size {
293                                ArraySize::Constant(size) => size.get(),
294                                _ => u32::MAX, // Default with unlimited sizes
295                            };
296
297                            let binding_info = ShaderBindingInfo {
298                                binding: binding.binding as u32,
299                                group: binding.group as u32,
300                                name: var_name,
301                                ty: ShaderBindingType::StorageBuffer(count, access),
302                            };
303
304                            bindings.push(binding_info);
305                        }
306
307                        _ => {}
308                    }
309                }
310
311                AddressSpace::Handle => {
312                    // Check if sampler, sampled texture, or storage texture
313
314                    let ty = &module.types[var.ty];
315                    let var_name = var
316                        .name
317                        .clone()
318                        .unwrap_or_else(|| format!("unnamed_{:?}", handle));
319
320                    match ty.inner {
321                        TypeInner::Sampler { comparison } => {
322                            let binding_info = ShaderBindingInfo {
323                                binding: binding.binding as u32,
324                                group: binding.group as u32,
325                                name: var_name,
326                                ty: ShaderBindingType::Sampler(comparison),
327                            };
328
329                            bindings.push(binding_info);
330                        }
331
332                        TypeInner::Image {
333                            dim: _,
334                            arrayed: _,
335                            class,
336                        } => {
337                            let binding_info = ShaderBindingInfo {
338                                binding: binding.binding as u32,
339                                group: binding.group as u32,
340                                name: var_name,
341                                ty: ShaderBindingType::Texture(match class {
342                                    wgpu::naga::ImageClass::Sampled { kind: _, multi } => multi,
343                                    wgpu::naga::ImageClass::Depth { multi } => multi,
344                                    wgpu::naga::ImageClass::Storage {
345                                        format: _,
346                                        access: _,
347                                    } => {
348                                        // panic!("Storage image should be handled separately")
349                                        return Err("Storage image should be handled separately"
350                                            .to_string());
351                                    }
352                                }),
353                            };
354
355                            bindings.push(binding_info);
356                        }
357
358                        _ => {}
359                    }
360                }
361
362                _ => {}
363            }
364        }
365    }
366
367    // sort the bindings by group first, then by binding
368    // A: 0, 0
369    // B: 0, 1
370    // C: 1, 0
371    // D: 1, 1
372    bindings.sort_by(|a, b| {
373        if a.group == b.group {
374            a.binding.cmp(&b.binding)
375        } else {
376            a.group.cmp(&b.group)
377        }
378    });
379
380    // get entry point
381    let mut vertex_entry_point = String::new();
382    let mut fragment_entry_point = String::new();
383    let mut compute_entry_point = String::new();
384
385    let mut vertex_struct_input = None;
386
387    #[allow(unused)]
388    for entry_point in module.entry_points.iter() {
389        match entry_point.stage {
390            ShaderStage::Vertex => {
391                vertex_entry_point = entry_point.name.clone();
392
393                /**
394                 * Example:
395                 *
396                 * struct VertexInput {
397                 *   @location(0) position: vec3<f32>,
398                 *   @location(1) color: vec4<f32>,
399                 *   @location(2) texCoord: vec2<f32>,
400                 * };
401                 */
402                for vertex_input in entry_point.function.arguments.iter() {
403                    let ty = &module.types[vertex_input.ty];
404
405                    let struct_name = ty
406                        .name
407                        .clone()
408                        .unwrap_or_else(|| format!("unnamed_{:?}", vertex_input.ty));
409
410                    let mut attributes = Vec::new();
411                    let mut total_size = 0;
412
413                    match &ty.inner {
414                        TypeInner::Struct { members, span } => {
415                            for member in members.iter() {
416                                let attribute_name = member
417                                    .name
418                                    .clone()
419                                    .unwrap_or_else(|| format!("unnamed_{:?}", member.ty));
420
421                                let ty = &module.types[member.ty];
422                                let size = get_size(&module, &ty.inner);
423                                let location = member
424                                    .binding
425                                    .as_ref()
426                                    .and_then(|b| match b {
427                                        Binding::Location {
428                                            location,
429                                            interpolation: _,
430                                            sampling: _,
431                                            blend_src: _,
432                                        } => Some(*location as u32),
433                                        _ => None,
434                                    })
435                                    .unwrap_or_else(|| {
436                                        panic!("Vertex input must have a location binding")
437                                    });
438
439                                match &ty.inner {
440                                    TypeInner::Scalar(scalar) => {
441                                        if let Some(vertex_input_type) =
442                                            mapping_to_vertex_input(scalar, None)
443                                        {
444                                            attributes.push((
445                                                location,
446                                                total_size as u64,
447                                                vertex_input_type,
448                                            ));
449
450                                            total_size += scalar_size(scalar);
451                                        } else {
452                                            // #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
453                                            // panic!(
454                                            //     "Unsupported vertex input type: {:?} for member: {}",
455                                            //     ty.inner, attribute_name
456                                            // );
457                                            return Err(format!(
458                                                "Unsupported vertex input type: {:?} for member: {}",
459                                                ty.inner, attribute_name
460                                            ));
461                                        }
462                                    }
463
464                                    TypeInner::Vector { size, scalar } => {
465                                        if let Some(vertex_input_type) =
466                                            mapping_to_vertex_input(scalar, Some(size))
467                                        {
468                                            attributes.push((
469                                                location,
470                                                total_size as u64,
471                                                vertex_input_type,
472                                            ));
473
474                                            total_size +=
475                                                vectorsize_as_u32(size) * scalar_size(scalar);
476                                        } else {
477                                            // #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
478                                            // panic!(
479                                            //     "Unsupported vertex vector input type: {:?} for member: {}",
480                                            //     ty.inner, attribute_name
481                                            // );
482                                            return Err(format!(
483                                                "Unsupported vertex vector input type: {:?} for member: {}",
484                                                ty.inner, attribute_name
485                                            ));
486                                        }
487                                    }
488
489                                    _ => {
490                                        // #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
491                                        // panic!(
492                                        //     "Unsupported vertex input type: {:?} for member: {}",
493                                        //     ty.inner, attribute_name
494                                        // );
495                                        return Err(format!(
496                                            "Unsupported vertex input type: {:?} for member: {}",
497                                            ty.inner, attribute_name
498                                        ));
499                                    }
500                                }
501                            }
502                        }
503                        _ => {}
504                    }
505
506                    vertex_struct_input = Some(VertexInputReflection {
507                        name: struct_name,
508                        stride: total_size as u64,
509                        attributes,
510                    });
511                }
512            }
513            ShaderStage::Fragment => fragment_entry_point = entry_point.name.clone(),
514            ShaderStage::Compute => compute_entry_point = entry_point.name.clone(),
515            _ => {
516                // #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
517                // panic!("Unsupported shader stage: {:?}", entry_point.stage);
518                return Err(format!("Unsupported shader stage: {:?}", entry_point.stage));
519            }
520        }
521    }
522
523    if !vertex_entry_point.is_empty() && !fragment_entry_point.is_empty() {
524        return Ok(ShaderReflect::VertexFragment {
525            vertex_entry_point,
526            vertex_input: vertex_struct_input,
527            fragment_entry_point,
528            bindings,
529        });
530    }
531
532    if !vertex_entry_point.is_empty() {
533        return Ok(ShaderReflect::Vertex {
534            entry_point: vertex_entry_point,
535            input: vertex_struct_input,
536            bindings,
537        });
538    }
539
540    if !fragment_entry_point.is_empty() {
541        return Ok(ShaderReflect::Fragment {
542            entry_point: fragment_entry_point,
543            bindings,
544        });
545    }
546
547    if !compute_entry_point.is_empty() {
548        return Ok(ShaderReflect::Compute {
549            entry_point: compute_entry_point,
550            bindings,
551        });
552    }
553
554    Err("No valid entry point found in shader module".to_string())
555}
556
557pub(crate) fn mapping_to_vertex_input(
558    scalar: &Scalar,
559    vector: Option<&VectorSize>,
560) -> Option<VertexInputType> {
561    match scalar.kind {
562        ScalarKind::Float => {
563            if let Some(vector_size) = vector {
564                match vector_size {
565                    VectorSize::Bi => Some(VertexInputType::Float32x2),
566                    VectorSize::Tri => Some(VertexInputType::Float32x3),
567                    VectorSize::Quad => Some(VertexInputType::Float32x4),
568                }
569            } else {
570                Some(VertexInputType::Float32)
571            }
572        }
573        ScalarKind::Sint => {
574            if let Some(vector_size) = vector {
575                match vector_size {
576                    VectorSize::Bi => Some(VertexInputType::Sint32x2),
577                    VectorSize::Tri => Some(VertexInputType::Sint32x3),
578                    VectorSize::Quad => Some(VertexInputType::Sint32x4),
579                }
580            } else {
581                Some(VertexInputType::Sint32)
582            }
583        }
584        ScalarKind::Uint => {
585            if let Some(vector_size) = vector {
586                match vector_size {
587                    VectorSize::Bi => Some(VertexInputType::Uint32x2),
588                    VectorSize::Tri => Some(VertexInputType::Uint32x3),
589                    VectorSize::Quad => Some(VertexInputType::Uint32x4),
590                }
591            } else {
592                Some(VertexInputType::Uint32)
593            }
594        }
595        ScalarKind::Bool => {
596            if let Some(vector_size) = vector {
597                match vector_size {
598                    VectorSize::Bi => Some(VertexInputType::Uint32),
599                    VectorSize::Tri => Some(VertexInputType::Uint32x3),
600                    VectorSize::Quad => Some(VertexInputType::Uint32x4),
601                }
602            } else {
603                Some(VertexInputType::Uint32)
604            }
605        }
606        _ => None,
607    }
608}
609
610#[allow(unused_variables)]
611pub(crate) fn get_size(module: &Module, ty_inner: &TypeInner) -> i32 {
612    match ty_inner {
613        TypeInner::Scalar(scalar) => scalar_size(scalar) as i32,
614
615        TypeInner::Vector { size, scalar } => {
616            let scalar_size = scalar_size(scalar);
617            let vec_size = vectorsize_as_u32(size) * scalar_size;
618            align_to(vec_size, vector_alignment(size)) as i32 // Ensure correct alignment
619        }
620
621        TypeInner::Matrix {
622            columns,
623            rows,
624            scalar,
625        } => {
626            let scalar_size = scalar_size(scalar);
627            let row_size = vectorsize_as_u32(rows) * scalar_size;
628            let aligned_row_size = align_to(row_size, 16); // Matrices align to 16 bytes per row
629            (vectorsize_as_u32(columns) * aligned_row_size) as i32
630        }
631
632        TypeInner::Array { base, size, stride } => {
633            let count = match size {
634                ArraySize::Constant(size) => size.get(),
635                _ => u32::MAX, // Default with unlimited sizes
636            };
637
638            if count == u32::MAX {
639                -1 // Indicate dynamic array
640            } else {
641                (count * stride) as i32
642            }
643        }
644
645        TypeInner::Struct { members, span } => {
646            let mut max_alignment = 0;
647            let mut size = 0;
648            for member in members {
649                let ty = &module.types[member.ty];
650
651                let member_size = get_size(module, &ty.inner);
652                let alignment = std140_alignment(module, &ty.inner);
653                size = align_to(size, alignment) + member_size as u32;
654                max_alignment = max_alignment.max(alignment);
655            }
656
657            align_to(size, max_alignment) as i32 // Ensure struct is padded to its largest member
658        }
659
660        _ => 0, // Other types like images, samplers, and pointers are not sized
661    }
662}
663
664pub(crate) fn scalar_size(scalar: &Scalar) -> u32 {
665    match scalar.kind {
666        ScalarKind::Float => 4,
667        ScalarKind::Sint => 4,
668        ScalarKind::Uint => 4,
669        ScalarKind::Bool => 4,
670        _ => 0,
671    }
672}
673
674pub(crate) fn vectorsize_as_u32(size: &VectorSize) -> u32 {
675    match size {
676        VectorSize::Bi => 2,
677        VectorSize::Tri => 3,
678        VectorSize::Quad => 4,
679    }
680}
681
682pub(crate) fn std140_alignment(module: &Module, ty_inner: &TypeInner) -> u32 {
683    match ty_inner {
684        TypeInner::Scalar(_) => 4,
685        TypeInner::Vector { size, .. } => vector_alignment(size),
686        TypeInner::Matrix { .. } => 16,
687        TypeInner::Struct { members, .. } => members
688            .iter()
689            .map(|m| {
690                let r#type = &module.types[m.ty];
691                std140_alignment(module, &r#type.inner)
692            })
693            .max()
694            .unwrap_or(1),
695        _ => 1,
696    }
697}
698
699pub(crate) fn vector_alignment(size: &VectorSize) -> u32 {
700    match size {
701        VectorSize::Bi => 8,    // vec2 = 8-byte aligned
702        VectorSize::Tri => 16,  // vec3 = 16-byte aligned
703        VectorSize::Quad => 16, // vec4 = 16-byte aligned
704    }
705}
706
707pub(crate) fn align_to(size: u32, alignment: u32) -> u32 {
708    (size + alignment - 1) & !(alignment - 1)
709}