est_render/gpu/command/
computepass.rs

1use std::{collections::HashMap, hash::{DefaultHasher, Hash, Hasher}, sync::{atomic::AtomicBool, Arc}};
2
3use crate::utils::ArcRef;
4
5use super::{
6    super::{
7        GPUInner,
8        shader::{
9            ComputeShader,
10            bind_group_manager::BindGroupCreateInfo,
11            BindGroupLayout,
12            types::ShaderReflect,
13            ShaderBindingType,
14        },
15        buffer::{
16            Buffer,
17            BufferUsage
18        },
19        pipeline::{
20            compute::ComputePipeline,
21            manager::ComputePipelineDesc,
22        },
23        command::{
24            BindGroupAttachment,
25            BindGroupType
26        }
27    }
28};
29
30#[derive(Clone, Debug)]
31pub struct ComputePass {
32    pub(crate) graphics: ArcRef<GPUInner>,
33    pub(crate) inner: ArcRef<ComputePassInner>,
34}
35
36impl ComputePass {
37    pub(crate) fn new(
38        graphics: ArcRef<GPUInner>, 
39        cmd: ArcRef<wgpu::CommandEncoder>, 
40        atomic_pass: Arc<AtomicBool>
41    ) -> Result<Self, ComputePassBuildError> {
42        let inner = ComputePassInner {
43            cmd,
44            shader: None,
45            atomic_pass,
46
47            queues: Vec::new(),
48            attachments: Vec::new(),
49            push_constant: None,
50
51            #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
52            reflection: None,
53        };
54
55        Ok(ComputePass {
56            graphics,
57            inner: ArcRef::new(inner),
58        })
59    }
60
61    pub fn set_shader(&mut self, shader: Option<&ComputeShader>) {
62        let mut inner = self.inner.borrow_mut();
63
64        match shader {
65            Some(shader) => {
66                let shader_inner = shader.inner.borrow();
67
68                let shader_entry_point = match &shader_inner.reflection {
69                    ShaderReflect::Compute { entry_point, .. } => entry_point.clone(),
70                    _ => panic!("Shader is not a compute shader"),
71                };
72
73                #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
74                {
75                    if shader_entry_point.is_empty() {
76                        panic!("Compute shader entry point is empty");
77                    }
78                }
79
80                let layout = shader_inner.bind_group_layouts.clone();
81
82                let shader_binding = IntermediateComputeBinding {
83                    shader: shader_inner.shader.clone(),
84                    layout,
85                    entry_point: shader_entry_point,
86                };
87
88                inner.shader = Some(ComputeShaderBinding::Intermediate(shader_binding));
89            }
90            None => {
91                inner.shader = None;
92            }
93        }
94    }
95
96    pub fn set_pipeline(&mut self, pipeline: Option<&ComputePipeline>) {
97        let mut inner = self.inner.borrow_mut();
98
99        match pipeline {
100            Some(pipeline) => {
101                inner.shader = Some(ComputeShaderBinding::Pipeline(pipeline.clone()));
102            }
103            None => {
104                inner.shader = None;
105            }
106        }
107    }
108
109    #[cfg(not(target_arch = "wasm32"))]
110    pub fn set_push_constants(&mut self, push_constant: Option<&[u8]>) {
111        let mut inner = self.inner.borrow_mut();
112
113        match push_constant {
114            Some(push_constant) => {
115                let mut push_constant = push_constant.to_vec();
116                if push_constant.len() % 4 != 0 {
117                    push_constant.resize(push_constant.len() + (4 - push_constant.len() % 4), 0);
118                }
119
120                #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
121                {
122                    if inner.shader.is_none() {
123                        panic!("Shader must be set before setting push constants");
124                    }
125
126                    let size = {
127                        let shader_reflection = inner.reflection.as_ref().unwrap();
128
129                        match &shader_reflection {
130                            ShaderReflect::Compute { bindings, .. } => bindings
131                                .iter()
132                                .find_map(|binding| {
133                                    if let ShaderBindingType::PushConstant(size) = binding.ty {
134                                        Some(size)
135                                    } else {
136                                        None
137                                    }
138                                })
139                                .unwrap_or(0),
140                            _ => panic!("Shader is not a compute shader"),
141                        }
142                    };
143
144                    if size == 0 {
145                        panic!("No push constant found in shader");
146                    }
147
148                    if push_constant.len() > size as usize {
149                        panic!("Push constant size is too large");
150                    }
151                }
152
153                inner.push_constant = Some(push_constant);
154            }
155            None => {
156                inner.push_constant = None;
157            }
158        }
159    }
160
161    pub fn set_attachment_buffer(&mut self, group: u32, binding: u32, attachment: Option<&Buffer>) {
162        match attachment {
163            Some(attachment) => {
164                let buffer = attachment.inner.borrow().buffer.clone();
165
166                self.insert_or_replace_attachment(
167                    group,
168                    binding,
169                    BindGroupAttachment {
170                        group,
171                        binding,
172                        attachment: BindGroupType::Storage(buffer),
173                    },
174                );
175            }
176            None => {
177                self.remove_attachment(group, binding);
178            }
179        }
180    }
181
182    pub fn set_attachment_buffer_raw<T>(
183        &mut self,
184        group: u32,
185        binding: u32,
186        attachment: Option<&[T]>,
187        usages: BufferUsage,
188    ) where
189        T: bytemuck::Pod + bytemuck::Zeroable,
190    {
191        match attachment {
192            Some(attachment) => {
193                let buffer = self
194                    .graphics
195                    .borrow_mut()
196                    .create_buffer_with(attachment, usages.into());
197
198                self.insert_or_replace_attachment(
199                    group,
200                    binding,
201                    BindGroupAttachment {
202                        group,
203                        binding,
204                        attachment: BindGroupType::Storage(buffer),
205                    },
206                );
207            }
208            None => {
209                self.remove_attachment(group, binding);
210            }
211        }
212    }
213
214    pub(crate) fn remove_attachment(&mut self, group: u32, binding: u32) {
215        let mut inner = self.inner.borrow_mut();
216
217        inner
218            .attachments
219            .retain(|a| a.group != group || a.binding != binding);
220    }
221
222    pub(crate) fn insert_or_replace_attachment(
223        &mut self,
224        group: u32,
225        binding: u32,
226        attachment: BindGroupAttachment,
227    ) {
228        let mut inner = self.inner.borrow_mut();
229
230        #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
231        {
232            if inner.shader.is_none() {
233                panic!("Shader is not set");
234            }
235
236            match &inner.shader {
237                Some(ComputeShaderBinding::Pipeline(_)) => {
238                    panic!("Cannot insert or replace attachment when using a pipeline shader");
239                }
240                _ => {}
241            }
242
243            let reflection = inner.reflection.as_ref().unwrap();
244
245            let r#type = match reflection {
246                ShaderReflect::Compute { bindings, .. } => bindings
247                    .iter()
248                    .find_map(|shaderbinding| {
249                        if shaderbinding.group == group && shaderbinding.binding == binding {
250                            Some(shaderbinding)
251                        } else {
252                            None
253                        }
254                    })
255                    .unwrap_or_else(|| {
256                        panic!(
257                            "Shader does not have binding group: {} binding: {}",
258                            group, binding
259                        );
260                    }),
261                _ => panic!("Shader is not a compute shader"),
262            };
263
264            if !match r#type.ty {
265                ShaderBindingType::UniformBuffer(_) => {
266                    matches!(attachment.attachment, BindGroupType::Uniform(_))
267                }
268                ShaderBindingType::StorageBuffer(_, _) => {
269                    matches!(attachment.attachment, BindGroupType::Storage(_))
270                }
271                ShaderBindingType::StorageTexture(_) => {
272                    matches!(attachment.attachment, BindGroupType::TextureStorage(_))
273                }
274                ShaderBindingType::Sampler(_) => {
275                    matches!(attachment.attachment, BindGroupType::Sampler(_))
276                }
277                ShaderBindingType::Texture(_) => {
278                    matches!(attachment.attachment, BindGroupType::Texture(_))
279                }
280                ShaderBindingType::PushConstant(_) => {
281                    matches!(attachment.attachment, BindGroupType::Uniform(_))
282                }
283            } {
284                panic!(
285                    "Attachment group: {} binding: {} type: {} not match with shader type: {}",
286                    group, binding, attachment.attachment, r#type.ty
287                );
288            }
289        }
290
291        let index = inner
292            .attachments
293            .iter()
294            .position(|a| a.group == group && a.binding == binding);
295
296        if let Some(index) = index {
297            inner.attachments[index] = attachment;
298        } else {
299            inner.attachments.push(attachment);
300        }
301    }
302
303    pub fn dispatch(&mut self, x: u32, y: u32, z: u32) {
304        #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
305        {
306            let inner = self.inner.borrow();
307
308            if inner.shader.is_none() {
309                panic!("Shader must be set before dispatching");
310            }
311        }
312
313        let (pipeline, bind_group) = self.prepare_pipeline();
314        let mut inner = self.inner.borrow_mut();
315
316        let queue = ComputePassQueue {
317            pipeline,
318            bind_group,
319            ty: DispatchType::Dispatch { x, y, z },
320            push_constant: inner.push_constant.clone(),
321            debug: None,
322        };
323
324        inner.queues.push(queue);
325    }
326
327    pub fn dispatch_indirect(&mut self, buffer: &Buffer, offset: u64) {
328        #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
329        {
330            let inner = self.inner.borrow();
331
332            if inner.shader.is_none() {
333                panic!("Shader must be set before dispatching");
334            }
335        }
336
337        let (pipeline, bind_group) = self.prepare_pipeline();
338        let mut inner = self.inner.borrow_mut();
339
340        let queue = ComputePassQueue {
341            pipeline,
342            bind_group,
343            ty: DispatchType::DispatchIndirect {
344                buffer: buffer.inner.borrow().buffer.clone(),
345                offset,
346            },
347            push_constant: inner.push_constant.clone(),
348            debug: None,
349        };
350
351        inner.queues.push(queue);
352    }
353
354    fn prepare_pipeline(&self) -> (wgpu::ComputePipeline, Vec<(u32, wgpu::BindGroup)>) {
355        let inner = self.inner.borrow();
356
357        match &inner.shader {
358            Some(ComputeShaderBinding::Intermediate(shader_binding)) => {
359                let bind_group_hash_key = {
360                    let mut hasher = std::collections::hash_map::DefaultHasher::new();
361                    hasher.write_u64(1u64); // Compute shader hash id.
362
363                    for attachment in &inner.attachments {
364                        attachment.group.hash(&mut hasher);
365                        attachment.binding.hash(&mut hasher);
366
367                        match &attachment.attachment {
368                            BindGroupType::Uniform(buffer) => {
369                                buffer.hash(&mut hasher);
370                            }
371                            BindGroupType::Storage(buffer) => {
372                                buffer.hash(&mut hasher);
373                            }
374                            BindGroupType::TextureStorage(texture) => {
375                                texture.hash(&mut hasher);
376                            }
377                            BindGroupType::Sampler(sampler) => {
378                                sampler.hash(&mut hasher);
379                            }
380                            BindGroupType::Texture(texture) => {
381                                texture.hash(&mut hasher);
382                            }
383                        }
384                    }
385
386                    hasher.finish()
387                };
388
389                let bind_group_attachments = {
390                    let mut gpu_inner = self.graphics.borrow_mut();
391
392                    match gpu_inner.get_bind_group(bind_group_hash_key) {
393                        Some(bind_group) => bind_group,
394                        None => {
395                            let mut bind_group_attachments: HashMap<
396                                u32,
397                                Vec<wgpu::BindGroupEntry>,
398                            > = inner.attachments.iter().fold(HashMap::new(), |mut map, e| {
399                                let (group, binding, attachment) =
400                                    (e.group, e.binding, &e.attachment);
401
402                                let entry = match attachment {
403                                    BindGroupType::TextureStorage(texture) => {
404                                        wgpu::BindGroupEntry {
405                                            binding,
406                                            resource: wgpu::BindingResource::TextureView(texture),
407                                        }
408                                    }
409                                    BindGroupType::Storage(buffer) => wgpu::BindGroupEntry {
410                                        binding,
411                                        resource: wgpu::BindingResource::Buffer(
412                                            wgpu::BufferBinding {
413                                                buffer,
414                                                offset: 0,
415                                                size: None,
416                                            },
417                                        ),
418                                    },
419                                    _ => panic!("Unsupported bind group type"),
420                                };
421
422                                map.entry(group).or_insert_with(Vec::new).push(entry);
423                                map
424                            });
425
426                            // sort each group attachments
427                            // group, binding
428                            // this is important for the bind group to be created in the correct order
429                            for entries in bind_group_attachments.values_mut() {
430                                entries.sort_by_key(|e| e.binding);
431                            }
432
433                            let bind_group = bind_group_attachments
434                                .iter()
435                                .map(|(group, entries)| {
436                                    let layout = shader_binding
437                                        .layout
438                                        .iter()
439                                        .find(|l| l.group == *group)
440                                        .unwrap();
441
442                                    (layout, entries.as_slice())
443                                })
444                                .collect::<Vec<_>>();
445
446                            let create_info = BindGroupCreateInfo {
447                                entries: bind_group,
448                            };
449
450                            gpu_inner.create_bind_group(bind_group_hash_key, create_info)
451                        }
452                    }
453                };
454
455                let pipeline_hash_key = {
456                    let mut hasher = std::collections::hash_map::DefaultHasher::new();
457                    shader_binding.hash(&mut hasher);
458
459                    hasher.finish()
460                };
461
462                let pipeline = {
463                    let mut gpu_inner = self.graphics.borrow_mut();
464
465                    match gpu_inner.get_compute_pipeline(pipeline_hash_key) {
466                        Some(pipeline) => pipeline,
467                        None => {
468                            let bind_group_layout = shader_binding
469                                .layout
470                                .iter()
471                                .map(|l| l.layout.clone())
472                                .collect::<Vec<_>>();
473
474                            let entry_point = shader_binding.entry_point.as_str();
475
476                            let pipeline_desc = ComputePipelineDesc {
477                                shader_module: shader_binding.shader.clone(),
478                                entry_point: entry_point.to_owned(),
479                                bind_group_layout,
480                            };
481
482                            gpu_inner.create_compute_pipeline(pipeline_hash_key, pipeline_desc)
483                        }
484                    }
485                };
486
487                (pipeline, bind_group_attachments)
488            }
489            Some(ComputeShaderBinding::Pipeline(pipeline)) => {
490                let pipeline_hash_key = {
491                    let mut hasher = DefaultHasher::new();
492                    pipeline.pipeline_desc.hash(&mut hasher);
493
494                    hasher.finish()
495                };
496
497                let wgpu_pipeline = {
498                    let mut gpu_inner = self.graphics.borrow_mut();
499
500                    match gpu_inner.get_compute_pipeline(pipeline_hash_key) {
501                        Some(pipeline) => pipeline,
502                        None => gpu_inner.create_compute_pipeline(
503                            pipeline_hash_key,
504                            pipeline.pipeline_desc.clone(),
505                        ),
506                    }
507                };
508
509                (wgpu_pipeline, pipeline.bind_group.clone())
510            }
511            None => {
512                panic!("Compute shader or pipeline must be set before dispatching");
513            }
514        }
515    }
516
517    fn end(&mut self) {
518        let mut inner = self.inner.borrow_mut();
519
520        let queues = inner.queues.drain(..).collect::<Vec<_>>();
521        let mut cmd = inner.cmd.borrow_mut();
522
523        let mut cpass = cmd.begin_compute_pass(&wgpu::ComputePassDescriptor {
524            label: Some("Compute Pass"),
525            timestamp_writes: None,
526        });
527
528        for queue in queues {
529            cpass.set_pipeline(&queue.pipeline);
530
531            for (bind_group_index, bind_group) in &queue.bind_group {
532                cpass.set_bind_group(*bind_group_index, bind_group, &[]);
533            }
534
535            if let Some(debug) = &queue.debug {
536                cpass.insert_debug_marker(debug);
537            }
538
539            #[cfg(not(target_arch = "wasm32"))]
540            if let Some(push_constant) = &queue.push_constant {
541                cpass.set_push_constants(0, push_constant);
542            }
543
544            match &queue.ty {
545                DispatchType::Dispatch { x, y, z } => {
546                    cpass.dispatch_workgroups(*x, *y, *z);
547                }
548                DispatchType::DispatchIndirect { buffer, offset } => {
549                    cpass.dispatch_workgroups_indirect(buffer, *offset);
550                }
551            }
552        }
553
554        inner.atomic_pass.store(false, std::sync::atomic::Ordering::Relaxed);
555    }
556}
557
558impl Drop for ComputePass {
559    fn drop(&mut self) {
560        self.end();
561    }
562}
563
564#[derive(Clone, Debug)]
565pub(crate) struct ComputePassQueue {
566    pub pipeline: wgpu::ComputePipeline,
567    pub bind_group: Vec<(u32, wgpu::BindGroup)>,
568    pub ty: DispatchType,
569    pub push_constant: Option<Vec<u8>>,
570
571    pub debug: Option<String>,
572}
573
574#[derive(Clone, Debug)]
575pub(crate) struct ComputePassInner {
576    pub cmd: ArcRef<wgpu::CommandEncoder>,
577    pub shader: Option<ComputeShaderBinding>,
578    pub atomic_pass: Arc<AtomicBool>,
579
580    pub queues: Vec<ComputePassQueue>,
581    pub attachments: Vec<BindGroupAttachment>,
582    pub push_constant: Option<Vec<u8>>,
583
584    #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
585    pub reflection: Option<ShaderReflect>,
586}
587
588#[derive(Clone, Debug, Hash, PartialEq, Eq)]
589pub enum DispatchType {
590    Dispatch { x: u32, y: u32, z: u32 },
591    DispatchIndirect { buffer: wgpu::Buffer, offset: u64 },
592}
593
594#[derive(Clone, Debug, Hash)]
595pub(crate) struct IntermediateComputeBinding {
596    pub shader: wgpu::ShaderModule,
597    pub layout: Vec<BindGroupLayout>,
598    pub entry_point: String,
599}
600
601#[derive(Clone, Debug)]
602pub(crate) enum ComputeShaderBinding {
603    Intermediate(IntermediateComputeBinding),
604    Pipeline(ComputePipeline),
605}
606
607#[derive(Clone, Debug)]
608pub enum ComputePassBuildError {
609    None
610}