1pub struct ComputePipeline {
25 pipeline: wgpu::ComputePipeline,
26 bind_group_layouts: Vec<wgpu::BindGroupLayout>,
27}
28
29impl ComputePipeline {
30 pub fn new(
40 device: &wgpu::Device,
41 wgsl_source: &str,
42 entry_point: &str,
43 buffer_count: u32,
44 ) -> Self {
45 let entries: Vec<wgpu::BindGroupLayoutEntry> = (0..buffer_count)
46 .map(|i| wgpu::BindGroupLayoutEntry {
47 binding: i,
48 visibility: wgpu::ShaderStages::COMPUTE,
49 ty: wgpu::BindingType::Buffer {
50 ty: wgpu::BufferBindingType::Storage { read_only: i > 0 },
51 has_dynamic_offset: false,
52 min_binding_size: None,
53 },
54 count: None,
55 })
56 .collect();
57
58 Self::with_layout(device, wgsl_source, entry_point, &entries)
59 }
60
61 pub fn with_layout(
66 device: &wgpu::Device,
67 wgsl_source: &str,
68 entry_point: &str,
69 entries: &[wgpu::BindGroupLayoutEntry],
70 ) -> Self {
71 Self::with_layouts(device, wgsl_source, entry_point, &[entries])
72 }
73
74 pub fn with_layouts(
83 device: &wgpu::Device,
84 wgsl_source: &str,
85 entry_point: &str,
86 groups: &[&[wgpu::BindGroupLayoutEntry]],
87 ) -> Self {
88 tracing::debug!(
89 entry_point,
90 groups = groups.len(),
91 "creating compute pipeline"
92 );
93 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
94 label: Some("compute_shader"),
95 source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
96 });
97
98 let bind_group_layouts: Vec<wgpu::BindGroupLayout> = groups
99 .iter()
100 .enumerate()
101 .map(|(i, entries)| {
102 use std::fmt::Write;
103 let mut label = String::with_capacity(20);
104 let _ = write!(label, "compute_layout_{i}");
105 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106 label: Some(&label),
107 entries,
108 })
109 })
110 .collect();
111
112 let layout_refs: Vec<Option<&wgpu::BindGroupLayout>> =
113 bind_group_layouts.iter().map(Some).collect();
114
115 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
116 label: Some("compute_pipeline_layout"),
117 bind_group_layouts: &layout_refs,
118 immediate_size: 0,
119 });
120
121 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
122 label: Some("compute_pipeline"),
123 layout: Some(&pipeline_layout),
124 module: &shader,
125 entry_point: Some(entry_point),
126 compilation_options: wgpu::PipelineCompilationOptions::default(),
127 cache: None,
128 });
129
130 Self {
131 pipeline,
132 bind_group_layouts,
133 }
134 }
135
136 #[must_use]
141 #[inline]
142 pub fn bind_group_layout(&self, index: usize) -> Option<&wgpu::BindGroupLayout> {
143 self.bind_group_layouts.get(index)
144 }
145
146 #[must_use]
148 #[inline]
149 pub fn bind_group_layout_count(&self) -> usize {
150 self.bind_group_layouts.len()
151 }
152
153 #[must_use]
155 #[inline]
156 pub fn raw(&self) -> &wgpu::ComputePipeline {
157 &self.pipeline
158 }
159
160 pub fn dispatch(
165 &self,
166 device: &wgpu::Device,
167 queue: &wgpu::Queue,
168 bind_group: &wgpu::BindGroup,
169 workgroups_x: u32,
170 workgroups_y: u32,
171 workgroups_z: u32,
172 ) {
173 self.dispatch_multi(
174 device,
175 queue,
176 &[bind_group],
177 workgroups_x,
178 workgroups_y,
179 workgroups_z,
180 );
181 }
182
183 pub fn dispatch_multi(
188 &self,
189 device: &wgpu::Device,
190 queue: &wgpu::Queue,
191 bind_groups: &[&wgpu::BindGroup],
192 workgroups_x: u32,
193 workgroups_y: u32,
194 workgroups_z: u32,
195 ) {
196 tracing::debug!(
197 workgroups_x,
198 workgroups_y,
199 workgroups_z,
200 groups = bind_groups.len(),
201 "compute dispatch"
202 );
203 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
204 label: Some("compute_encoder"),
205 });
206
207 self.encode_dispatch_multi(
208 &mut encoder,
209 bind_groups,
210 workgroups_x,
211 workgroups_y,
212 workgroups_z,
213 );
214
215 queue.submit(std::iter::once(encoder.finish()));
216 }
217
218 pub fn encode_dispatch(
220 &self,
221 encoder: &mut wgpu::CommandEncoder,
222 bind_group: &wgpu::BindGroup,
223 workgroups_x: u32,
224 workgroups_y: u32,
225 workgroups_z: u32,
226 ) {
227 self.encode_dispatch_multi(
228 encoder,
229 &[bind_group],
230 workgroups_x,
231 workgroups_y,
232 workgroups_z,
233 );
234 }
235
236 pub fn encode_dispatch_multi(
238 &self,
239 encoder: &mut wgpu::CommandEncoder,
240 bind_groups: &[&wgpu::BindGroup],
241 workgroups_x: u32,
242 workgroups_y: u32,
243 workgroups_z: u32,
244 ) {
245 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
246 label: Some("compute_pass"),
247 timestamp_writes: None,
248 });
249 pass.set_pipeline(&self.pipeline);
250 for (i, bg) in bind_groups.iter().enumerate() {
251 pass.set_bind_group(i as u32, *bg, &[]);
252 }
253 pass.dispatch_workgroups(workgroups_x, workgroups_y, workgroups_z);
254 }
255
256 pub fn encode_dispatch_indirect(
261 &self,
262 encoder: &mut wgpu::CommandEncoder,
263 bind_groups: &[&wgpu::BindGroup],
264 indirect_buffer: &wgpu::Buffer,
265 indirect_offset: u64,
266 ) {
267 tracing::debug!(indirect_offset, "compute indirect dispatch");
268 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
269 label: Some("compute_pass_indirect"),
270 timestamp_writes: None,
271 });
272 pass.set_pipeline(&self.pipeline);
273 for (i, bg) in bind_groups.iter().enumerate() {
274 pass.set_bind_group(i as u32, *bg, &[]);
275 }
276 pass.dispatch_workgroups_indirect(indirect_buffer, indirect_offset);
277 }
278}
279
280pub struct PingPongBuffer {
297 buffers: [wgpu::Buffer; 2],
298 current: usize,
299}
300
301impl PingPongBuffer {
302 pub fn new(device: &wgpu::Device, size: u64, label: &str) -> Self {
304 tracing::debug!(size, label, "creating ping-pong buffer pair");
305 let buffers = [
306 device.create_buffer(&wgpu::BufferDescriptor {
307 label: Some(&format!("{label}_a")),
308 size,
309 usage: wgpu::BufferUsages::STORAGE
310 | wgpu::BufferUsages::COPY_DST
311 | wgpu::BufferUsages::COPY_SRC,
312 mapped_at_creation: false,
313 }),
314 device.create_buffer(&wgpu::BufferDescriptor {
315 label: Some(&format!("{label}_b")),
316 size,
317 usage: wgpu::BufferUsages::STORAGE
318 | wgpu::BufferUsages::COPY_DST
319 | wgpu::BufferUsages::COPY_SRC,
320 mapped_at_creation: false,
321 }),
322 ];
323 Self {
324 buffers,
325 current: 0,
326 }
327 }
328
329 #[must_use]
331 #[inline]
332 pub fn source(&self) -> &wgpu::Buffer {
333 &self.buffers[self.current]
334 }
335
336 #[must_use]
338 #[inline]
339 pub fn dest(&self) -> &wgpu::Buffer {
340 &self.buffers[1 - self.current]
341 }
342
343 #[inline]
345 pub fn swap(&mut self) {
346 self.current = 1 - self.current;
347 }
348
349 #[must_use]
351 #[inline]
352 pub fn index(&self) -> usize {
353 self.current
354 }
355}
356
357pub fn validate_dispatch(
362 limits: &wgpu::Limits,
363 workgroups_x: u32,
364 workgroups_y: u32,
365 workgroups_z: u32,
366) -> crate::error::Result<()> {
367 use crate::error::GpuError;
368 let max = limits.max_compute_workgroups_per_dimension;
369 if workgroups_x > max {
370 return Err(GpuError::WorkgroupLimitExceeded {
371 axis: "x",
372 actual: workgroups_x,
373 limit: max,
374 });
375 }
376 if workgroups_y > max {
377 return Err(GpuError::WorkgroupLimitExceeded {
378 axis: "y",
379 actual: workgroups_y,
380 limit: max,
381 });
382 }
383 if workgroups_z > max {
384 return Err(GpuError::WorkgroupLimitExceeded {
385 axis: "z",
386 actual: workgroups_z,
387 limit: max,
388 });
389 }
390 Ok(())
391}
392
393#[must_use]
397#[inline]
398pub fn workgroups_1d(total: u32, workgroup_size: u32) -> u32 {
399 total.div_ceil(workgroup_size)
400}
401
402#[must_use]
406#[inline]
407pub fn workgroups_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> (u32, u32) {
408 (width.div_ceil(wg_x), height.div_ceil(wg_y))
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn compute_pipeline_types() {
417 let _size = std::mem::size_of::<ComputePipeline>();
418 }
419
420 #[test]
421 fn workgroups_1d_exact() {
422 assert_eq!(workgroups_1d(256, 256), 1);
423 assert_eq!(workgroups_1d(512, 256), 2);
424 }
425
426 #[test]
427 fn workgroups_1d_remainder() {
428 assert_eq!(workgroups_1d(257, 256), 2);
429 assert_eq!(workgroups_1d(1, 256), 1);
430 }
431
432 #[test]
433 fn workgroups_2d_exact() {
434 assert_eq!(workgroups_2d(32, 32, 16, 16), (2, 2));
435 }
436
437 #[test]
438 fn workgroups_2d_remainder() {
439 assert_eq!(workgroups_2d(33, 17, 16, 16), (3, 2));
440 }
441
442 #[test]
443 fn workgroups_1d_single() {
444 assert_eq!(workgroups_1d(1, 64), 1);
445 assert_eq!(workgroups_1d(0, 64), 0);
446 }
447
448 #[test]
449 fn workgroups_2d_single() {
450 assert_eq!(workgroups_2d(1, 1, 8, 8), (1, 1));
451 assert_eq!(workgroups_2d(0, 0, 8, 8), (0, 0));
452 }
453
454 #[test]
455 fn validate_dispatch_within_limits() {
456 let limits = wgpu::Limits {
457 max_compute_workgroups_per_dimension: 65535,
458 ..Default::default()
459 };
460 assert!(validate_dispatch(&limits, 100, 100, 1).is_ok());
461 assert!(validate_dispatch(&limits, 65535, 65535, 65535).is_ok());
462 }
463
464 #[test]
465 fn validate_dispatch_exceeds_limits() {
466 let limits = wgpu::Limits {
467 max_compute_workgroups_per_dimension: 65535,
468 ..Default::default()
469 };
470 assert!(validate_dispatch(&limits, 65536, 1, 1).is_err());
471 assert!(validate_dispatch(&limits, 1, 65536, 1).is_err());
472 assert!(validate_dispatch(&limits, 1, 1, 65536).is_err());
473 }
474
475 #[test]
476 fn validate_dispatch_error_contains_axis() {
477 let limits = wgpu::Limits {
478 max_compute_workgroups_per_dimension: 100,
479 ..Default::default()
480 };
481 let err = validate_dispatch(&limits, 200, 1, 1).unwrap_err();
482 assert!(err.to_string().contains("x"));
483 let err = validate_dispatch(&limits, 1, 200, 1).unwrap_err();
484 assert!(err.to_string().contains("y"));
485 }
486
487 #[test]
488 fn workgroups_1d_large() {
489 assert_eq!(workgroups_1d(1_000_000, 256), 3907);
490 assert_eq!(workgroups_1d(u32::MAX, 256), 16_777_216);
491 }
492
493 #[test]
494 fn ping_pong_swap() {
495 let mut current = 0usize;
497 assert_eq!(current, 0);
498 assert_eq!(1 - current, 1);
499 current = 1 - current;
500 assert_eq!(current, 1);
501 assert_eq!(1 - current, 0);
502 current = 1 - current;
503 assert_eq!(current, 0);
504 }
505
506 #[test]
507 fn ping_pong_types() {
508 let _size = std::mem::size_of::<PingPongBuffer>();
509 }
510
511 fn try_gpu() -> Option<(wgpu::Device, wgpu::Queue)> {
512 let ctx = pollster::block_on(crate::context::GpuContext::new()).ok()?;
513 Some((ctx.device, ctx.queue))
514 }
515
516 const DOUBLE_SHADER: &str = r#"
517 @group(0) @binding(0) var<storage, read_write> output: array<f32>;
518 @group(0) @binding(1) var<storage, read> input: array<f32>;
519
520 @compute @workgroup_size(64)
521 fn main(@builtin(global_invocation_id) id: vec3u) {
522 if id.x < arrayLength(&input) {
523 output[id.x] = input[id.x] * 2.0;
524 }
525 }
526 "#;
527
528 #[test]
529 fn gpu_compute_pipeline_create() {
530 let Some((device, _queue)) = try_gpu() else {
531 return;
532 };
533 let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
534 assert_eq!(pipeline.bind_group_layout_count(), 1);
535 assert!(pipeline.bind_group_layout(0).is_some());
536 assert!(pipeline.bind_group_layout(1).is_none());
537 }
538
539 #[test]
540 fn gpu_compute_dispatch_roundtrip() {
541 let Some((device, queue)) = try_gpu() else {
542 return;
543 };
544 let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
545
546 let input: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
547 let input_buf = crate::buffer::create_storage_buffer(
548 &device,
549 bytemuck::cast_slice(&input),
550 "input",
551 true,
552 );
553 let output_buf = crate::buffer::create_storage_buffer_empty(&device, 16, "output", false);
554
555 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
556 label: Some("test_bg"),
557 layout: pipeline.bind_group_layout(0).unwrap(),
558 entries: &[
559 wgpu::BindGroupEntry {
560 binding: 0,
561 resource: output_buf.as_entire_binding(),
562 },
563 wgpu::BindGroupEntry {
564 binding: 1,
565 resource: input_buf.as_entire_binding(),
566 },
567 ],
568 });
569
570 pipeline.dispatch(&device, &queue, &bind_group, 1, 1, 1);
571
572 let result: Vec<f32> =
573 crate::buffer::read_buffer_typed(&device, &queue, &output_buf, 4).unwrap();
574 assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0]);
575 }
576
577 #[test]
578 fn gpu_ping_pong_buffer() {
579 let Some((device, _queue)) = try_gpu() else {
580 return;
581 };
582 let mut pp = PingPongBuffer::new(&device, 64, "pp_test");
583 assert_eq!(pp.index(), 0);
584 let src0 = pp.source() as *const _;
585 let dst0 = pp.dest() as *const _;
586 pp.swap();
587 assert_eq!(pp.index(), 1);
588 assert_eq!(src0, pp.dest() as *const _);
590 assert_eq!(dst0, pp.source() as *const _);
591 }
592}