1use core::str;
2use std::io::{Cursor, Read};
3
4use byteorder_lite::{LittleEndian, ReadBytesExt};
5use wgpu::naga::{
6 AddressSpace, ArraySize, Binding, Module, Scalar, ScalarKind, ShaderStage, TypeInner,
7 VectorSize,
8};
9
10use super::types::{
11 ShaderBindingInfo, ShaderBindingType, ShaderReflect, StorageAccess, VertexInputReflection,
12 VertexInputType,
13};
14
15pub fn is_shader_valid(data: &str) -> bool {
16 match wgpu::naga::front::wgsl::parse_str(data) {
17 Ok(module) => {
18 let res = parse(module);
19 res.is_ok()
20 }
21 Err(err) => {
22 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
23 eprintln!("Shader validation error: {:?}", err);
24 false
25 }
26 }
27}
28
29pub struct BinaryShader {
30 pub spirv: Vec<u8>,
31 pub reflect: ShaderReflect,
32}
33
34const BINARY_SHADER_MAGIC: [u8; 20] = *b"est-binary-shader-v1";
35
36fn read_u32(cursor: &mut Cursor<&[u8]>) -> Result<u32, String> {
37 cursor
38 .read_u32::<LittleEndian>()
39 .map_err(|_| "Failed to read u32".to_string())
40}
41
42fn read_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, String> {
43 cursor
44 .read_u64::<LittleEndian>()
45 .map_err(|_| "Failed to read u64".to_string())
46}
47
48fn read_bytes(cursor: &mut Cursor<&[u8]>, len: usize) -> Result<Vec<u8>, String> {
49 let mut buf = vec![0; len];
50 cursor
51 .read_exact(&mut buf)
52 .map_err(|_| "Failed to read bytes".to_string())?;
53 Ok(buf)
54}
55
56fn read_utf8_string(cursor: &mut Cursor<&[u8]>, len: usize) -> Result<String, String> {
57 let bytes = read_bytes(cursor, len)?;
58 String::from_utf8(bytes).map_err(|_| "Invalid UTF-8 string".to_string())
59}
60
61pub fn load_binary_shader(data: &[u8]) -> Result<BinaryShader, String> {
62 let mut cursor = Cursor::new(data);
63
64 let mut magic = [0; 20];
65 cursor
66 .read_exact(&mut magic)
67 .map_err(|_| "Failed to read magic".to_string())?;
68 if magic != BINARY_SHADER_MAGIC {
69 return Err("Invalid shader magic".to_string());
70 }
71
72 let shader_type_id = read_u32(&mut cursor)?;
73
74 let entry_point_sz = read_u32(&mut cursor)?;
75 let entry_point = read_utf8_string(&mut cursor, entry_point_sz as usize)?;
76
77 let binding_count = read_u32(&mut cursor)?;
78 let mut bindings = Vec::with_capacity(binding_count as usize);
79
80 for _ in 0..binding_count {
81 let group = read_u32(&mut cursor)?;
82 let binding = read_u32(&mut cursor)?;
83 let name_sz = read_u32(&mut cursor)?;
84 let name = read_utf8_string(&mut cursor, name_sz as usize)?;
85 let ty = match read_u32(&mut cursor)? {
86 0 => ShaderBindingType::UniformBuffer(read_u32(&mut cursor)?),
87 1 => {
88 let size = read_u32(&mut cursor)?;
89 let access = StorageAccess::from_bits(read_u32(&mut cursor)?)
90 .ok_or("Invalid storage access")?;
91 ShaderBindingType::StorageBuffer(size, access)
92 }
93 2 => {
94 let access = StorageAccess::from_bits(read_u32(&mut cursor)?)
95 .ok_or("Invalid storage texture access")?;
96 ShaderBindingType::StorageTexture(access)
97 }
98 3 => ShaderBindingType::Sampler(read_u32(&mut cursor)? != 0),
99 4 => ShaderBindingType::Texture(read_u32(&mut cursor)? != 0),
100 5 => ShaderBindingType::PushConstant(read_u32(&mut cursor)?),
101 t => return Err(format!("Unknown binding type ID: {}", t)),
102 };
103
104 bindings.push(ShaderBindingInfo {
105 binding,
106 group,
107 name,
108 ty,
109 });
110 }
111
112 let vertex_input = if shader_type_id == 0 || shader_type_id == 2 {
113 let name_sz = read_u32(&mut cursor)?;
114 let name = read_utf8_string(&mut cursor, name_sz as usize)?;
115 let stride = read_u32(&mut cursor)? as u64;
116 let attr_count = read_u32(&mut cursor)?;
117 let mut attributes = Vec::with_capacity(attr_count as usize);
118
119 for _ in 0..attr_count {
120 let location = read_u32(&mut cursor)?;
121 let offset = read_u64(&mut cursor)?;
122 let ty_id = read_u32(&mut cursor)?;
123 let ty = match ty_id {
124 0 => VertexInputType::Float32,
125 1 => VertexInputType::Float32x2,
126 2 => VertexInputType::Float32x3,
127 3 => VertexInputType::Float32x4,
128 4 => VertexInputType::Sint32,
129 5 => VertexInputType::Sint32x2,
130 6 => VertexInputType::Sint32x3,
131 7 => VertexInputType::Sint32x4,
132 8 => VertexInputType::Uint32,
133 9 => VertexInputType::Uint32x2,
134 10 => VertexInputType::Uint32x3,
135 11 => VertexInputType::Uint32x4,
136 _ => return Err(format!("Invalid vertex input type: {}", ty_id)),
137 };
138 attributes.push((location, offset, ty));
139 }
140
141 Some(VertexInputReflection {
142 name,
143 stride,
144 attributes,
145 })
146 } else {
147 None
148 };
149
150 let reflect = match shader_type_id {
151 0 => ShaderReflect::Vertex {
152 entry_point,
153 input: vertex_input,
154 bindings,
155 },
156 1 => ShaderReflect::Fragment {
157 entry_point,
158 bindings,
159 },
160 2 => {
161 let parts: Vec<&str> = entry_point.split(',').collect();
162 if parts.len() != 2 {
163 return Err("Invalid vertex/fragment entry point format".to_string());
164 }
165 ShaderReflect::VertexFragment {
166 vertex_entry_point: parts[0].to_string(),
167 vertex_input,
168 fragment_entry_point: parts[1].to_string(),
169 bindings,
170 }
171 }
172 3 => ShaderReflect::Compute {
173 entry_point,
174 bindings,
175 },
176 t => return Err(format!("Unknown shader type ID: {}", t)),
177 };
178
179 let spirv_sz = read_u32(&mut cursor)?;
180 let spirv = read_bytes(&mut cursor, spirv_sz as usize)?;
181
182 Ok(BinaryShader { spirv, reflect })
183}
184
185pub(crate) fn parse(module: Module) -> Result<ShaderReflect, String> {
186 let mut bindings = Vec::new();
187 for (handle, var) in module.global_variables.iter() {
188 if let Some(binding) = &var.binding {
189 match var.space {
190 AddressSpace::Uniform => {
191 let ty = &module.types[var.ty];
192 let size = get_size(&module, &ty.inner);
193 let var_name = var
194 .name
195 .clone()
196 .unwrap_or_else(|| format!("unnamed_{:?}", handle));
197
198 if size <= 16 {
199 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
201 return Err(format!(
202 "Uniform variable '{}' is too small ({} bytes), must be at least 16 bytes",
203 var_name, size
204 ));
205 }
206
207 let binding_info = ShaderBindingInfo {
208 binding: binding.binding as u32,
209 group: binding.group as u32,
210 name: var_name,
211 ty: ShaderBindingType::UniformBuffer(size as u32),
212 };
213
214 bindings.push(binding_info);
215 }
216
217 AddressSpace::PushConstant => {
218 let ty = &module.types[var.ty];
219 let size = get_size(&module, &ty.inner);
220 let var_name = var
221 .name
222 .clone()
223 .unwrap_or_else(|| format!("unnamed_{:?}", handle));
224
225 let binding_info = ShaderBindingInfo {
226 binding: binding.binding as u32,
227 group: binding.group as u32,
228 name: var_name,
229 ty: ShaderBindingType::PushConstant(size as u32),
230 };
231
232 bindings.push(binding_info);
233 }
234
235 AddressSpace::Storage { access: _access } => {
236 let ty = &module.types[var.ty];
237 let var_name = var
238 .name
239 .clone()
240 .unwrap_or_else(|| format!("unnamed_{:?}", handle));
241
242 let mut access = StorageAccess::empty();
243 if _access.contains(wgpu::naga::StorageAccess::LOAD) {
244 access |= StorageAccess::READ
245 }
246
247 if _access.contains(wgpu::naga::StorageAccess::STORE) {
248 access |= StorageAccess::WRITE;
249 }
250
251 if _access.contains(wgpu::naga::StorageAccess::ATOMIC) {
252 access |= StorageAccess::ATOMIC;
253 }
254
255 match &ty.inner {
256 TypeInner::Struct {
257 members: _,
258 span: _,
259 } => {
260 let size = get_size(&module, &ty.inner);
261
262 let binding_info = ShaderBindingInfo {
263 binding: binding.binding as u32,
264 group: binding.group as u32,
265 name: var_name,
266 ty: ShaderBindingType::StorageBuffer(size as u32, access),
267 };
268
269 bindings.push(binding_info);
270 }
271
272 TypeInner::Image {
273 dim: _,
274 arrayed: _,
275 class: _,
276 } => {
277 let binding_info = ShaderBindingInfo {
278 binding: binding.binding as u32,
279 group: binding.group as u32,
280 name: var_name,
281 ty: ShaderBindingType::StorageTexture(access),
282 };
283
284 bindings.push(binding_info);
285 }
286
287 TypeInner::Array {
288 base: _,
289 size,
290 stride: _,
291 } => {
292 let count = match size {
293 ArraySize::Constant(size) => size.get(),
294 _ => u32::MAX, };
296
297 let binding_info = ShaderBindingInfo {
298 binding: binding.binding as u32,
299 group: binding.group as u32,
300 name: var_name,
301 ty: ShaderBindingType::StorageBuffer(count, access),
302 };
303
304 bindings.push(binding_info);
305 }
306
307 _ => {}
308 }
309 }
310
311 AddressSpace::Handle => {
312 let ty = &module.types[var.ty];
315 let var_name = var
316 .name
317 .clone()
318 .unwrap_or_else(|| format!("unnamed_{:?}", handle));
319
320 match ty.inner {
321 TypeInner::Sampler { comparison } => {
322 let binding_info = ShaderBindingInfo {
323 binding: binding.binding as u32,
324 group: binding.group as u32,
325 name: var_name,
326 ty: ShaderBindingType::Sampler(comparison),
327 };
328
329 bindings.push(binding_info);
330 }
331
332 TypeInner::Image {
333 dim: _,
334 arrayed: _,
335 class,
336 } => {
337 let binding_info = ShaderBindingInfo {
338 binding: binding.binding as u32,
339 group: binding.group as u32,
340 name: var_name,
341 ty: ShaderBindingType::Texture(match class {
342 wgpu::naga::ImageClass::Sampled { kind: _, multi } => multi,
343 wgpu::naga::ImageClass::Depth { multi } => multi,
344 wgpu::naga::ImageClass::Storage {
345 format: _,
346 access: _,
347 } => {
348 return Err("Storage image should be handled separately"
350 .to_string());
351 }
352 }),
353 };
354
355 bindings.push(binding_info);
356 }
357
358 _ => {}
359 }
360 }
361
362 _ => {}
363 }
364 }
365 }
366
367 bindings.sort_by(|a, b| {
373 if a.group == b.group {
374 a.binding.cmp(&b.binding)
375 } else {
376 a.group.cmp(&b.group)
377 }
378 });
379
380 let mut vertex_entry_point = String::new();
382 let mut fragment_entry_point = String::new();
383 let mut compute_entry_point = String::new();
384
385 let mut vertex_struct_input = None;
386
387 #[allow(unused)]
388 for entry_point in module.entry_points.iter() {
389 match entry_point.stage {
390 ShaderStage::Vertex => {
391 vertex_entry_point = entry_point.name.clone();
392
393 for vertex_input in entry_point.function.arguments.iter() {
403 let ty = &module.types[vertex_input.ty];
404
405 let struct_name = ty
406 .name
407 .clone()
408 .unwrap_or_else(|| format!("unnamed_{:?}", vertex_input.ty));
409
410 let mut attributes = Vec::new();
411 let mut total_size = 0;
412
413 match &ty.inner {
414 TypeInner::Struct { members, span } => {
415 for member in members.iter() {
416 let attribute_name = member
417 .name
418 .clone()
419 .unwrap_or_else(|| format!("unnamed_{:?}", member.ty));
420
421 let ty = &module.types[member.ty];
422 let size = get_size(&module, &ty.inner);
423 let location = member
424 .binding
425 .as_ref()
426 .and_then(|b| match b {
427 Binding::Location {
428 location,
429 interpolation: _,
430 sampling: _,
431 blend_src: _,
432 } => Some(*location as u32),
433 _ => None,
434 })
435 .unwrap_or_else(|| {
436 panic!("Vertex input must have a location binding")
437 });
438
439 match &ty.inner {
440 TypeInner::Scalar(scalar) => {
441 if let Some(vertex_input_type) =
442 mapping_to_vertex_input(scalar, None)
443 {
444 attributes.push((
445 location,
446 total_size as u64,
447 vertex_input_type,
448 ));
449
450 total_size += scalar_size(scalar);
451 } else {
452 return Err(format!(
458 "Unsupported vertex input type: {:?} for member: {}",
459 ty.inner, attribute_name
460 ));
461 }
462 }
463
464 TypeInner::Vector { size, scalar } => {
465 if let Some(vertex_input_type) =
466 mapping_to_vertex_input(scalar, Some(size))
467 {
468 attributes.push((
469 location,
470 total_size as u64,
471 vertex_input_type,
472 ));
473
474 total_size +=
475 vectorsize_as_u32(size) * scalar_size(scalar);
476 } else {
477 return Err(format!(
483 "Unsupported vertex vector input type: {:?} for member: {}",
484 ty.inner, attribute_name
485 ));
486 }
487 }
488
489 _ => {
490 return Err(format!(
496 "Unsupported vertex input type: {:?} for member: {}",
497 ty.inner, attribute_name
498 ));
499 }
500 }
501 }
502 }
503 _ => {}
504 }
505
506 vertex_struct_input = Some(VertexInputReflection {
507 name: struct_name,
508 stride: total_size as u64,
509 attributes,
510 });
511 }
512 }
513 ShaderStage::Fragment => fragment_entry_point = entry_point.name.clone(),
514 ShaderStage::Compute => compute_entry_point = entry_point.name.clone(),
515 _ => {
516 return Err(format!("Unsupported shader stage: {:?}", entry_point.stage));
519 }
520 }
521 }
522
523 if !vertex_entry_point.is_empty() && !fragment_entry_point.is_empty() {
524 return Ok(ShaderReflect::VertexFragment {
525 vertex_entry_point,
526 vertex_input: vertex_struct_input,
527 fragment_entry_point,
528 bindings,
529 });
530 }
531
532 if !vertex_entry_point.is_empty() {
533 return Ok(ShaderReflect::Vertex {
534 entry_point: vertex_entry_point,
535 input: vertex_struct_input,
536 bindings,
537 });
538 }
539
540 if !fragment_entry_point.is_empty() {
541 return Ok(ShaderReflect::Fragment {
542 entry_point: fragment_entry_point,
543 bindings,
544 });
545 }
546
547 if !compute_entry_point.is_empty() {
548 return Ok(ShaderReflect::Compute {
549 entry_point: compute_entry_point,
550 bindings,
551 });
552 }
553
554 Err("No valid entry point found in shader module".to_string())
555}
556
557pub(crate) fn mapping_to_vertex_input(
558 scalar: &Scalar,
559 vector: Option<&VectorSize>,
560) -> Option<VertexInputType> {
561 match scalar.kind {
562 ScalarKind::Float => {
563 if let Some(vector_size) = vector {
564 match vector_size {
565 VectorSize::Bi => Some(VertexInputType::Float32x2),
566 VectorSize::Tri => Some(VertexInputType::Float32x3),
567 VectorSize::Quad => Some(VertexInputType::Float32x4),
568 }
569 } else {
570 Some(VertexInputType::Float32)
571 }
572 }
573 ScalarKind::Sint => {
574 if let Some(vector_size) = vector {
575 match vector_size {
576 VectorSize::Bi => Some(VertexInputType::Sint32x2),
577 VectorSize::Tri => Some(VertexInputType::Sint32x3),
578 VectorSize::Quad => Some(VertexInputType::Sint32x4),
579 }
580 } else {
581 Some(VertexInputType::Sint32)
582 }
583 }
584 ScalarKind::Uint => {
585 if let Some(vector_size) = vector {
586 match vector_size {
587 VectorSize::Bi => Some(VertexInputType::Uint32x2),
588 VectorSize::Tri => Some(VertexInputType::Uint32x3),
589 VectorSize::Quad => Some(VertexInputType::Uint32x4),
590 }
591 } else {
592 Some(VertexInputType::Uint32)
593 }
594 }
595 ScalarKind::Bool => {
596 if let Some(vector_size) = vector {
597 match vector_size {
598 VectorSize::Bi => Some(VertexInputType::Uint32),
599 VectorSize::Tri => Some(VertexInputType::Uint32x3),
600 VectorSize::Quad => Some(VertexInputType::Uint32x4),
601 }
602 } else {
603 Some(VertexInputType::Uint32)
604 }
605 }
606 _ => None,
607 }
608}
609
610#[allow(unused_variables)]
611pub(crate) fn get_size(module: &Module, ty_inner: &TypeInner) -> i32 {
612 match ty_inner {
613 TypeInner::Scalar(scalar) => scalar_size(scalar) as i32,
614
615 TypeInner::Vector { size, scalar } => {
616 let scalar_size = scalar_size(scalar);
617 let vec_size = vectorsize_as_u32(size) * scalar_size;
618 align_to(vec_size, vector_alignment(size)) as i32 }
620
621 TypeInner::Matrix {
622 columns,
623 rows,
624 scalar,
625 } => {
626 let scalar_size = scalar_size(scalar);
627 let row_size = vectorsize_as_u32(rows) * scalar_size;
628 let aligned_row_size = align_to(row_size, 16); (vectorsize_as_u32(columns) * aligned_row_size) as i32
630 }
631
632 TypeInner::Array { base, size, stride } => {
633 let count = match size {
634 ArraySize::Constant(size) => size.get(),
635 _ => u32::MAX, };
637
638 if count == u32::MAX {
639 -1 } else {
641 (count * stride) as i32
642 }
643 }
644
645 TypeInner::Struct { members, span } => {
646 let mut max_alignment = 0;
647 let mut size = 0;
648 for member in members {
649 let ty = &module.types[member.ty];
650
651 let member_size = get_size(module, &ty.inner);
652 let alignment = std140_alignment(module, &ty.inner);
653 size = align_to(size, alignment) + member_size as u32;
654 max_alignment = max_alignment.max(alignment);
655 }
656
657 align_to(size, max_alignment) as i32 }
659
660 _ => 0, }
662}
663
664pub(crate) fn scalar_size(scalar: &Scalar) -> u32 {
665 match scalar.kind {
666 ScalarKind::Float => 4,
667 ScalarKind::Sint => 4,
668 ScalarKind::Uint => 4,
669 ScalarKind::Bool => 4,
670 _ => 0,
671 }
672}
673
674pub(crate) fn vectorsize_as_u32(size: &VectorSize) -> u32 {
675 match size {
676 VectorSize::Bi => 2,
677 VectorSize::Tri => 3,
678 VectorSize::Quad => 4,
679 }
680}
681
682pub(crate) fn std140_alignment(module: &Module, ty_inner: &TypeInner) -> u32 {
683 match ty_inner {
684 TypeInner::Scalar(_) => 4,
685 TypeInner::Vector { size, .. } => vector_alignment(size),
686 TypeInner::Matrix { .. } => 16,
687 TypeInner::Struct { members, .. } => members
688 .iter()
689 .map(|m| {
690 let r#type = &module.types[m.ty];
691 std140_alignment(module, &r#type.inner)
692 })
693 .max()
694 .unwrap_or(1),
695 _ => 1,
696 }
697}
698
699pub(crate) fn vector_alignment(size: &VectorSize) -> u32 {
700 match size {
701 VectorSize::Bi => 8, VectorSize::Tri => 16, VectorSize::Quad => 16, }
705}
706
707pub(crate) fn align_to(size: u32, alignment: u32) -> u32 {
708 (size + alignment - 1) & !(alignment - 1)
709}