est_render/gpu/shader/
compute.rs1use std::collections::HashMap;
2
3use wgpu::{BindingType, SamplerBindingType, naga::front::wgsl};
4
5use crate::utils::ArcRef;
6use super::{
7 super::GPUInner,
8 types::{
9 ShaderReflect, BindGroupLayout,
10 ShaderBindingType, StorageAccess,
11 }
12};
13
14pub struct ComputeShaderBuilder {
15 pub(crate) graphics: ArcRef<GPUInner>,
16 pub(crate) wgls_data: String,
17}
18
19impl ComputeShaderBuilder {
20 pub(crate) fn new(graphics: ArcRef<GPUInner>) -> Self {
21 Self {
22 graphics,
23 wgls_data: String::new(),
24 }
25 }
26
27 pub fn set_file(mut self, path: &str) -> Self {
28 let data = std::fs::read_to_string(path);
29 if let Err(err) = data {
30 panic!("Failed to read shader file: {:?}", err);
31 }
32
33 self.wgls_data = data.unwrap();
34 self
35 }
36
37 pub fn set_source(mut self, source: &str) -> Self {
38 self.wgls_data = source.to_string();
39 self
40 }
41
42 pub fn build(self) -> Result<ComputeShader, String> {
43 ComputeShader::new(self.graphics, &self.wgls_data)
44 }
45}
46
47pub(crate) struct ComputeShaderInner {
48 pub shader: wgpu::ShaderModule,
49 pub reflection: ShaderReflect,
50
51 pub bind_group_layouts: Vec<BindGroupLayout>,
52}
53
54#[allow(unused)]
55#[derive(Clone, Debug)]
56pub struct ComputeShader {
57 pub(crate) graphics: ArcRef<GPUInner>,
58 pub(crate) inner: ArcRef<ComputeShaderInner>,
59}
60
61impl ComputeShader {
62 pub(crate) fn new(graphics: ArcRef<GPUInner>, wgls_data: &str) -> Result<Self, String> {
63 if graphics.borrow().is_invalid {
64 panic!("Graphics context is invalid");
65 }
66
67 let module = wgsl::parse_str(wgls_data);
68 if let Err(err) = module {
69 return Err(format!("Failed to parse shader: {:?}", err));
70 }
71
72 let module = module.unwrap();
73 let reflect = super::reflection::parse(module);
74
75 if reflect.is_err() {
76 return Err(format!("Failed to reflect shader: {:?}", reflect.err()));
77 }
78
79 let reflect = reflect.unwrap();
80
81 let graphics_ref = graphics.borrow();
82 let device_ref = graphics_ref.device();
83
84 let shader = device_ref.create_shader_module(wgpu::ShaderModuleDescriptor {
85 label: None,
86 source: wgpu::ShaderSource::Wgsl(wgls_data.into()),
87 });
88
89 let bind_group_layouts = Self::make_group_layout(device_ref, &[reflect.clone()]);
90
91 let inner = ComputeShaderInner {
92 shader,
93 reflection: reflect,
94 bind_group_layouts,
95 };
96
97 Ok(Self {
98 graphics: ArcRef::clone(&graphics),
99 inner: ArcRef::new(inner),
100 })
101 }
102
103 fn create_layout_ty(ty: ShaderBindingType) -> wgpu::BindingType {
104 match ty {
105 ShaderBindingType::UniformBuffer(size) => BindingType::Buffer {
106 ty: wgpu::BufferBindingType::Uniform,
107 has_dynamic_offset: false,
108 min_binding_size: if size == u32::MAX {
109 None
110 } else {
111 wgpu::BufferSize::new(size as u64)
112 },
113 },
114 ShaderBindingType::Texture(multisampled) => BindingType::Texture {
115 sample_type: wgpu::TextureSampleType::Float { filterable: true },
116 view_dimension: wgpu::TextureViewDimension::D2,
117 multisampled,
118 },
119 ShaderBindingType::Sampler(comparison) => BindingType::Sampler(if comparison {
120 SamplerBindingType::Comparison
121 } else {
122 SamplerBindingType::Filtering
123 }),
124 ShaderBindingType::StorageBuffer(size, access) => BindingType::Buffer {
125 ty: wgpu::BufferBindingType::Storage {
126 read_only: access.contains(StorageAccess::READ)
127 && !access.contains(StorageAccess::WRITE),
128 },
129 has_dynamic_offset: false,
130 min_binding_size: if size == u32::MAX {
131 None
132 } else {
133 wgpu::BufferSize::new(size as u64)
134 },
135 },
136 ShaderBindingType::StorageTexture(access) => BindingType::StorageTexture {
137 access: if access.contains(StorageAccess::READ)
138 && access.contains(StorageAccess::WRITE)
139 {
140 wgpu::StorageTextureAccess::ReadWrite
141 } else if access.contains(StorageAccess::READ) {
142 wgpu::StorageTextureAccess::ReadOnly
143 } else if access.contains(StorageAccess::WRITE) {
144 wgpu::StorageTextureAccess::WriteOnly
145 } else if access.contains(StorageAccess::ATOMIC) {
146 wgpu::StorageTextureAccess::Atomic
147 } else {
148 panic!("Invalid storage texture access")
149 },
150 format: wgpu::TextureFormat::Rgba8Unorm,
151 view_dimension: wgpu::TextureViewDimension::D2,
152 },
153 _ => unreachable!(),
154 }
155 }
156
157 fn make_group_layout(
158 device: &wgpu::Device,
159 reflects: &[ShaderReflect],
160 ) -> Vec<BindGroupLayout> {
161 let mut layouts: HashMap<u32, Vec<wgpu::BindGroupLayoutEntry>> = HashMap::new();
162
163 for reflect in reflects {
164 match reflect {
165 ShaderReflect::Compute { bindings, .. } => {
166 for binding in bindings.iter() {
167 let ty = Self::create_layout_ty(binding.ty.clone());
168
169 let layout_desc = wgpu::BindGroupLayoutEntry {
171 ty,
172 binding: binding.binding,
173 visibility: wgpu::ShaderStages::COMPUTE,
174 count: None,
175 };
176
177 let group = layouts.entry(binding.group).or_insert_with(Vec::new);
178
179 group.push(layout_desc);
180 }
181 }
182 _ => continue,
183 }
184 }
185
186 layouts
187 .into_iter()
188 .map(|(group, layout)| {
189 let label = if !layout.is_empty() {
191 let mut s = format!("BindGroupLayout for group {}, binding: ", group);
192 for (i, entry) in layout.iter().enumerate() {
193 s.push_str(&entry.binding.to_string());
194 if i != layout.len() - 1 {
195 s.push_str(", ");
196 }
197 }
198 Some(s)
199 } else {
200 None
201 };
202
203 let bind_group_layout =
204 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
205 label: label.as_deref(),
206 entries: &layout,
207 });
208
209 BindGroupLayout {
210 group,
211 bindings: layout.iter().map(|entry| entry.binding).collect(),
212 layout: bind_group_layout,
213 }
214 })
215 .collect()
216 }
217
218 pub fn get_uniform_location(&self, name: &str) -> Option<(u32, u32)> {
219 let reflection = self.inner.borrow().reflection.clone();
220 match reflection {
221 ShaderReflect::Compute { bindings, .. } => bindings.iter().find_map(|binding| {
222 if binding.name == name && matches!(binding.ty, ShaderBindingType::UniformBuffer(_))
223 {
224 Some((binding.group, binding.binding))
225 } else {
226 None
227 }
228 }),
229 _ => None,
230 }
231 }
232}