est_render/gpu/shader/
compute.rs

1use std::collections::HashMap;
2
3use wgpu::{BindingType, SamplerBindingType, naga::front::wgsl};
4
5use crate::utils::ArcRef;
6use super::{
7    super::GPUInner,
8    types::{
9        ShaderReflect, BindGroupLayout,
10        ShaderBindingType, StorageAccess,
11    }
12};
13
14pub struct ComputeShaderBuilder {
15    pub(crate) graphics: ArcRef<GPUInner>,
16    pub(crate) wgls_data: String,
17}
18
19impl ComputeShaderBuilder {
20    pub(crate) fn new(graphics: ArcRef<GPUInner>) -> Self {
21        Self {
22            graphics,
23            wgls_data: String::new(),
24        }
25    }
26
27    pub fn set_file(mut self, path: &str) -> Self {
28        let data = std::fs::read_to_string(path);
29        if let Err(err) = data {
30            panic!("Failed to read shader file: {:?}", err);
31        }
32
33        self.wgls_data = data.unwrap();
34        self
35    }
36
37    pub fn set_source(mut self, source: &str) -> Self {
38        self.wgls_data = source.to_string();
39        self
40    }
41
42    pub fn build(self) -> Result<ComputeShader, String> {
43        ComputeShader::new(self.graphics, &self.wgls_data)
44    }
45}
46
47pub(crate) struct ComputeShaderInner {
48    pub shader: wgpu::ShaderModule,
49    pub reflection: ShaderReflect,
50
51    pub bind_group_layouts: Vec<BindGroupLayout>,
52}
53
54#[allow(unused)]
55#[derive(Clone, Debug)]
56pub struct ComputeShader {
57    pub(crate) graphics: ArcRef<GPUInner>,
58    pub(crate) inner: ArcRef<ComputeShaderInner>,
59}
60
61impl ComputeShader {
62    pub(crate) fn new(graphics: ArcRef<GPUInner>, wgls_data: &str) -> Result<Self, String> {
63        if graphics.borrow().is_invalid {
64            panic!("Graphics context is invalid");
65        }
66
67        let module = wgsl::parse_str(wgls_data);
68        if let Err(err) = module {
69            return Err(format!("Failed to parse shader: {:?}", err));
70        }
71
72        let module = module.unwrap();
73        let reflect = super::reflection::parse(module);
74
75        if reflect.is_err() {
76            return Err(format!("Failed to reflect shader: {:?}", reflect.err()));
77        }
78
79        let reflect = reflect.unwrap();
80
81        let graphics_ref = graphics.borrow();
82        let device_ref = graphics_ref.device();
83
84        let shader = device_ref.create_shader_module(wgpu::ShaderModuleDescriptor {
85            label: None,
86            source: wgpu::ShaderSource::Wgsl(wgls_data.into()),
87        });
88
89        let bind_group_layouts = Self::make_group_layout(device_ref, &[reflect.clone()]);
90
91        let inner = ComputeShaderInner {
92            shader,
93            reflection: reflect,
94            bind_group_layouts,
95        };
96
97        Ok(Self {
98            graphics: ArcRef::clone(&graphics),
99            inner: ArcRef::new(inner),
100        })
101    }
102
103    fn create_layout_ty(ty: ShaderBindingType) -> wgpu::BindingType {
104        match ty {
105            ShaderBindingType::UniformBuffer(size) => BindingType::Buffer {
106                ty: wgpu::BufferBindingType::Uniform,
107                has_dynamic_offset: false,
108                min_binding_size: if size == u32::MAX {
109                    None
110                } else {
111                    wgpu::BufferSize::new(size as u64)
112                },
113            },
114            ShaderBindingType::Texture(multisampled) => BindingType::Texture {
115                sample_type: wgpu::TextureSampleType::Float { filterable: true },
116                view_dimension: wgpu::TextureViewDimension::D2,
117                multisampled,
118            },
119            ShaderBindingType::Sampler(comparison) => BindingType::Sampler(if comparison {
120                SamplerBindingType::Comparison
121            } else {
122                SamplerBindingType::Filtering
123            }),
124            ShaderBindingType::StorageBuffer(size, access) => BindingType::Buffer {
125                ty: wgpu::BufferBindingType::Storage {
126                    read_only: access.contains(StorageAccess::READ)
127                        && !access.contains(StorageAccess::WRITE),
128                },
129                has_dynamic_offset: false,
130                min_binding_size: if size == u32::MAX {
131                    None
132                } else {
133                    wgpu::BufferSize::new(size as u64)
134                },
135            },
136            ShaderBindingType::StorageTexture(access) => BindingType::StorageTexture {
137                access: if access.contains(StorageAccess::READ)
138                    && access.contains(StorageAccess::WRITE)
139                {
140                    wgpu::StorageTextureAccess::ReadWrite
141                } else if access.contains(StorageAccess::READ) {
142                    wgpu::StorageTextureAccess::ReadOnly
143                } else if access.contains(StorageAccess::WRITE) {
144                    wgpu::StorageTextureAccess::WriteOnly
145                } else if access.contains(StorageAccess::ATOMIC) {
146                    wgpu::StorageTextureAccess::Atomic
147                } else {
148                    panic!("Invalid storage texture access")
149                },
150                format: wgpu::TextureFormat::Rgba8Unorm,
151                view_dimension: wgpu::TextureViewDimension::D2,
152            },
153            _ => unreachable!(),
154        }
155    }
156
157    fn make_group_layout(
158        device: &wgpu::Device,
159        reflects: &[ShaderReflect],
160    ) -> Vec<BindGroupLayout> {
161        let mut layouts: HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>> = HashMap::new();
162
163        for reflect in reflects {
164            match reflect {
165                ShaderReflect::Compute { bindings, .. } => {
166                    for binding in bindings.iter() {
167                        let ty = Self::create_layout_ty(binding.ty.clone());
168
169                        // Push new layout entry
170                        let layout_desc = wgpu::BindGroupLayoutEntry {
171                            ty,
172                            binding: binding.binding,
173                            visibility: wgpu::ShaderStages::COMPUTE,
174                            count: None,
175                        };
176
177                        let group = layouts.entry(binding.group).or_insert_with(Vec::new);
178
179                        group.push(layout_desc);
180                    }
181                }
182                _ => continue,
183            }
184        }
185
186        layouts
187            .into_iter()
188            .map(|(group, layout)| {
189                // Label: "BindGroupLayout for group {group}, binding: {binding} (ex: 0, 1, 2)"
190                let label = if !layout.is_empty() {
191                    let mut s = format!("BindGroupLayout for group {}, binding: ", group);
192                    for (i, entry) in layout.iter().enumerate() {
193                        s.push_str(&entry.binding.to_string());
194                        if i != layout.len() - 1 {
195                            s.push_str(", ");
196                        }
197                    }
198                    Some(s)
199                } else {
200                    None
201                };
202
203                let bind_group_layout =
204                    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
205                        label: label.as_deref(),
206                        entries: &layout,
207                    });
208
209                BindGroupLayout {
210                    group,
211                    bindings: layout.iter().map(|entry| entry.binding).collect(),
212                    layout: bind_group_layout,
213                }
214            })
215            .collect()
216    }
217
218    pub fn get_uniform_location(&self, name: &str) -> Option<(u32, u32)> {
219        let reflection = self.inner.borrow().reflection.clone();
220        match reflection {
221            ShaderReflect::Compute { bindings, .. } => bindings.iter().find_map(|binding| {
222                if binding.name == name && matches!(binding.ty, ShaderBindingType::UniformBuffer(_))
223                {
224                    Some((binding.group, binding.binding))
225                } else {
226                    None
227                }
228            }),
229            _ => None,
230        }
231    }
232}