1use std::{collections::HashMap, hash::{DefaultHasher, Hash, Hasher}, sync::{atomic::AtomicBool, Arc}};
2
3use crate::utils::ArcRef;
4
5use super::{
6 super::{
7 GPUInner,
8 shader::{
9 ComputeShader,
10 bind_group_manager::BindGroupCreateInfo,
11 BindGroupLayout,
12 types::ShaderReflect,
13 ShaderBindingType,
14 },
15 buffer::{
16 Buffer,
17 BufferUsage
18 },
19 pipeline::{
20 compute::ComputePipeline,
21 manager::ComputePipelineDesc,
22 },
23 command::{
24 BindGroupAttachment,
25 BindGroupType
26 }
27 }
28};
29
30#[derive(Clone, Debug)]
31pub struct ComputePass {
32 pub(crate) graphics: ArcRef<GPUInner>,
33 pub(crate) inner: ArcRef<ComputePassInner>,
34}
35
36impl ComputePass {
37 pub(crate) fn new(
38 graphics: ArcRef<GPUInner>,
39 cmd: ArcRef<wgpu::CommandEncoder>,
40 atomic_pass: Arc<AtomicBool>
41 ) -> Result<Self, ComputePassBuildError> {
42 let inner = ComputePassInner {
43 cmd,
44 shader: None,
45 atomic_pass,
46
47 queues: Vec::new(),
48 attachments: Vec::new(),
49 push_constant: None,
50
51 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
52 reflection: None,
53 };
54
55 Ok(ComputePass {
56 graphics,
57 inner: ArcRef::new(inner),
58 })
59 }
60
61 pub fn set_shader(&mut self, shader: Option<&ComputeShader>) {
62 let mut inner = self.inner.borrow_mut();
63
64 match shader {
65 Some(shader) => {
66 let shader_inner = shader.inner.borrow();
67
68 let shader_entry_point = match &shader_inner.reflection {
69 ShaderReflect::Compute { entry_point, .. } => entry_point.clone(),
70 _ => panic!("Shader is not a compute shader"),
71 };
72
73 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
74 {
75 if shader_entry_point.is_empty() {
76 panic!("Compute shader entry point is empty");
77 }
78 }
79
80 let layout = shader_inner.bind_group_layouts.clone();
81
82 let shader_binding = IntermediateComputeBinding {
83 shader: shader_inner.shader.clone(),
84 layout,
85 entry_point: shader_entry_point,
86 };
87
88 inner.shader = Some(ComputeShaderBinding::Intermediate(shader_binding));
89 }
90 None => {
91 inner.shader = None;
92 }
93 }
94 }
95
96 pub fn set_pipeline(&mut self, pipeline: Option<&ComputePipeline>) {
97 let mut inner = self.inner.borrow_mut();
98
99 match pipeline {
100 Some(pipeline) => {
101 inner.shader = Some(ComputeShaderBinding::Pipeline(pipeline.clone()));
102 }
103 None => {
104 inner.shader = None;
105 }
106 }
107 }
108
109 #[cfg(not(target_arch = "wasm32"))]
110 pub fn set_push_constants(&mut self, push_constant: Option<&[u8]>) {
111 let mut inner = self.inner.borrow_mut();
112
113 match push_constant {
114 Some(push_constant) => {
115 let mut push_constant = push_constant.to_vec();
116 if push_constant.len() % 4 != 0 {
117 push_constant.resize(push_constant.len() + (4 - push_constant.len() % 4), 0);
118 }
119
120 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
121 {
122 if inner.shader.is_none() {
123 panic!("Shader must be set before setting push constants");
124 }
125
126 let size = {
127 let shader_reflection = inner.reflection.as_ref().unwrap();
128
129 match &shader_reflection {
130 ShaderReflect::Compute { bindings, .. } => bindings
131 .iter()
132 .find_map(|binding| {
133 if let ShaderBindingType::PushConstant(size) = binding.ty {
134 Some(size)
135 } else {
136 None
137 }
138 })
139 .unwrap_or(0),
140 _ => panic!("Shader is not a compute shader"),
141 }
142 };
143
144 if size == 0 {
145 panic!("No push constant found in shader");
146 }
147
148 if push_constant.len() > size as usize {
149 panic!("Push constant size is too large");
150 }
151 }
152
153 inner.push_constant = Some(push_constant);
154 }
155 None => {
156 inner.push_constant = None;
157 }
158 }
159 }
160
161 pub fn set_attachment_buffer(&mut self, group: u32, binding: u32, attachment: Option<&Buffer>) {
162 match attachment {
163 Some(attachment) => {
164 let buffer = attachment.inner.borrow().buffer.clone();
165
166 self.insert_or_replace_attachment(
167 group,
168 binding,
169 BindGroupAttachment {
170 group,
171 binding,
172 attachment: BindGroupType::Storage(buffer),
173 },
174 );
175 }
176 None => {
177 self.remove_attachment(group, binding);
178 }
179 }
180 }
181
182 pub fn set_attachment_buffer_raw<T>(
183 &mut self,
184 group: u32,
185 binding: u32,
186 attachment: Option<&[T]>,
187 usages: BufferUsage,
188 ) where
189 T: bytemuck::Pod + bytemuck::Zeroable,
190 {
191 match attachment {
192 Some(attachment) => {
193 let buffer = self
194 .graphics
195 .borrow_mut()
196 .create_buffer_with(attachment, usages.into());
197
198 self.insert_or_replace_attachment(
199 group,
200 binding,
201 BindGroupAttachment {
202 group,
203 binding,
204 attachment: BindGroupType::Storage(buffer),
205 },
206 );
207 }
208 None => {
209 self.remove_attachment(group, binding);
210 }
211 }
212 }
213
214 pub(crate) fn remove_attachment(&mut self, group: u32, binding: u32) {
215 let mut inner = self.inner.borrow_mut();
216
217 inner
218 .attachments
219 .retain(|a| a.group != group || a.binding != binding);
220 }
221
222 pub(crate) fn insert_or_replace_attachment(
223 &mut self,
224 group: u32,
225 binding: u32,
226 attachment: BindGroupAttachment,
227 ) {
228 let mut inner = self.inner.borrow_mut();
229
230 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
231 {
232 if inner.shader.is_none() {
233 panic!("Shader is not set");
234 }
235
236 match &inner.shader {
237 Some(ComputeShaderBinding::Pipeline(_)) => {
238 panic!("Cannot insert or replace attachment when using a pipeline shader");
239 }
240 _ => {}
241 }
242
243 let reflection = inner.reflection.as_ref().unwrap();
244
245 let r#type = match reflection {
246 ShaderReflect::Compute { bindings, .. } => bindings
247 .iter()
248 .find_map(|shaderbinding| {
249 if shaderbinding.group == group && shaderbinding.binding == binding {
250 Some(shaderbinding)
251 } else {
252 None
253 }
254 })
255 .unwrap_or_else(|| {
256 panic!(
257 "Shader does not have binding group: {} binding: {}",
258 group, binding
259 );
260 }),
261 _ => panic!("Shader is not a compute shader"),
262 };
263
264 if !match r#type.ty {
265 ShaderBindingType::UniformBuffer(_) => {
266 matches!(attachment.attachment, BindGroupType::Uniform(_))
267 }
268 ShaderBindingType::StorageBuffer(_, _) => {
269 matches!(attachment.attachment, BindGroupType::Storage(_))
270 }
271 ShaderBindingType::StorageTexture(_) => {
272 matches!(attachment.attachment, BindGroupType::TextureStorage(_))
273 }
274 ShaderBindingType::Sampler(_) => {
275 matches!(attachment.attachment, BindGroupType::Sampler(_))
276 }
277 ShaderBindingType::Texture(_) => {
278 matches!(attachment.attachment, BindGroupType::Texture(_))
279 }
280 ShaderBindingType::PushConstant(_) => {
281 matches!(attachment.attachment, BindGroupType::Uniform(_))
282 }
283 } {
284 panic!(
285 "Attachment group: {} binding: {} type: {} not match with shader type: {}",
286 group, binding, attachment.attachment, r#type.ty
287 );
288 }
289 }
290
291 let index = inner
292 .attachments
293 .iter()
294 .position(|a| a.group == group && a.binding == binding);
295
296 if let Some(index) = index {
297 inner.attachments[index] = attachment;
298 } else {
299 inner.attachments.push(attachment);
300 }
301 }
302
303 pub fn dispatch(&mut self, x: u32, y: u32, z: u32) {
304 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
305 {
306 let inner = self.inner.borrow();
307
308 if inner.shader.is_none() {
309 panic!("Shader must be set before dispatching");
310 }
311 }
312
313 let (pipeline, bind_group) = self.prepare_pipeline();
314 let mut inner = self.inner.borrow_mut();
315
316 let queue = ComputePassQueue {
317 pipeline,
318 bind_group,
319 ty: DispatchType::Dispatch { x, y, z },
320 push_constant: inner.push_constant.clone(),
321 debug: None,
322 };
323
324 inner.queues.push(queue);
325 }
326
327 pub fn dispatch_indirect(&mut self, buffer: &Buffer, offset: u64) {
328 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
329 {
330 let inner = self.inner.borrow();
331
332 if inner.shader.is_none() {
333 panic!("Shader must be set before dispatching");
334 }
335 }
336
337 let (pipeline, bind_group) = self.prepare_pipeline();
338 let mut inner = self.inner.borrow_mut();
339
340 let queue = ComputePassQueue {
341 pipeline,
342 bind_group,
343 ty: DispatchType::DispatchIndirect {
344 buffer: buffer.inner.borrow().buffer.clone(),
345 offset,
346 },
347 push_constant: inner.push_constant.clone(),
348 debug: None,
349 };
350
351 inner.queues.push(queue);
352 }
353
354 fn prepare_pipeline(&self) -> (wgpu::ComputePipeline, Vec<(u32, wgpu::BindGroup)>) {
355 let inner = self.inner.borrow();
356
357 match &inner.shader {
358 Some(ComputeShaderBinding::Intermediate(shader_binding)) => {
359 let bind_group_hash_key = {
360 let mut hasher = std::collections::hash_map::DefaultHasher::new();
361 hasher.write_u64(1u64); for attachment in &inner.attachments {
364 attachment.group.hash(&mut hasher);
365 attachment.binding.hash(&mut hasher);
366
367 match &attachment.attachment {
368 BindGroupType::Uniform(buffer) => {
369 buffer.hash(&mut hasher);
370 }
371 BindGroupType::Storage(buffer) => {
372 buffer.hash(&mut hasher);
373 }
374 BindGroupType::TextureStorage(texture) => {
375 texture.hash(&mut hasher);
376 }
377 BindGroupType::Sampler(sampler) => {
378 sampler.hash(&mut hasher);
379 }
380 BindGroupType::Texture(texture) => {
381 texture.hash(&mut hasher);
382 }
383 }
384 }
385
386 hasher.finish()
387 };
388
389 let bind_group_attachments = {
390 let mut gpu_inner = self.graphics.borrow_mut();
391
392 match gpu_inner.get_bind_group(bind_group_hash_key) {
393 Some(bind_group) => bind_group,
394 None => {
395 let mut bind_group_attachments: HashMap<
396 u32,
397 Vec<wgpu::BindGroupEntry>,
398 > = inner.attachments.iter().fold(HashMap::new(), |mut map, e| {
399 let (group, binding, attachment) =
400 (e.group, e.binding, &e.attachment);
401
402 let entry = match attachment {
403 BindGroupType::TextureStorage(texture) => {
404 wgpu::BindGroupEntry {
405 binding,
406 resource: wgpu::BindingResource::TextureView(texture),
407 }
408 }
409 BindGroupType::Storage(buffer) => wgpu::BindGroupEntry {
410 binding,
411 resource: wgpu::BindingResource::Buffer(
412 wgpu::BufferBinding {
413 buffer,
414 offset: 0,
415 size: None,
416 },
417 ),
418 },
419 _ => panic!("Unsupported bind group type"),
420 };
421
422 map.entry(group).or_insert_with(Vec::new).push(entry);
423 map
424 });
425
426 for entries in bind_group_attachments.values_mut() {
430 entries.sort_by_key(|e| e.binding);
431 }
432
433 let bind_group = bind_group_attachments
434 .iter()
435 .map(|(group, entries)| {
436 let layout = shader_binding
437 .layout
438 .iter()
439 .find(|l| l.group == *group)
440 .unwrap();
441
442 (layout, entries.as_slice())
443 })
444 .collect::<Vec<_>>();
445
446 let create_info = BindGroupCreateInfo {
447 entries: bind_group,
448 };
449
450 gpu_inner.create_bind_group(bind_group_hash_key, create_info)
451 }
452 }
453 };
454
455 let pipeline_hash_key = {
456 let mut hasher = std::collections::hash_map::DefaultHasher::new();
457 shader_binding.hash(&mut hasher);
458
459 hasher.finish()
460 };
461
462 let pipeline = {
463 let mut gpu_inner = self.graphics.borrow_mut();
464
465 match gpu_inner.get_compute_pipeline(pipeline_hash_key) {
466 Some(pipeline) => pipeline,
467 None => {
468 let bind_group_layout = shader_binding
469 .layout
470 .iter()
471 .map(|l| l.layout.clone())
472 .collect::<Vec<_>>();
473
474 let entry_point = shader_binding.entry_point.as_str();
475
476 let pipeline_desc = ComputePipelineDesc {
477 shader_module: shader_binding.shader.clone(),
478 entry_point: entry_point.to_owned(),
479 bind_group_layout,
480 };
481
482 gpu_inner.create_compute_pipeline(pipeline_hash_key, pipeline_desc)
483 }
484 }
485 };
486
487 (pipeline, bind_group_attachments)
488 }
489 Some(ComputeShaderBinding::Pipeline(pipeline)) => {
490 let pipeline_hash_key = {
491 let mut hasher = DefaultHasher::new();
492 pipeline.pipeline_desc.hash(&mut hasher);
493
494 hasher.finish()
495 };
496
497 let wgpu_pipeline = {
498 let mut gpu_inner = self.graphics.borrow_mut();
499
500 match gpu_inner.get_compute_pipeline(pipeline_hash_key) {
501 Some(pipeline) => pipeline,
502 None => gpu_inner.create_compute_pipeline(
503 pipeline_hash_key,
504 pipeline.pipeline_desc.clone(),
505 ),
506 }
507 };
508
509 (wgpu_pipeline, pipeline.bind_group.clone())
510 }
511 None => {
512 panic!("Compute shader or pipeline must be set before dispatching");
513 }
514 }
515 }
516
517 fn end(&mut self) {
518 let mut inner = self.inner.borrow_mut();
519
520 let queues = inner.queues.drain(..).collect::<Vec<_>>();
521 let mut cmd = inner.cmd.borrow_mut();
522
523 let mut cpass = cmd.begin_compute_pass(&wgpu::ComputePassDescriptor {
524 label: Some("Compute Pass"),
525 timestamp_writes: None,
526 });
527
528 for queue in queues {
529 cpass.set_pipeline(&queue.pipeline);
530
531 for (bind_group_index, bind_group) in &queue.bind_group {
532 cpass.set_bind_group(*bind_group_index, bind_group, &[]);
533 }
534
535 if let Some(debug) = &queue.debug {
536 cpass.insert_debug_marker(debug);
537 }
538
539 #[cfg(not(target_arch = "wasm32"))]
540 if let Some(push_constant) = &queue.push_constant {
541 cpass.set_push_constants(0, push_constant);
542 }
543
544 match &queue.ty {
545 DispatchType::Dispatch { x, y, z } => {
546 cpass.dispatch_workgroups(*x, *y, *z);
547 }
548 DispatchType::DispatchIndirect { buffer, offset } => {
549 cpass.dispatch_workgroups_indirect(buffer, *offset);
550 }
551 }
552 }
553
554 inner.atomic_pass.store(false, std::sync::atomic::Ordering::Relaxed);
555 }
556}
557
558impl Drop for ComputePass {
559 fn drop(&mut self) {
560 self.end();
561 }
562}
563
564#[derive(Clone, Debug)]
565pub(crate) struct ComputePassQueue {
566 pub pipeline: wgpu::ComputePipeline,
567 pub bind_group: Vec<(u32, wgpu::BindGroup)>,
568 pub ty: DispatchType,
569 pub push_constant: Option<Vec<u8>>,
570
571 pub debug: Option<String>,
572}
573
574#[derive(Clone, Debug)]
575pub(crate) struct ComputePassInner {
576 pub cmd: ArcRef<wgpu::CommandEncoder>,
577 pub shader: Option<ComputeShaderBinding>,
578 pub atomic_pass: Arc<AtomicBool>,
579
580 pub queues: Vec<ComputePassQueue>,
581 pub attachments: Vec<BindGroupAttachment>,
582 pub push_constant: Option<Vec<u8>>,
583
584 #[cfg(any(debug_assertions, feature = "enable-release-validation"))]
585 pub reflection: Option<ShaderReflect>,
586}
587
588#[derive(Clone, Debug, Hash, PartialEq, Eq)]
589pub enum DispatchType {
590 Dispatch { x: u32, y: u32, z: u32 },
591 DispatchIndirect { buffer: wgpu::Buffer, offset: u64 },
592}
593
594#[derive(Clone, Debug, Hash)]
595pub(crate) struct IntermediateComputeBinding {
596 pub shader: wgpu::ShaderModule,
597 pub layout: Vec<BindGroupLayout>,
598 pub entry_point: String,
599}
600
601#[derive(Clone, Debug)]
602pub(crate) enum ComputeShaderBinding {
603 Intermediate(IntermediateComputeBinding),
604 Pipeline(ComputePipeline),
605}
606
607#[derive(Clone, Debug)]
608pub enum ComputePassBuildError {
609 None
610}