1use std::{
2 collections::HashMap,
3 hash::{DefaultHasher, Hash, Hasher},
4};
5
6use crate::utils::ArcRef;
7
8use super::{
9 manager::ComputePipelineDesc,
10 super::{
11 GPUInner,
12 texture::{Texture, TextureSampler},
13 buffer::Buffer,
14 command::{
15 BindGroupAttachment,
16 utils::BindGroupType,
17 computepass::IntermediateComputeBinding
18 },
19 shader::{
20 bind_group_manager::BindGroupCreateInfo,
21 types::{ShaderReflect, ShaderBindingType},
22 compute::ComputeShader,
23 },
24 },
25};
26
27#[derive(Debug, Clone, Hash)]
28pub struct ComputePipeline {
29 pub(crate) bind_group: Vec<(u32, wgpu::BindGroup)>,
30 pub(crate) pipeline_desc: ComputePipelineDesc,
31}
32
33#[derive(Debug, Clone)]
34pub struct ComputePipelineBuilder {
35 pub(crate) gpu: ArcRef<GPUInner>,
36 pub(crate) attachments: Vec<BindGroupAttachment>,
37 pub(crate) shader: Option<IntermediateComputeBinding>,
38 pub(crate) shader_reflection: Option<ShaderReflect>,
39}
40
41impl ComputePipelineBuilder {
42 pub(crate) fn new(gpu: ArcRef<GPUInner>) -> Self {
43 Self {
44 gpu,
45 attachments: Vec::new(),
46 shader: None,
47 shader_reflection: None,
48 }
49 }
50
51 #[inline]
52 pub fn set_shader(mut self, shader: Option<&ComputeShader>) -> Self {
53 match shader {
54 Some(shader) => {
55 let shader_inner = shader.inner.borrow();
56 let shader_module = shader_inner.shader.clone();
57 let layout = shader_inner.bind_group_layouts.clone();
58
59 let shader_reflect = shader_inner.reflection.clone();
60 let entry_point = match &shader_reflect {
61 ShaderReflect::Compute { entry_point, .. } => entry_point.clone(),
62 _ => panic!("Shader must be a compute shader"),
63 };
64
65 let shader_binding = IntermediateComputeBinding {
66 shader: shader_module,
67 entry_point,
68 layout,
69 };
70
71 self.shader = Some(shader_binding);
72 self.shader_reflection = Some(shader_reflect);
73 }
74 None => {
75 self.shader = None;
76 self.shader_reflection = None;
77 }
78 }
79
80 self
81 }
82
83 #[inline]
84 pub fn set_attachment_sampler(
85 mut self,
86 group: u32,
87 binding: u32,
88 sampler: Option<&TextureSampler>,
89 ) -> Self {
90 match sampler {
91 Some(sampler) => {
92 let attachment = {
93 let gpu_inner = self.gpu.borrow();
94
95 BindGroupAttachment {
96 group,
97 binding,
98 attachment: BindGroupType::Sampler(
99 sampler.make_wgpu(gpu_inner.device()),
100 ),
101 }
102 };
103
104 self.insert_or_replace_attachment(group, binding, attachment);
105 }
106 None => {
107 self.remove_attachment(group, binding);
108 }
109 }
110
111 self
112 }
113
114 #[inline]
115 pub fn set_attachment_texture(
116 mut self,
117 group: u32,
118 binding: u32,
119 texture: Option<&Texture>,
120 ) -> Self {
121 match texture {
122 Some(texture) => {
123 let attachment = {
124 BindGroupAttachment {
125 group,
126 binding,
127 attachment: BindGroupType::Texture(
128 texture.inner.borrow().wgpu_view.clone(),
129 ),
130 }
131 };
132
133 self.insert_or_replace_attachment(group, binding, attachment);
134 }
135 None => {
136 self.remove_attachment(group, binding);
137 }
138 }
139
140 self
141 }
142
143 #[inline]
144 pub fn set_attachment_texture_storage(
145 mut self,
146 group: u32,
147 binding: u32,
148 texture: Option<&Texture>,
149 ) -> Self {
150 match texture {
151 Some(texture) => {
152 let inner = texture.inner.borrow();
153 let attachment = BindGroupAttachment {
154 group,
155 binding,
156 attachment: BindGroupType::TextureStorage(inner.wgpu_view.clone()),
157 };
158
159 self.insert_or_replace_attachment(group, binding, attachment);
160 }
161 None => {
162 self.remove_attachment(group, binding);
163 }
164 }
165
166 self
167 }
168
169 #[inline]
170 pub fn set_attachment_uniform(
171 mut self,
172 group: u32,
173 binding: u32,
174 buffer: Option<&Buffer>,
175 ) -> Self {
176 match buffer {
177 Some(buffer) => {
178 let inner = buffer.inner.borrow();
179 let attachment = BindGroupAttachment {
180 group,
181 binding,
182 attachment: BindGroupType::Uniform(inner.buffer.clone()),
183 };
184
185 self.insert_or_replace_attachment(group, binding, attachment);
186 }
187 None => {
188 self.remove_attachment(group, binding);
189 }
190 }
191
192 self
193 }
194
195 #[inline]
196 pub fn set_attachment_uniform_vec<T>(
197 mut self,
198 group: u32,
199 binding: u32,
200 buffer: Option<Vec<T>>,
201 ) -> Self
202 where
203 T: bytemuck::Pod + bytemuck::Zeroable,
204 {
205 match buffer {
206 Some(buffer) => {
207 let attachment = {
208 let mut inner = self.gpu.borrow_mut();
209
210 let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
211 BindGroupAttachment {
212 group,
213 binding,
214 attachment: BindGroupType::Uniform(buffer),
215 }
216 };
217
218 self.insert_or_replace_attachment(group, binding, attachment);
219 }
220 None => {
221 self.remove_attachment(group, binding);
222 }
223 }
224
225 self
226 }
227
228 #[inline]
229 pub fn set_attachment_uniform_raw<T>(
230 mut self,
231 group: u32,
232 binding: u32,
233 buffer: Option<&[T]>,
234 ) -> Self
235 where
236 T: bytemuck::Pod + bytemuck::Zeroable,
237 {
238 match buffer {
239 Some(buffer) => {
240 let mut inner = self.gpu.borrow_mut();
241
242 let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
243 let attachment = BindGroupAttachment {
244 group,
245 binding,
246 attachment: BindGroupType::Uniform(buffer),
247 };
248
249 drop(inner);
250
251 self.insert_or_replace_attachment(group, binding, attachment);
252 }
253 None => {
254 self.remove_attachment(group, binding);
255 }
256 }
257
258 self
259 }
260
261 #[inline]
262 pub fn set_attachment_storage(
263 mut self,
264 group: u32,
265 binding: u32,
266 buffer: Option<&Buffer>,
267 ) -> Self {
268 match buffer {
269 Some(buffer) => {
270 let inner = buffer.inner.borrow();
271 let attachment = BindGroupAttachment {
272 group,
273 binding,
274 attachment: BindGroupType::Storage(inner.buffer.clone()),
275 };
276
277 self.insert_or_replace_attachment(group, binding, attachment);
278 }
279 None => {
280 self.remove_attachment(group, binding);
281 }
282 }
283
284 self
285 }
286
287 #[inline]
288 pub fn set_attachment_storage_raw<T>(
289 mut self,
290 group: u32,
291 binding: u32,
292 buffer: Option<&[T]>,
293 ) -> Self
294 where
295 T: bytemuck::Pod + bytemuck::Zeroable,
296 {
297 match buffer {
298 Some(buffer) => {
299 let mut inner = self.gpu.borrow_mut();
300
301 let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
302 let attachment = BindGroupAttachment {
303 group,
304 binding,
305 attachment: BindGroupType::Storage(buffer),
306 };
307
308 drop(inner);
309
310 self.insert_or_replace_attachment(group, binding, attachment);
311 }
312 None => {
313 self.remove_attachment(group, binding);
314 }
315 }
316
317 self
318 }
319
320 #[inline]
321 pub fn set_attachment_storage_vec<T>(
322 mut self,
323 group: u32,
324 binding: u32,
325 buffer: Option<Vec<T>>,
326 ) -> Self
327 where
328 T: bytemuck::Pod + bytemuck::Zeroable,
329 {
330 match buffer {
331 Some(buffer) => {
332 let mut inner = self.gpu.borrow_mut();
333
334 let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
335 let attachment = BindGroupAttachment {
336 group,
337 binding,
338 attachment: BindGroupType::Storage(buffer),
339 };
340
341 drop(inner);
342
343 self.insert_or_replace_attachment(group, binding, attachment);
344 }
345 None => {
346 self.remove_attachment(group, binding);
347 }
348 }
349
350 self
351 }
352
353 #[inline]
354 pub(crate) fn remove_attachment(&mut self, group: u32, binding: u32) {
355 self.attachments
356 .retain(|a| a.group != group || a.binding != binding);
357 }
358
359 pub(crate) fn insert_or_replace_attachment(
360 &mut self,
361 group: u32,
362 binding: u32,
363 attachment: BindGroupAttachment,
364 ) {
365 let index = self
366 .attachments
367 .iter()
368 .position(|a| a.group == group && a.binding == binding);
369
370 if let Some(index) = index {
371 self.attachments[index] = attachment;
372 } else {
373 self.attachments.push(attachment);
374 }
375 }
376
377 pub fn build(self) -> Result<ComputePipeline, CompuitePipelineError> {
378 if self.shader.is_none() {
379 return Err(CompuitePipelineError::ShaderNotSet);
380 }
381
382 let shader_binding = self.shader.unwrap();
383 for attachment in &self.attachments {
384 let r#type = {
385 let shader_reflection = self.shader_reflection.as_ref().unwrap();
386
387 match shader_reflection {
388 ShaderReflect::Compute { bindings, .. } => bindings
389 .iter()
390 .find(|b| b.group == attachment.group && b.binding == attachment.binding),
391 _ => None,
392 }
393 };
394
395 if r#type.is_none() {
396 return Err(CompuitePipelineError::AttachmentNotSet(
397 attachment.group,
398 attachment.binding,
399 ));
400 }
401
402 let r#type = r#type.unwrap();
403
404 if !match r#type.ty {
405 ShaderBindingType::UniformBuffer(_) => {
406 matches!(attachment.attachment, BindGroupType::Uniform(_))
407 }
408 ShaderBindingType::StorageBuffer(_, _) => {
409 matches!(attachment.attachment, BindGroupType::Storage(_))
410 }
411 ShaderBindingType::StorageTexture(_) => {
412 matches!(attachment.attachment, BindGroupType::TextureStorage(_))
413 }
414 ShaderBindingType::Sampler(_) => {
415 matches!(attachment.attachment, BindGroupType::Sampler(_))
416 }
417 ShaderBindingType::Texture(_) => {
418 matches!(attachment.attachment, BindGroupType::Texture(_))
419 }
420 ShaderBindingType::PushConstant(_) => {
421 matches!(attachment.attachment, BindGroupType::Uniform(_))
422 }
423 } {
424 return Err(CompuitePipelineError::InvalidAttachmentType(
425 attachment.group,
426 attachment.binding,
427 r#type.ty,
428 ));
429 }
430 }
431
432 let bind_group_hash_key = {
433 let mut hasher = DefaultHasher::new();
434 hasher.write_u64(0u64); for attachment in &self.attachments {
437 attachment.group.hash(&mut hasher);
438 attachment.binding.hash(&mut hasher);
439 match &attachment.attachment {
440 BindGroupType::Uniform(uniform) => {
441 uniform.hash(&mut hasher);
442 }
443 BindGroupType::Texture(texture) => {
444 texture.hash(&mut hasher);
445 }
446 BindGroupType::TextureStorage(texture) => texture.hash(&mut hasher),
447 BindGroupType::Sampler(sampler) => sampler.hash(&mut hasher),
448 BindGroupType::Storage(storage) => storage.hash(&mut hasher),
449 }
450 }
451
452 hasher.finish()
453 };
454
455 let bind_group_attachments = {
456 let mut gpu_inner = self.gpu.borrow_mut();
457
458 match gpu_inner.get_bind_group(bind_group_hash_key) {
459 Some(bind_group) => bind_group,
460 None => {
461 let mut bind_group_attachments: HashMap<u32, Vec<wgpu::BindGroupEntry>> =
462 self.attachments.iter().fold(HashMap::new(), |mut map, e| {
463 let (group, binding, attachment) = (e.group, e.binding, &e.attachment);
464 let entry = match attachment {
465 BindGroupType::Uniform(buffer) => wgpu::BindGroupEntry {
466 binding,
467 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
468 buffer,
469 offset: 0,
470 size: None,
471 }),
472 },
473 BindGroupType::Texture(texture) => wgpu::BindGroupEntry {
474 binding,
475 resource: wgpu::BindingResource::TextureView(texture),
476 },
477 BindGroupType::Sampler(sampler) => wgpu::BindGroupEntry {
478 binding,
479 resource: wgpu::BindingResource::Sampler(sampler),
480 },
481 BindGroupType::Storage(buffer) => wgpu::BindGroupEntry {
482 binding,
483 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
484 buffer,
485 offset: 0,
486 size: None,
487 }),
488 },
489 BindGroupType::TextureStorage(texture) => wgpu::BindGroupEntry {
490 binding,
491 resource: wgpu::BindingResource::TextureView(texture),
492 },
493 };
494
495 map.entry(group).or_insert_with(Vec::new).push(entry);
496 map
497 });
498
499 for entries in bind_group_attachments.values_mut() {
503 entries.sort_by_key(|e| e.binding);
504 }
505
506 let bind_group = bind_group_attachments
507 .iter()
508 .map(|(group, entries)| {
509 let layout = shader_binding
510 .layout
511 .iter()
512 .find(|l| l.group == *group)
513 .unwrap();
514
515 (layout, entries.as_slice())
516 })
517 .collect::<Vec<_>>();
518
519 let create_info = BindGroupCreateInfo {
520 entries: bind_group,
521 };
522
523 gpu_inner.create_bind_group(bind_group_hash_key, create_info)
524 }
525 }
526 };
527
528 let layout = shader_binding
529 .layout
530 .iter()
531 .map(|l| l.layout.clone())
532 .collect::<Vec<_>>();
533
534 let pipeline_desc = ComputePipelineDesc {
535 shader_module: shader_binding.shader,
536 entry_point: shader_binding.entry_point,
537 bind_group_layout: layout,
538 };
539
540 let pipeline = ComputePipeline {
541 bind_group: bind_group_attachments,
542 pipeline_desc,
543 };
544
545 Ok(pipeline)
546 }
547}
548
549#[derive(Debug, Clone, Copy)]
550pub enum CompuitePipelineError {
551 ShaderNotSet,
552 InvalidShaderType,
553 AttachmentNotSet(u32, u32),
554 InvalidAttachmentType(u32, u32, ShaderBindingType),
555}