est_render/gpu/pipeline/
compute.rs

1use std::{
2    collections::HashMap,
3    hash::{DefaultHasher, Hash, Hasher},
4};
5
6use crate::utils::ArcRef;
7
8use super::{
9    manager::ComputePipelineDesc,
10    super::{
11        GPUInner,
12        texture::{Texture, TextureSampler},
13        buffer::Buffer,
14        command::{
15            BindGroupAttachment,
16            utils::BindGroupType,
17            computepass::IntermediateComputeBinding
18        },
19        shader::{
20            bind_group_manager::BindGroupCreateInfo,
21            types::{ShaderReflect, ShaderBindingType},
22            compute::ComputeShader,
23        },
24    },
25};
26
27#[derive(Debug, Clone, Hash)]
28pub struct ComputePipeline {
29    pub(crate) bind_group: Vec<(u32, wgpu::BindGroup)>,
30    pub(crate) pipeline_desc: ComputePipelineDesc,
31}
32
33#[derive(Debug, Clone)]
34pub struct ComputePipelineBuilder {
35    pub(crate) gpu: ArcRef<GPUInner>,
36    pub(crate) attachments: Vec<BindGroupAttachment>,
37    pub(crate) shader: Option<IntermediateComputeBinding>,
38    pub(crate) shader_reflection: Option<ShaderReflect>,
39}
40
41impl ComputePipelineBuilder {
42    pub(crate) fn new(gpu: ArcRef<GPUInner>) -> Self {
43        Self {
44            gpu,
45            attachments: Vec::new(),
46            shader: None,
47            shader_reflection: None,
48        }
49    }
50
51    #[inline]
52    pub fn set_shader(mut self, shader: Option<&ComputeShader>) -> Self {
53        match shader {
54            Some(shader) => {
55                let shader_inner = shader.inner.borrow();
56                let shader_module = shader_inner.shader.clone();
57                let layout = shader_inner.bind_group_layouts.clone();
58
59                let shader_reflect = shader_inner.reflection.clone();
60                let entry_point = match &shader_reflect {
61                    ShaderReflect::Compute { entry_point, .. } => entry_point.clone(),
62                    _ => panic!("Shader must be a compute shader"),
63                };
64
65                let shader_binding = IntermediateComputeBinding {
66                    shader: shader_module,
67                    entry_point,
68                    layout,
69                };
70
71                self.shader = Some(shader_binding);
72                self.shader_reflection = Some(shader_reflect);
73            }
74            None => {
75                self.shader = None;
76                self.shader_reflection = None;
77            }
78        }
79
80        self
81    }
82
83    #[inline]
84    pub fn set_attachment_sampler(
85        mut self,
86        group: u32,
87        binding: u32,
88        sampler: Option<&TextureSampler>,
89    ) -> Self {
90        match sampler {
91            Some(sampler) => {
92                let attachment = {
93                    let gpu_inner = self.gpu.borrow();
94
95                    BindGroupAttachment {
96                        group,
97                        binding,
98                        attachment: BindGroupType::Sampler(
99                            sampler.make_wgpu(gpu_inner.device()),
100                        ),
101                    }
102                };
103
104                self.insert_or_replace_attachment(group, binding, attachment);
105            }
106            None => {
107                self.remove_attachment(group, binding);
108            }
109        }
110
111        self
112    }
113
114    #[inline]
115    pub fn set_attachment_texture(
116        mut self,
117        group: u32,
118        binding: u32,
119        texture: Option<&Texture>,
120    ) -> Self {
121        match texture {
122            Some(texture) => {
123                let attachment = {
124                    BindGroupAttachment {
125                        group,
126                        binding,
127                        attachment: BindGroupType::Texture(
128                            texture.inner.borrow().wgpu_view.clone(),
129                        ),
130                    }
131                };
132
133                self.insert_or_replace_attachment(group, binding, attachment);
134            }
135            None => {
136                self.remove_attachment(group, binding);
137            }
138        }
139
140        self
141    }
142
143    #[inline]
144    pub fn set_attachment_texture_storage(
145        mut self,
146        group: u32,
147        binding: u32,
148        texture: Option<&Texture>,
149    ) -> Self {
150        match texture {
151            Some(texture) => {
152                let inner = texture.inner.borrow();
153                let attachment = BindGroupAttachment {
154                    group,
155                    binding,
156                    attachment: BindGroupType::TextureStorage(inner.wgpu_view.clone()),
157                };
158
159                self.insert_or_replace_attachment(group, binding, attachment);
160            }
161            None => {
162                self.remove_attachment(group, binding);
163            }
164        }
165
166        self
167    }
168
169    #[inline]
170    pub fn set_attachment_uniform(
171        mut self,
172        group: u32,
173        binding: u32,
174        buffer: Option<&Buffer>,
175    ) -> Self {
176        match buffer {
177            Some(buffer) => {
178                let inner = buffer.inner.borrow();
179                let attachment = BindGroupAttachment {
180                    group,
181                    binding,
182                    attachment: BindGroupType::Uniform(inner.buffer.clone()),
183                };
184
185                self.insert_or_replace_attachment(group, binding, attachment);
186            }
187            None => {
188                self.remove_attachment(group, binding);
189            }
190        }
191
192        self
193    }
194
195    #[inline]
196    pub fn set_attachment_uniform_vec<T>(
197        mut self,
198        group: u32,
199        binding: u32,
200        buffer: Option<Vec<T>>,
201    ) -> Self
202    where
203        T: bytemuck::Pod + bytemuck::Zeroable,
204    {
205        match buffer {
206            Some(buffer) => {
207                let attachment = {
208                    let mut inner = self.gpu.borrow_mut();
209
210                    let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
211                    BindGroupAttachment {
212                        group,
213                        binding,
214                        attachment: BindGroupType::Uniform(buffer),
215                    }
216                };
217
218                self.insert_or_replace_attachment(group, binding, attachment);
219            }
220            None => {
221                self.remove_attachment(group, binding);
222            }
223        }
224
225        self
226    }
227
228    #[inline]
229    pub fn set_attachment_uniform_raw<T>(
230        mut self,
231        group: u32,
232        binding: u32,
233        buffer: Option<&[T]>,
234    ) -> Self
235    where
236        T: bytemuck::Pod + bytemuck::Zeroable,
237    {
238        match buffer {
239            Some(buffer) => {
240                let mut inner = self.gpu.borrow_mut();
241
242                let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
243                let attachment = BindGroupAttachment {
244                    group,
245                    binding,
246                    attachment: BindGroupType::Uniform(buffer),
247                };
248
249                drop(inner);
250
251                self.insert_or_replace_attachment(group, binding, attachment);
252            }
253            None => {
254                self.remove_attachment(group, binding);
255            }
256        }
257
258        self
259    }
260
261    #[inline]
262    pub fn set_attachment_storage(
263        mut self,
264        group: u32,
265        binding: u32,
266        buffer: Option<&Buffer>,
267    ) -> Self {
268        match buffer {
269            Some(buffer) => {
270                let inner = buffer.inner.borrow();
271                let attachment = BindGroupAttachment {
272                    group,
273                    binding,
274                    attachment: BindGroupType::Storage(inner.buffer.clone()),
275                };
276
277                self.insert_or_replace_attachment(group, binding, attachment);
278            }
279            None => {
280                self.remove_attachment(group, binding);
281            }
282        }
283
284        self
285    }
286
287    #[inline]
288    pub fn set_attachment_storage_raw<T>(
289        mut self,
290        group: u32,
291        binding: u32,
292        buffer: Option<&[T]>,
293    ) -> Self
294    where
295        T: bytemuck::Pod + bytemuck::Zeroable,
296    {
297        match buffer {
298            Some(buffer) => {
299                let mut inner = self.gpu.borrow_mut();
300
301                let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
302                let attachment = BindGroupAttachment {
303                    group,
304                    binding,
305                    attachment: BindGroupType::Storage(buffer),
306                };
307
308                drop(inner);
309
310                self.insert_or_replace_attachment(group, binding, attachment);
311            }
312            None => {
313                self.remove_attachment(group, binding);
314            }
315        }
316
317        self
318    }
319
320    #[inline]
321    pub fn set_attachment_storage_vec<T>(
322        mut self,
323        group: u32,
324        binding: u32,
325        buffer: Option<Vec<T>>,
326    ) -> Self
327    where
328        T: bytemuck::Pod + bytemuck::Zeroable,
329    {
330        match buffer {
331            Some(buffer) => {
332                let mut inner = self.gpu.borrow_mut();
333
334                let buffer = inner.create_buffer_with(&buffer, wgpu::BufferUsages::COPY_DST);
335                let attachment = BindGroupAttachment {
336                    group,
337                    binding,
338                    attachment: BindGroupType::Storage(buffer),
339                };
340
341                drop(inner);
342
343                self.insert_or_replace_attachment(group, binding, attachment);
344            }
345            None => {
346                self.remove_attachment(group, binding);
347            }
348        }
349
350        self
351    }
352
353    #[inline]
354    pub(crate) fn remove_attachment(&mut self, group: u32, binding: u32) {
355        self.attachments
356            .retain(|a| a.group != group || a.binding != binding);
357    }
358
359    pub(crate) fn insert_or_replace_attachment(
360        &mut self,
361        group: u32,
362        binding: u32,
363        attachment: BindGroupAttachment,
364    ) {
365        let index = self
366            .attachments
367            .iter()
368            .position(|a| a.group == group && a.binding == binding);
369
370        if let Some(index) = index {
371            self.attachments[index] = attachment;
372        } else {
373            self.attachments.push(attachment);
374        }
375    }
376
377    pub fn build(self) -> Result<ComputePipeline, CompuitePipelineError> {
378        if self.shader.is_none() {
379            return Err(CompuitePipelineError::ShaderNotSet);
380        }
381
382        let shader_binding = self.shader.unwrap();
383        for attachment in &self.attachments {
384            let r#type = {
385                let shader_reflection = self.shader_reflection.as_ref().unwrap();
386
387                match shader_reflection {
388                    ShaderReflect::Compute { bindings, .. } => bindings
389                        .iter()
390                        .find(|b| b.group == attachment.group && b.binding == attachment.binding),
391                    _ => None,
392                }
393            };
394
395            if r#type.is_none() {
396                return Err(CompuitePipelineError::AttachmentNotSet(
397                    attachment.group,
398                    attachment.binding,
399                ));
400            }
401
402            let r#type = r#type.unwrap();
403
404            if !match r#type.ty {
405                ShaderBindingType::UniformBuffer(_) => {
406                    matches!(attachment.attachment, BindGroupType::Uniform(_))
407                }
408                ShaderBindingType::StorageBuffer(_, _) => {
409                    matches!(attachment.attachment, BindGroupType::Storage(_))
410                }
411                ShaderBindingType::StorageTexture(_) => {
412                    matches!(attachment.attachment, BindGroupType::TextureStorage(_))
413                }
414                ShaderBindingType::Sampler(_) => {
415                    matches!(attachment.attachment, BindGroupType::Sampler(_))
416                }
417                ShaderBindingType::Texture(_) => {
418                    matches!(attachment.attachment, BindGroupType::Texture(_))
419                }
420                ShaderBindingType::PushConstant(_) => {
421                    matches!(attachment.attachment, BindGroupType::Uniform(_))
422                }
423            } {
424                return Err(CompuitePipelineError::InvalidAttachmentType(
425                    attachment.group,
426                    attachment.binding,
427                    r#type.ty,
428                ));
429            }
430        }
431
432        let bind_group_hash_key = {
433            let mut hasher = DefaultHasher::new();
434            hasher.write_u64(0u64); // Graphics shader hash id
435
436            for attachment in &self.attachments {
437                attachment.group.hash(&mut hasher);
438                attachment.binding.hash(&mut hasher);
439                match &attachment.attachment {
440                    BindGroupType::Uniform(uniform) => {
441                        uniform.hash(&mut hasher);
442                    }
443                    BindGroupType::Texture(texture) => {
444                        texture.hash(&mut hasher);
445                    }
446                    BindGroupType::TextureStorage(texture) => texture.hash(&mut hasher),
447                    BindGroupType::Sampler(sampler) => sampler.hash(&mut hasher),
448                    BindGroupType::Storage(storage) => storage.hash(&mut hasher),
449                }
450            }
451
452            hasher.finish()
453        };
454
455        let bind_group_attachments = {
456            let mut gpu_inner = self.gpu.borrow_mut();
457
458            match gpu_inner.get_bind_group(bind_group_hash_key) {
459                Some(bind_group) => bind_group,
460                None => {
461                    let mut bind_group_attachments: HashMap<u32, Vec<wgpu::BindGroupEntry>> =
462                        self.attachments.iter().fold(HashMap::new(), |mut map, e| {
463                            let (group, binding, attachment) = (e.group, e.binding, &e.attachment);
464                            let entry = match attachment {
465                                BindGroupType::Uniform(buffer) => wgpu::BindGroupEntry {
466                                    binding,
467                                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
468                                        buffer,
469                                        offset: 0,
470                                        size: None,
471                                    }),
472                                },
473                                BindGroupType::Texture(texture) => wgpu::BindGroupEntry {
474                                    binding,
475                                    resource: wgpu::BindingResource::TextureView(texture),
476                                },
477                                BindGroupType::Sampler(sampler) => wgpu::BindGroupEntry {
478                                    binding,
479                                    resource: wgpu::BindingResource::Sampler(sampler),
480                                },
481                                BindGroupType::Storage(buffer) => wgpu::BindGroupEntry {
482                                    binding,
483                                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
484                                        buffer,
485                                        offset: 0,
486                                        size: None,
487                                    }),
488                                },
489                                BindGroupType::TextureStorage(texture) => wgpu::BindGroupEntry {
490                                    binding,
491                                    resource: wgpu::BindingResource::TextureView(texture),
492                                },
493                            };
494
495                            map.entry(group).or_insert_with(Vec::new).push(entry);
496                            map
497                        });
498
499                    // sort each group attachments
500                    // group, binding
501                    // this is important for the bind group to be created in the correct order
502                    for entries in bind_group_attachments.values_mut() {
503                        entries.sort_by_key(|e| e.binding);
504                    }
505
506                    let bind_group = bind_group_attachments
507                        .iter()
508                        .map(|(group, entries)| {
509                            let layout = shader_binding
510                                .layout
511                                .iter()
512                                .find(|l| l.group == *group)
513                                .unwrap();
514
515                            (layout, entries.as_slice())
516                        })
517                        .collect::<Vec<_>>();
518
519                    let create_info = BindGroupCreateInfo {
520                        entries: bind_group,
521                    };
522
523                    gpu_inner.create_bind_group(bind_group_hash_key, create_info)
524                }
525            }
526        };
527
528        let layout = shader_binding
529            .layout
530            .iter()
531            .map(|l| l.layout.clone())
532            .collect::<Vec<_>>();
533
534        let pipeline_desc = ComputePipelineDesc {
535            shader_module: shader_binding.shader,
536            entry_point: shader_binding.entry_point,
537            bind_group_layout: layout,
538        };
539
540        let pipeline = ComputePipeline {
541            bind_group: bind_group_attachments,
542            pipeline_desc,
543        };
544
545        Ok(pipeline)
546    }
547}
548
549#[derive(Debug, Clone, Copy)]
550pub enum CompuitePipelineError {
551    ShaderNotSet,
552    InvalidShaderType,
553    AttachmentNotSet(u32, u32),
554    InvalidAttachmentType(u32, u32, ShaderBindingType),
555}