1use core::panic;
2use std::{borrow::Cow, collections::HashMap, hash::Hash};
3
4use wgpu::{BindingType, SamplerBindingType, ShaderRuntimeChecks, ShaderStages, naga::front::wgsl};
5
6use crate::{
7 utils::ArcRef,
8};
9
10use super::{
11 types::{
12 BindGroupLayout, IndexBufferSize,
13 ShaderBindingType, ShaderCullMode,
14 ShaderFrontFace, ShaderPollygonMode,
15 ShaderReflect, ShaderTopology,
16 StorageAccess, VertexInputType,
17 VertexInputReflection,
18 },
19 super::GPUInner,
20};
21
22pub(crate) enum GraphicsShaderSource {
23 None,
24 Source(String),
25 SplitSource(String, String),
26 BinarySource(Vec<u8>),
27 BinarySplitSource(Vec<u8>, Vec<u8>),
28}
29
30pub struct GraphicsShaderBuilder {
35 pub(crate) graphics: ArcRef<GPUInner>,
36 pub(crate) source: GraphicsShaderSource,
37}
38
39impl GraphicsShaderBuilder {
40 pub(crate) fn new(graphics: ArcRef<GPUInner>) -> Self {
41 Self {
42 graphics,
43 source: GraphicsShaderSource::None,
44 }
45 }
46
47 pub fn set_file(mut self, path: &str) -> Self {
49 let data = std::fs::read_to_string(path);
50 if let Err(err) = data {
51 panic!("Failed to read shader file: {:?}", err);
52 }
53
54 self.source = GraphicsShaderSource::Source(data.unwrap());
55
56 self
57 }
58
59 pub fn set_source(mut self, source: &str) -> Self {
61 self.source = GraphicsShaderSource::Source(source.to_string());
62 self
63 }
64
65 pub fn set_vertex_file(mut self, path: &str) -> Self {
69 let data = std::fs::read_to_string(path);
70 if let Err(err) = data {
71 panic!("Failed to read vertex shader file: {:?}", err);
72 }
73
74 match self.source {
75 GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
76 self.source =
77 GraphicsShaderSource::SplitSource(data.unwrap(), vertex_source.clone());
78 }
79 _ => {
80 self.source = GraphicsShaderSource::SplitSource(data.unwrap(), "".to_string());
81 }
82 }
83
84 self
85 }
86
87 pub fn set_fragment_file(mut self, path: &str) -> Self {
91 let data = std::fs::read_to_string(path);
92 if let Err(err) = data {
93 panic!("Failed to read fragment shader file: {:?}", err);
94 }
95
96 match self.source {
97 GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
98 self.source =
99 GraphicsShaderSource::SplitSource(vertex_source.clone(), data.unwrap());
100 }
101 _ => {
102 self.source = GraphicsShaderSource::SplitSource("".to_string(), data.unwrap());
103 }
104 }
105
106 self
107 }
108
109 pub fn set_vertex_code(mut self, source: &str) -> Self {
113 match self.source {
114 GraphicsShaderSource::SplitSource(_, ref mut fragment_source) => {
115 self.source =
116 GraphicsShaderSource::SplitSource(source.to_string(), fragment_source.clone());
117 }
118 _ => {
119 self.source = GraphicsShaderSource::SplitSource(source.to_string(), "".to_string());
120 }
121 }
122
123 self
124 }
125
126 pub fn set_fragment_code(mut self, source: &str) -> Self {
130 match self.source {
131 GraphicsShaderSource::SplitSource(ref mut vertex_source, _) => {
132 self.source =
133 GraphicsShaderSource::SplitSource(vertex_source.clone(), source.to_string());
134 }
135 _ => {
136 self.source = GraphicsShaderSource::SplitSource("".to_string(), source.to_string());
137 }
138 }
139
140 self
141 }
142
143 pub fn set_binary_source(mut self, binary: &[u8]) -> Self {
147 self.source = GraphicsShaderSource::BinarySource(binary.to_vec());
148 self
149 }
150
151 pub fn set_binary_file(mut self, path: &str) -> Self {
155 let data = std::fs::read(path);
156 if let Err(err) = data {
157 panic!("Failed to read binary shader file: {:?}", err);
158 }
159
160 self.source = GraphicsShaderSource::BinarySource(data.unwrap());
161 self
162 }
163
164 pub fn set_binary_vertex(mut self, binary: &[u8]) -> Self {
168 match self.source {
169 GraphicsShaderSource::BinarySplitSource(ref mut vertex_bin, _) => {
170 self.source =
171 GraphicsShaderSource::BinarySplitSource(binary.to_vec(), vertex_bin.clone());
172 }
173 _ => {
174 self.source = GraphicsShaderSource::BinarySplitSource(binary.to_vec(), vec![]);
175 }
176 }
177
178 self
179 }
180
181 pub fn set_binary_fragment(mut self, binary: &[u8]) -> Self {
185 match self.source {
186 GraphicsShaderSource::BinarySplitSource(_, ref mut fragment_bin) => {
187 self.source =
188 GraphicsShaderSource::BinarySplitSource(fragment_bin.clone(), binary.to_vec());
189 }
190 _ => {
191 self.source = GraphicsShaderSource::BinarySplitSource(vec![], binary.to_vec());
192 }
193 }
194
195 self
196 }
197
198 pub fn build(self) -> Result<GraphicsShader, String> {
199 GraphicsShader::new(self.graphics, self.source)
200 }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub enum GraphicsShaderType {
205 GraphicsSingle {
206 module: wgpu::ShaderModule,
207 },
208 GraphicsSplit {
209 vertex_module: wgpu::ShaderModule,
210 fragment_module: wgpu::ShaderModule,
211 },
212}
213
214impl Hash for GraphicsShaderType {
215 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
216 match self {
217 GraphicsShaderType::GraphicsSingle { module } => {
218 module.hash(state);
219 }
220 GraphicsShaderType::GraphicsSplit {
221 vertex_module,
222 fragment_module,
223 } => {
224 vertex_module.hash(state);
225 fragment_module.hash(state);
226 }
227 }
228 }
229}
230
231#[derive(Clone, Debug, Hash)]
232pub(crate) struct GraphicsShaderInner {
233 pub ty: GraphicsShaderType,
234 pub reflection: Vec<ShaderReflect>,
235
236 pub bind_group_layouts: Vec<BindGroupLayout>,
237}
238
239impl PartialEq for GraphicsShaderInner {
240 fn eq(&self, other: &Self) -> bool {
241 let ty_equal = self.ty == other.ty;
242
243 let reflection_equal = self.reflection.len() == other.reflection.len()
244 && self
245 .reflection
246 .iter()
247 .zip(&other.reflection)
248 .all(|(a, b)| a == b);
249
250 let layouts_equal = self.bind_group_layouts.len() == other.bind_group_layouts.len()
251 && self
252 .bind_group_layouts
253 .iter()
254 .zip(&other.bind_group_layouts)
255 .all(|(a, b)| {
256 a.group == b.group && a.bindings == b.bindings && a.layout == b.layout
257 });
258
259 ty_equal && reflection_equal && layouts_equal
260 }
261}
262
263#[derive(Clone, Debug, Eq, Hash)]
264pub(crate) struct VertexInputDescription {
265 pub index: Option<IndexBufferSize>,
266 pub topology: ShaderTopology,
267 pub cull_mode: Option<ShaderCullMode>,
268 pub polygon_mode: ShaderPollygonMode,
269 pub front_face: ShaderFrontFace,
270 pub stride: wgpu::BufferAddress,
271 pub attributes: Vec<wgpu::VertexAttribute>,
272}
273
274impl PartialEq for VertexInputDescription {
275 fn eq(&self, other: &Self) -> bool {
276 self.index == other.index
277 && self.topology == other.topology
278 && self.cull_mode == other.cull_mode
279 && self.polygon_mode == other.polygon_mode
280 && self.front_face == other.front_face
281 && self.stride == other.stride
282 && self.attributes == other.attributes
283 }
284}
285
286#[derive(Clone, Debug, Eq)]
287#[allow(unused)]
288pub struct GraphicsShader {
289 pub(crate) graphics: ArcRef<GPUInner>,
290 pub(crate) inner: ArcRef<GraphicsShaderInner>,
291
292 pub(crate) attrib: ArcRef<VertexInputDescription>,
293}
294
295impl GraphicsShader {
296 pub(crate) fn new(
297 graphics: ArcRef<GPUInner>,
298 wgls_data: GraphicsShaderSource,
299 ) -> Result<Self, String> {
300 let graphics_ref = graphics.borrow();
301 let device_ref = graphics_ref.device.as_ref().ok_or("Missing device")?;
302
303 fn create_vertex_input_attrib(input: &VertexInputReflection) -> Vec<wgpu::VertexAttribute> {
304 input
305 .attributes
306 .iter()
307 .map(|(location, offset, vtype)| wgpu::VertexAttribute {
308 format: vtype.clone().into(),
309 offset: *offset as wgpu::BufferAddress,
310 shader_location: *location,
311 })
312 .collect()
313 }
314
315 fn create_input_desc(reflection: &ShaderReflect) -> Result<VertexInputDescription, String> {
316 let (vertex_input, stride) = match reflection {
317 ShaderReflect::Vertex { input, .. }
318 | ShaderReflect::VertexFragment {
319 vertex_input: input,
320 ..
321 } => {
322 let input = input.as_ref().ok_or("Missing vertex input")?;
323 (input, input.stride as wgpu::BufferAddress)
324 }
325 _ => return Err("Invalid shader type for vertex input".to_string()),
326 };
327
328 let attributes = create_vertex_input_attrib(vertex_input);
329 Ok(VertexInputDescription {
330 index: Some(IndexBufferSize::U16),
331 stride,
332 attributes,
333 topology: ShaderTopology::TriangleList,
334 cull_mode: None,
335 polygon_mode: ShaderPollygonMode::Fill,
336 front_face: ShaderFrontFace::Clockwise,
337 })
338 }
339
340 fn build_single_shader(
341 device: &wgpu::Device,
342 source: &str,
343 ) -> Result<(wgpu::ShaderModule, ShaderReflect), String> {
344 let module = wgsl::parse_str(source).map_err(|e| format!("Parse error: {e:?}"))?;
345 let reflection = super::reflection::parse(module).map_err(|e| format!("Reflect error: {e:?}"))?;
346 Ok((
347 device.create_shader_module(wgpu::ShaderModuleDescriptor {
348 label: None,
349 source: wgpu::ShaderSource::Wgsl(source.into()),
350 }),
351 reflection,
352 ))
353 }
354
355 fn build_binary_shader(
356 device: &wgpu::Device,
357 binary: &[u8],
358 ) -> Result<(wgpu::ShaderModule, ShaderReflect), String> {
359 let binary_shader = super::reflection::load_binary_shader(binary)
360 .map_err(|e| format!("Binary load error: {e:?}"))?;
361 let spirv_u32 = Cow::Borrowed(bytemuck::cast_slice(&binary_shader.spirv));
362 Ok((
363 unsafe {
366 let desc = wgpu::ShaderModuleDescriptor {
367 label: None,
368 source: wgpu::ShaderSource::SpirV(spirv_u32),
369 };
370
371 let runtime_checks = ShaderRuntimeChecks {
372 bounds_checks: true,
373 force_loop_bounding: false,
374 };
375
376 device.create_shader_module_trusted(desc, runtime_checks)
377 },
378 binary_shader.reflect,
379 ))
380 }
381
382 match wgls_data {
383 GraphicsShaderSource::None => Err("No shader source provided".to_string()),
384
385 GraphicsShaderSource::Source(source) => {
386 let (module, reflection) = build_single_shader(device_ref, &source)?;
387 match reflection {
388 ShaderReflect::VertexFragment { .. } => {
389 let layout = Self::make_group_layout(device_ref, &[reflection.clone()]);
390 let input_desc = create_input_desc(&reflection)?;
391 Ok(Self {
392 graphics: ArcRef::clone(&graphics),
393 inner: ArcRef::new(GraphicsShaderInner {
394 ty: GraphicsShaderType::GraphicsSingle { module },
395 reflection: vec![reflection],
396 bind_group_layouts: layout,
397 }),
398 attrib: ArcRef::new(input_desc),
399 })
400 }
401 _ => Err("Shader source is not VertexFragment shader!".to_string()),
402 }
403 }
404
405 GraphicsShaderSource::SplitSource(vertex_src, fragment_src) => {
406 let (vertex_module, vertex_reflect) = build_single_shader(device_ref, &vertex_src)?;
407 let (fragment_module, fragment_reflect) =
408 build_single_shader(device_ref, &fragment_src)?;
409
410 match (&vertex_reflect, &fragment_reflect) {
411 (ShaderReflect::Vertex { .. }, ShaderReflect::Fragment { .. }) => {
412 let layout = Self::make_group_layout(
413 device_ref,
414 &[vertex_reflect.clone(), fragment_reflect.clone()],
415 );
416 let input_desc = create_input_desc(&vertex_reflect)?;
417 Ok(Self {
418 graphics: ArcRef::clone(&graphics),
419 inner: ArcRef::new(GraphicsShaderInner {
420 ty: GraphicsShaderType::GraphicsSplit {
421 vertex_module,
422 fragment_module,
423 },
424 reflection: vec![vertex_reflect, fragment_reflect],
425 bind_group_layouts: layout,
426 }),
427 attrib: ArcRef::new(input_desc),
428 })
429 }
430 _ => Err("Invalid shader pair for SplitSource".to_string()),
431 }
432 }
433
434 GraphicsShaderSource::BinarySource(binary) => {
435 let (module, reflection) = build_binary_shader(device_ref, &binary)?;
436 match reflection {
437 ShaderReflect::VertexFragment { .. } => {
438 let layout = Self::make_group_layout(device_ref, &[reflection.clone()]);
439 let input_desc = create_input_desc(&reflection)?;
440 Ok(Self {
441 graphics: ArcRef::clone(&graphics),
442 inner: ArcRef::new(GraphicsShaderInner {
443 ty: GraphicsShaderType::GraphicsSingle { module },
444 reflection: vec![reflection],
445 bind_group_layouts: layout,
446 }),
447 attrib: ArcRef::new(input_desc),
448 })
449 }
450 _ => Err("Binary shader is not VertexFragment shader!".to_string()),
451 }
452 }
453
454 GraphicsShaderSource::BinarySplitSource(vertex_bin, fragment_bin) => {
455 let (vertex_module, vertex_reflect) = build_binary_shader(device_ref, &vertex_bin)?;
456 let (fragment_module, fragment_reflect) =
457 build_binary_shader(device_ref, &fragment_bin)?;
458
459 match (&vertex_reflect, &fragment_reflect) {
460 (ShaderReflect::Vertex { .. }, ShaderReflect::Fragment { .. }) => {
461 let layout = Self::make_group_layout(
462 device_ref,
463 &[vertex_reflect.clone(), fragment_reflect.clone()],
464 );
465 let input_desc = create_input_desc(&vertex_reflect)?;
466 Ok(Self {
467 graphics: ArcRef::clone(&graphics),
468 inner: ArcRef::new(GraphicsShaderInner {
469 ty: GraphicsShaderType::GraphicsSplit {
470 vertex_module,
471 fragment_module,
472 },
473 reflection: vec![vertex_reflect, fragment_reflect],
474 bind_group_layouts: layout,
475 }),
476 attrib: ArcRef::new(input_desc),
477 })
478 }
479 _ => Err("Invalid binary shader pair for BinarySplitSource".to_string()),
480 }
481 }
482 }
483 }
484
485 fn make_group_layout(
486 device: &wgpu::Device,
487 reflects: &[ShaderReflect],
488 ) -> Vec<BindGroupLayout> {
489 let mut layouts: HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>> = HashMap::new();
490
491 fn find_existing(
492 layouts: &mut HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>>,
493 group: u32,
494 binding: u32,
495 _ty: wgpu::BindingType,
496 ) -> Option<&mut wgpu::BindGroupLayoutEntry> {
497 layouts.get_mut(&group).and_then(|entries| {
498 entries
499 .iter_mut()
500 .find(|entry| entry.binding == binding && matches!(entry.ty, _ty))
501 })
502 }
503
504 fn create_layout_ty(ty: ShaderBindingType) -> wgpu::BindingType {
505 match ty {
506 ShaderBindingType::UniformBuffer(size) => BindingType::Buffer {
507 ty: wgpu::BufferBindingType::Uniform,
508 has_dynamic_offset: false,
509 min_binding_size: if size == u32::MAX {
510 None
511 } else {
512 wgpu::BufferSize::new(size as u64)
513 },
514 },
515 ShaderBindingType::Texture(multisampled) => BindingType::Texture {
516 sample_type: wgpu::TextureSampleType::Float { filterable: true },
517 view_dimension: wgpu::TextureViewDimension::D2,
518 multisampled,
519 },
520 ShaderBindingType::Sampler(comparison) => BindingType::Sampler(if comparison {
521 SamplerBindingType::Comparison
522 } else {
523 SamplerBindingType::Filtering
524 }),
525 ShaderBindingType::StorageBuffer(size, access) => BindingType::Buffer {
526 ty: wgpu::BufferBindingType::Storage {
527 read_only: access.contains(StorageAccess::READ)
528 && !access.contains(StorageAccess::WRITE),
529 },
530 has_dynamic_offset: false,
531 min_binding_size: if size == u32::MAX {
532 None
533 } else {
534 wgpu::BufferSize::new(size as u64)
535 },
536 },
537 ShaderBindingType::StorageTexture(access) => BindingType::StorageTexture {
538 access: if access.contains(StorageAccess::READ)
539 && access.contains(StorageAccess::WRITE)
540 {
541 wgpu::StorageTextureAccess::ReadWrite
542 } else if access.contains(StorageAccess::READ) {
543 wgpu::StorageTextureAccess::ReadOnly
544 } else if access.contains(StorageAccess::WRITE) {
545 wgpu::StorageTextureAccess::WriteOnly
546 } else if access.contains(StorageAccess::ATOMIC) {
547 wgpu::StorageTextureAccess::Atomic
548 } else {
549 panic!("Invalid storage texture access")
550 },
551 format: wgpu::TextureFormat::Rgba8Unorm,
552 view_dimension: wgpu::TextureViewDimension::D2,
553 },
554 _ => unreachable!(),
555 }
556 }
557
558 for reflect in reflects {
559 match reflect {
560 ShaderReflect::Vertex { bindings, .. } => {
561 for binding in bindings.iter() {
562 let ty = create_layout_ty(binding.ty.clone());
563 let existing =
564 find_existing(&mut layouts, binding.group, binding.binding, ty);
565 if let Some(existing) = existing {
566 existing.visibility |= ShaderStages::VERTEX;
567 crate::dbg_log!(
568 "BindGroupLayout: group {}, binding: {}, ty: {:?} (existing)",
569 binding.group,
570 binding.binding,
571 binding.ty
572 );
573 } else {
574 let layout_desc = wgpu::BindGroupLayoutEntry {
576 ty,
577 binding: binding.binding,
578 visibility: ShaderStages::VERTEX,
579 count: None,
580 };
581
582 let group = layouts.entry(binding.group).or_insert_with(Vec::new);
583
584 crate::dbg_log!(
585 "BindGroupLayout: group {}, binding: {}, ty: {:?}",
586 binding.group,
587 binding.binding,
588 binding.ty
589 );
590 group.push(layout_desc);
591 }
592 }
593 }
594 ShaderReflect::Fragment { bindings, .. } => {
595 for binding in bindings.iter() {
596 let ty = create_layout_ty(binding.ty.clone());
597 let existing =
598 find_existing(&mut layouts, binding.group, binding.binding, ty);
599 if let Some(existing) = existing {
600 existing.visibility |= ShaderStages::FRAGMENT;
601 crate::dbg_log!(
602 "BindGroupLayout: group {}, binding: {}, ty: {:?} (existing)",
603 binding.group,
604 binding.binding,
605 binding.ty
606 );
607 } else {
608 let layout_desc = wgpu::BindGroupLayoutEntry {
610 ty,
611 binding: binding.binding,
612 visibility: ShaderStages::FRAGMENT,
613 count: None,
614 };
615
616 let group = layouts.entry(binding.group).or_insert_with(Vec::new);
617
618 crate::dbg_log!(
619 "BindGroupLayout: group {}, binding: {}, ty: {:?}",
620 binding.group,
621 binding.binding,
622 binding.ty
623 );
624 group.push(layout_desc);
625 }
626 }
627 }
628 ShaderReflect::VertexFragment { bindings, .. } => {
629 for binding in bindings.iter() {
630 let ty = create_layout_ty(binding.ty.clone());
631
632 let layout_desc = wgpu::BindGroupLayoutEntry {
634 ty,
635 binding: binding.binding,
636 visibility: ShaderStages::VERTEX_FRAGMENT,
637 count: None,
638 };
639
640 let group = layouts.entry(binding.group).or_insert_with(Vec::new);
641
642 crate::dbg_log!(
643 "BindGroupLayout: group {}, binding: {}, ty: {:?}",
644 binding.group,
645 binding.binding,
646 binding.ty
647 );
648 group.push(layout_desc);
649 }
650 }
651 _ => continue,
652 }
653 }
654
655 let mut layout_vec = layouts.into_iter().collect::<Vec<_>>();
656 layout_vec.sort_by_key(|(group, _)| *group);
657 layout_vec
658 .into_iter()
659 .map(|(group, layout)| {
660 let label = if !layout.is_empty() {
662 let mut s = format!("BindGroupLayout for group {}, binding: ", group);
663 for (i, entry) in layout.iter().enumerate() {
664 s.push_str(&entry.binding.to_string());
665 if i != layout.len() - 1 {
666 s.push_str(", ");
667 }
668 }
669 Some(s)
670 } else {
671 None
672 };
673
674 let bind_group_layout =
675 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
676 label: label.as_deref(),
677 entries: &layout,
678 });
679
680 crate::dbg_log!(
681 "Created BindGroupLayout for group {} with {} entries",
682 group,
683 layout.len()
684 );
685
686 BindGroupLayout {
687 group,
688 bindings: layout.iter().map(|entry| entry.binding).collect(),
689 layout: bind_group_layout,
690 }
691 })
692 .collect()
693 }
694
695 pub fn get_uniform_location(&self, name: &str) -> Option<(u32, u32)> {
696 let inner = self.inner.borrow();
697
698 let reflection = &inner.reflection;
699 for reflect in reflection.iter() {
700 match reflect {
701 ShaderReflect::Vertex { bindings, .. } => {
702 if let Some(binding) = bindings.iter().find(|b| {
703 b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
704 }) {
705 return Some((binding.group, binding.binding));
706 }
707 }
708 ShaderReflect::Fragment { bindings, .. } => {
709 if let Some(binding) = bindings.iter().find(|b| {
710 b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
711 }) {
712 return Some((binding.group, binding.binding));
713 }
714 }
715 ShaderReflect::VertexFragment { bindings, .. } => {
716 if let Some(binding) = bindings.iter().find(|b| {
717 b.name == name && matches!(b.ty, ShaderBindingType::UniformBuffer(_))
718 }) {
719 return Some((binding.group, binding.binding));
720 }
721 }
722 _ => continue,
723 }
724 }
725
726 None
727 }
728
729 pub fn get_uniform_size(&self, group: u32, binding: u32) -> Option<u32> {
730 let inner = self.inner.borrow();
731
732 let reflection = &inner.reflection;
733 for reflect in reflection.iter() {
734 match reflect {
735 ShaderReflect::Vertex { bindings, .. } => {
736 if let Some(binding) = bindings
737 .iter()
738 .find(|b| b.group == group && b.binding == binding)
739 {
740 if let ShaderBindingType::UniformBuffer(size) = binding.ty {
741 return Some(size);
742 }
743 }
744 }
745 ShaderReflect::Fragment { bindings, .. } => {
746 if let Some(binding) = bindings
747 .iter()
748 .find(|b| b.group == group && b.binding == binding)
749 {
750 if let ShaderBindingType::UniformBuffer(size) = binding.ty {
751 return Some(size);
752 }
753 }
754 }
755 ShaderReflect::VertexFragment { bindings, .. } => {
756 if let Some(binding) = bindings
757 .iter()
758 .find(|b| b.group == group && b.binding == binding)
759 {
760 if let ShaderBindingType::UniformBuffer(size) = binding.ty {
761 return Some(size);
762 }
763 }
764 }
765 _ => continue,
766 }
767 }
768
769 None
770 }
771
772 pub fn set_topology(&mut self, topology: ShaderTopology) -> Result<(), String> {
773 self.attrib.borrow_mut().topology = topology;
774 Ok(())
775 }
776
777 pub fn set_cull_mode(&mut self, cull_mode: Option<ShaderCullMode>) -> Result<(), String> {
778 self.attrib.borrow_mut().cull_mode = cull_mode;
779 Ok(())
780 }
781
782 pub fn set_polygon_mode(&mut self, polygon_mode: ShaderPollygonMode) -> Result<(), String> {
783 self.attrib.borrow_mut().polygon_mode = polygon_mode;
784 Ok(())
785 }
786
787 pub fn set_front_face(&mut self, front_face: ShaderFrontFace) -> Result<(), String> {
788 self.attrib.borrow_mut().front_face = front_face;
789 Ok(())
790 }
791
792 pub fn set_vertex_index_ty(&mut self, index_ty: Option<IndexBufferSize>) -> Result<(), String> {
793 self.attrib.borrow_mut().index = index_ty;
794 Ok(())
795 }
796
797 pub fn set_vertex_input(
798 &mut self,
799 location: u32,
800 vtype: VertexInputType,
801 ) -> Result<(), String> {
802 let inner = self.inner.borrow_mut();
803
804 let vertex_input = match inner.reflection.first() {
805 Some(ShaderReflect::Vertex { input, .. }) => input.as_ref(),
806 Some(ShaderReflect::VertexFragment { vertex_input, .. }) => vertex_input.as_ref(),
807 _ => None,
808 };
809
810 if vertex_input.is_none() {
811 return Err("Shader does not have vertex input".to_string());
812 }
813
814 let vertex_input = vertex_input.unwrap();
815
816 let input = vertex_input
817 .attributes
818 .iter()
819 .find(|attr| attr.0 == location);
820 if input.is_none() {
821 return Err(format!("Vertex input location {} not found", location));
822 }
823
824 let (location, _offset, og_vtype) = input.unwrap();
825 if !is_format_conversion_supported(*og_vtype, vtype) {
826 return Err(format!(
827 "Vertex input type {:?} is not supported for location {}",
828 vtype, location
829 ));
830 }
831
832 let mut attrib = self.attrib.borrow_mut();
833 let vertex_input_attrib = attrib
834 .attributes
835 .iter_mut()
836 .find(|attr| attr.shader_location == *location);
837
838 if vertex_input_attrib.is_none() {
839 return Err(format!(
840 "Vertex input location {} not found in shader attributes",
841 location
842 ));
843 }
844
845 let vertex_input_attrib = vertex_input_attrib.unwrap();
846 vertex_input_attrib.format = vtype.into();
847
848 Ok(())
849 }
850}
851
852impl std::hash::Hash for GraphicsShader {
853 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
854 ArcRef::as_ptr(&self.graphics).hash(state);
855 self.inner.hash(state);
856 self.attrib.hash(state);
857 }
858}
859
860#[inline]
863fn is_format_conversion_supported(origin: VertexInputType, target: VertexInputType) -> bool {
864 match origin {
865 VertexInputType::Float32 => match target {
866 VertexInputType::Float32 => true,
867 VertexInputType::Snorm8 => true,
868 VertexInputType::Unorm8 => true,
869 VertexInputType::Snorm16 => true,
870 _ => false,
871 },
872 VertexInputType::Float32x2 => match target {
873 VertexInputType::Float32x2 => true,
874 VertexInputType::Snorm8x2 => true,
875 VertexInputType::Unorm8x2 => true,
876 VertexInputType::Snorm16x2 => true,
877 _ => false,
878 },
879 VertexInputType::Float32x3 => {
880 match target {
881 VertexInputType::Float32x3 => true,
882 _ => false,
884 }
885 }
886 VertexInputType::Float32x4 => match target {
887 VertexInputType::Float32x4 => true,
888 VertexInputType::Snorm8x4 => true,
889 VertexInputType::Unorm8x4 => true,
890 VertexInputType::Snorm16x4 => true,
891 _ => false,
892 },
893 VertexInputType::Uint32 => match target {
894 VertexInputType::Uint32 => true,
895 VertexInputType::Uint16 => true,
896 VertexInputType::Uint8 => true,
897 _ => false,
898 },
899 VertexInputType::Uint32x2 => match target {
900 VertexInputType::Uint32x2 => true,
901 VertexInputType::Uint16x2 => true,
902 VertexInputType::Uint8x2 => true,
903 _ => false,
904 },
905 VertexInputType::Uint32x3 => match target {
906 VertexInputType::Uint32x3 => true,
907 VertexInputType::Uint16x4 => true,
908 VertexInputType::Uint8x4 => true,
909 _ => false,
910 },
911 VertexInputType::Uint32x4 => match target {
912 VertexInputType::Uint32x4 => true,
913 VertexInputType::Uint16x4 => true,
914 VertexInputType::Uint8x4 => true,
915 _ => false,
916 },
917 _ => origin == target,
918 }
919}
920
921impl PartialEq for GraphicsShader {
941 fn eq(&self, other: &Self) -> bool {
942 ArcRef::ptr_eq(&self.graphics, &other.graphics)
943 && ArcRef::ptr_eq(&self.inner, &other.inner)
944 && ArcRef::ptr_eq(&self.attrib, &other.attrib)
945 }
946}