Skip to main content

mabda/
compute.rs

1//! Compute shader pipeline — general-purpose GPU compute.
2//!
3//! Wraps `wgpu::ComputePipeline` with bind group layout management and
4//! dispatch helpers. Supports single or multiple bind group layouts.
5
6/// A compute pipeline wrapping `wgpu::ComputePipeline` with bind group management.
7///
8/// The default layout ([`new`](Self::new)) creates storage buffer bindings:
9/// buffer 0 is read-write (output), buffers 1+ are read-only (inputs).
10///
11/// Use [`with_layout`](Self::with_layout) for a single custom bind group,
12/// or [`with_layouts`](Self::with_layouts) for multiple bind groups
13/// (e.g., storage buffers + uniform buffers + textures).
14///
15/// # Examples
16///
17/// ```ignore
18/// use mabda::compute::ComputePipeline;
19///
20/// // 2-buffer pipeline (1 output + 1 input):
21/// let pipeline = ComputePipeline::new(&device, WGSL, "main", 2);
22/// pipeline.dispatch(&device, &queue, &bind_group, [64, 1, 1]);
23/// ```
24pub struct ComputePipeline {
25    pipeline: wgpu::ComputePipeline,
26    bind_group_layouts: Vec<wgpu::BindGroupLayout>,
27}
28
29impl ComputePipeline {
30    /// Create a compute pipeline from WGSL source code.
31    ///
32    /// `entry_point`: the compute shader entry function name.
33    /// `buffer_count`: number of storage buffers in bind group 0 (bindings 0..n).
34    ///
35    /// Buffer 0 is created as read-write (`read_only: false`) and buffers 1+
36    /// are read-only. This matches the common pattern where a single output
37    /// buffer is written by the shader while additional input buffers are
38    /// consumed without modification.
39    pub fn new(
40        device: &wgpu::Device,
41        wgsl_source: &str,
42        entry_point: &str,
43        buffer_count: u32,
44    ) -> Self {
45        let entries: Vec<wgpu::BindGroupLayoutEntry> = (0..buffer_count)
46            .map(|i| wgpu::BindGroupLayoutEntry {
47                binding: i,
48                visibility: wgpu::ShaderStages::COMPUTE,
49                ty: wgpu::BindingType::Buffer {
50                    ty: wgpu::BufferBindingType::Storage { read_only: i > 0 },
51                    has_dynamic_offset: false,
52                    min_binding_size: None,
53                },
54                count: None,
55            })
56            .collect();
57
58        Self::with_layout(device, wgsl_source, entry_point, &entries)
59    }
60
61    /// Create a compute pipeline with a single custom bind group layout.
62    ///
63    /// Use this when you need uniform buffers, mixed read-write patterns,
64    /// or texture bindings alongside storage buffers.
65    pub fn with_layout(
66        device: &wgpu::Device,
67        wgsl_source: &str,
68        entry_point: &str,
69        entries: &[wgpu::BindGroupLayoutEntry],
70    ) -> Self {
71        Self::with_layouts(device, wgsl_source, entry_point, &[entries])
72    }
73
74    /// Create a compute pipeline with multiple bind group layouts.
75    ///
76    /// Each element in `groups` defines the entries for one bind group.
77    /// Group 0 is the first, group 1 the second, etc.
78    ///
79    /// Use this when your shader needs separate bind groups for different
80    /// resource types (e.g., group 0 for storage buffers, group 1 for
81    /// uniform buffers, group 2 for textures).
82    pub fn with_layouts(
83        device: &wgpu::Device,
84        wgsl_source: &str,
85        entry_point: &str,
86        groups: &[&[wgpu::BindGroupLayoutEntry]],
87    ) -> Self {
88        tracing::debug!(
89            entry_point,
90            groups = groups.len(),
91            "creating compute pipeline"
92        );
93        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
94            label: Some("compute_shader"),
95            source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
96        });
97
98        let bind_group_layouts: Vec<wgpu::BindGroupLayout> = groups
99            .iter()
100            .enumerate()
101            .map(|(i, entries)| {
102                use std::fmt::Write;
103                let mut label = String::with_capacity(20);
104                let _ = write!(label, "compute_layout_{i}");
105                device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106                    label: Some(&label),
107                    entries,
108                })
109            })
110            .collect();
111
112        let layout_refs: Vec<Option<&wgpu::BindGroupLayout>> =
113            bind_group_layouts.iter().map(Some).collect();
114
115        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
116            label: Some("compute_pipeline_layout"),
117            bind_group_layouts: &layout_refs,
118            immediate_size: 0,
119        });
120
121        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
122            label: Some("compute_pipeline"),
123            layout: Some(&pipeline_layout),
124            module: &shader,
125            entry_point: Some(entry_point),
126            compilation_options: wgpu::PipelineCompilationOptions::default(),
127            cache: None,
128        });
129
130        Self {
131            pipeline,
132            bind_group_layouts,
133        }
134    }
135
136    /// Get bind group layout by index.
137    ///
138    /// For pipelines created with [`new`](Self::new) or [`with_layout`](Self::with_layout),
139    /// only index 0 is valid.
140    #[must_use]
141    #[inline]
142    pub fn bind_group_layout(&self, index: usize) -> Option<&wgpu::BindGroupLayout> {
143        self.bind_group_layouts.get(index)
144    }
145
146    /// Number of bind group layouts in this pipeline.
147    #[must_use]
148    #[inline]
149    pub fn bind_group_layout_count(&self) -> usize {
150        self.bind_group_layouts.len()
151    }
152
153    /// Get the underlying wgpu compute pipeline.
154    #[must_use]
155    #[inline]
156    pub fn raw(&self) -> &wgpu::ComputePipeline {
157        &self.pipeline
158    }
159
160    /// Dispatch the compute shader with a single bind group.
161    ///
162    /// Creates a command encoder, runs one compute pass, and submits.
163    /// For batched dispatches, use [`encode_dispatch`](Self::encode_dispatch).
164    pub fn dispatch(
165        &self,
166        device: &wgpu::Device,
167        queue: &wgpu::Queue,
168        bind_group: &wgpu::BindGroup,
169        workgroups_x: u32,
170        workgroups_y: u32,
171        workgroups_z: u32,
172    ) {
173        self.dispatch_multi(
174            device,
175            queue,
176            &[bind_group],
177            workgroups_x,
178            workgroups_y,
179            workgroups_z,
180        );
181    }
182
183    /// Dispatch the compute shader with multiple bind groups.
184    ///
185    /// Creates a command encoder, runs one compute pass, and submits.
186    /// Each bind group is set at its corresponding index (0, 1, 2, ...).
187    pub fn dispatch_multi(
188        &self,
189        device: &wgpu::Device,
190        queue: &wgpu::Queue,
191        bind_groups: &[&wgpu::BindGroup],
192        workgroups_x: u32,
193        workgroups_y: u32,
194        workgroups_z: u32,
195    ) {
196        tracing::debug!(
197            workgroups_x,
198            workgroups_y,
199            workgroups_z,
200            groups = bind_groups.len(),
201            "compute dispatch"
202        );
203        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
204            label: Some("compute_encoder"),
205        });
206
207        self.encode_dispatch_multi(
208            &mut encoder,
209            bind_groups,
210            workgroups_x,
211            workgroups_y,
212            workgroups_z,
213        );
214
215        queue.submit(std::iter::once(encoder.finish()));
216    }
217
218    /// Encode a compute dispatch with a single bind group into an existing encoder.
219    pub fn encode_dispatch(
220        &self,
221        encoder: &mut wgpu::CommandEncoder,
222        bind_group: &wgpu::BindGroup,
223        workgroups_x: u32,
224        workgroups_y: u32,
225        workgroups_z: u32,
226    ) {
227        self.encode_dispatch_multi(
228            encoder,
229            &[bind_group],
230            workgroups_x,
231            workgroups_y,
232            workgroups_z,
233        );
234    }
235
236    /// Encode a compute dispatch with multiple bind groups into an existing encoder.
237    pub fn encode_dispatch_multi(
238        &self,
239        encoder: &mut wgpu::CommandEncoder,
240        bind_groups: &[&wgpu::BindGroup],
241        workgroups_x: u32,
242        workgroups_y: u32,
243        workgroups_z: u32,
244    ) {
245        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
246            label: Some("compute_pass"),
247            timestamp_writes: None,
248        });
249        pass.set_pipeline(&self.pipeline);
250        for (i, bg) in bind_groups.iter().enumerate() {
251            pass.set_bind_group(i as u32, *bg, &[]);
252        }
253        pass.dispatch_workgroups(workgroups_x, workgroups_y, workgroups_z);
254    }
255
256    /// Encode an indirect compute dispatch into an existing encoder.
257    ///
258    /// The `indirect_buffer` must contain a `DispatchIndirect` struct
259    /// (3 × u32: workgroups_x, workgroups_y, workgroups_z) at `indirect_offset`.
260    pub fn encode_dispatch_indirect(
261        &self,
262        encoder: &mut wgpu::CommandEncoder,
263        bind_groups: &[&wgpu::BindGroup],
264        indirect_buffer: &wgpu::Buffer,
265        indirect_offset: u64,
266    ) {
267        tracing::debug!(indirect_offset, "compute indirect dispatch");
268        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
269            label: Some("compute_pass_indirect"),
270            timestamp_writes: None,
271        });
272        pass.set_pipeline(&self.pipeline);
273        for (i, bg) in bind_groups.iter().enumerate() {
274            pass.set_bind_group(i as u32, *bg, &[]);
275        }
276        pass.dispatch_workgroups_indirect(indirect_buffer, indirect_offset);
277    }
278}
279
280/// A double-buffer pair for iterative compute patterns (ping-pong).
281///
282/// Common in FDTD simulation, iterative blur, fluid simulation, and any
283/// algorithm that reads from one buffer and writes to another, then swaps.
284///
285/// # Examples
286///
287/// ```ignore
288/// use mabda::compute::PingPongBuffer;
289///
290/// let mut pp = PingPongBuffer::new(&device, &initial_data, "simulation");
291/// for _ in 0..100 {
292///     pipeline.dispatch_with(pp.read(), pp.write(), workgroups);
293///     pp.swap();
294/// }
295/// ```
296pub struct PingPongBuffer {
297    buffers: [wgpu::Buffer; 2],
298    current: usize,
299}
300
301impl PingPongBuffer {
302    /// Create a ping-pong buffer pair, each with `size` bytes.
303    pub fn new(device: &wgpu::Device, size: u64, label: &str) -> Self {
304        tracing::debug!(size, label, "creating ping-pong buffer pair");
305        let buffers = [
306            device.create_buffer(&wgpu::BufferDescriptor {
307                label: Some(&format!("{label}_a")),
308                size,
309                usage: wgpu::BufferUsages::STORAGE
310                    | wgpu::BufferUsages::COPY_DST
311                    | wgpu::BufferUsages::COPY_SRC,
312                mapped_at_creation: false,
313            }),
314            device.create_buffer(&wgpu::BufferDescriptor {
315                label: Some(&format!("{label}_b")),
316                size,
317                usage: wgpu::BufferUsages::STORAGE
318                    | wgpu::BufferUsages::COPY_DST
319                    | wgpu::BufferUsages::COPY_SRC,
320                mapped_at_creation: false,
321            }),
322        ];
323        Self {
324            buffers,
325            current: 0,
326        }
327    }
328
329    /// The buffer to read from (current source).
330    #[must_use]
331    #[inline]
332    pub fn source(&self) -> &wgpu::Buffer {
333        &self.buffers[self.current]
334    }
335
336    /// The buffer to write to (current destination).
337    #[must_use]
338    #[inline]
339    pub fn dest(&self) -> &wgpu::Buffer {
340        &self.buffers[1 - self.current]
341    }
342
343    /// Swap source and destination buffers.
344    #[inline]
345    pub fn swap(&mut self) {
346        self.current = 1 - self.current;
347    }
348
349    /// Current iteration index (0 or 1).
350    #[must_use]
351    #[inline]
352    pub fn index(&self) -> usize {
353        self.current
354    }
355}
356
357/// Validate workgroup counts against device limits.
358///
359/// Returns `Err(GpuError::WorkgroupLimitExceeded)` if any dimension exceeds
360/// `max_compute_workgroups_per_dimension`.
361pub fn validate_dispatch(
362    limits: &wgpu::Limits,
363    workgroups_x: u32,
364    workgroups_y: u32,
365    workgroups_z: u32,
366) -> crate::error::Result<()> {
367    use crate::error::GpuError;
368    let max = limits.max_compute_workgroups_per_dimension;
369    if workgroups_x > max {
370        return Err(GpuError::WorkgroupLimitExceeded {
371            axis: "x",
372            actual: workgroups_x,
373            limit: max,
374        });
375    }
376    if workgroups_y > max {
377        return Err(GpuError::WorkgroupLimitExceeded {
378            axis: "y",
379            actual: workgroups_y,
380            limit: max,
381        });
382    }
383    if workgroups_z > max {
384        return Err(GpuError::WorkgroupLimitExceeded {
385            axis: "z",
386            actual: workgroups_z,
387            limit: max,
388        });
389    }
390    Ok(())
391}
392
393/// Calculate workgroup count for a 1D dispatch.
394///
395/// Returns `ceil(total / workgroup_size)`.
396#[must_use]
397#[inline]
398pub fn workgroups_1d(total: u32, workgroup_size: u32) -> u32 {
399    total.div_ceil(workgroup_size)
400}
401
402/// Calculate workgroup counts for a 2D dispatch.
403///
404/// Returns `(ceil(width / wg_x), ceil(height / wg_y))`.
405#[must_use]
406#[inline]
407pub fn workgroups_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> (u32, u32) {
408    (width.div_ceil(wg_x), height.div_ceil(wg_y))
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn compute_pipeline_types() {
417        let _size = std::mem::size_of::<ComputePipeline>();
418    }
419
420    #[test]
421    fn workgroups_1d_exact() {
422        assert_eq!(workgroups_1d(256, 256), 1);
423        assert_eq!(workgroups_1d(512, 256), 2);
424    }
425
426    #[test]
427    fn workgroups_1d_remainder() {
428        assert_eq!(workgroups_1d(257, 256), 2);
429        assert_eq!(workgroups_1d(1, 256), 1);
430    }
431
432    #[test]
433    fn workgroups_2d_exact() {
434        assert_eq!(workgroups_2d(32, 32, 16, 16), (2, 2));
435    }
436
437    #[test]
438    fn workgroups_2d_remainder() {
439        assert_eq!(workgroups_2d(33, 17, 16, 16), (3, 2));
440    }
441
442    #[test]
443    fn workgroups_1d_single() {
444        assert_eq!(workgroups_1d(1, 64), 1);
445        assert_eq!(workgroups_1d(0, 64), 0);
446    }
447
448    #[test]
449    fn workgroups_2d_single() {
450        assert_eq!(workgroups_2d(1, 1, 8, 8), (1, 1));
451        assert_eq!(workgroups_2d(0, 0, 8, 8), (0, 0));
452    }
453
454    #[test]
455    fn validate_dispatch_within_limits() {
456        let limits = wgpu::Limits {
457            max_compute_workgroups_per_dimension: 65535,
458            ..Default::default()
459        };
460        assert!(validate_dispatch(&limits, 100, 100, 1).is_ok());
461        assert!(validate_dispatch(&limits, 65535, 65535, 65535).is_ok());
462    }
463
464    #[test]
465    fn validate_dispatch_exceeds_limits() {
466        let limits = wgpu::Limits {
467            max_compute_workgroups_per_dimension: 65535,
468            ..Default::default()
469        };
470        assert!(validate_dispatch(&limits, 65536, 1, 1).is_err());
471        assert!(validate_dispatch(&limits, 1, 65536, 1).is_err());
472        assert!(validate_dispatch(&limits, 1, 1, 65536).is_err());
473    }
474
475    #[test]
476    fn validate_dispatch_error_contains_axis() {
477        let limits = wgpu::Limits {
478            max_compute_workgroups_per_dimension: 100,
479            ..Default::default()
480        };
481        let err = validate_dispatch(&limits, 200, 1, 1).unwrap_err();
482        assert!(err.to_string().contains("x"));
483        let err = validate_dispatch(&limits, 1, 200, 1).unwrap_err();
484        assert!(err.to_string().contains("y"));
485    }
486
487    #[test]
488    fn workgroups_1d_large() {
489        assert_eq!(workgroups_1d(1_000_000, 256), 3907);
490        assert_eq!(workgroups_1d(u32::MAX, 256), 16_777_216);
491    }
492
493    #[test]
494    fn ping_pong_swap() {
495        // Verify swap logic without GPU
496        let mut current = 0usize;
497        assert_eq!(current, 0);
498        assert_eq!(1 - current, 1);
499        current = 1 - current;
500        assert_eq!(current, 1);
501        assert_eq!(1 - current, 0);
502        current = 1 - current;
503        assert_eq!(current, 0);
504    }
505
506    #[test]
507    fn ping_pong_types() {
508        let _size = std::mem::size_of::<PingPongBuffer>();
509    }
510
511    fn try_gpu() -> Option<(wgpu::Device, wgpu::Queue)> {
512        let ctx = pollster::block_on(crate::context::GpuContext::new()).ok()?;
513        Some((ctx.device, ctx.queue))
514    }
515
516    const DOUBLE_SHADER: &str = r#"
517        @group(0) @binding(0) var<storage, read_write> output: array<f32>;
518        @group(0) @binding(1) var<storage, read> input: array<f32>;
519
520        @compute @workgroup_size(64)
521        fn main(@builtin(global_invocation_id) id: vec3u) {
522            if id.x < arrayLength(&input) {
523                output[id.x] = input[id.x] * 2.0;
524            }
525        }
526    "#;
527
528    #[test]
529    fn gpu_compute_pipeline_create() {
530        let Some((device, _queue)) = try_gpu() else {
531            return;
532        };
533        let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
534        assert_eq!(pipeline.bind_group_layout_count(), 1);
535        assert!(pipeline.bind_group_layout(0).is_some());
536        assert!(pipeline.bind_group_layout(1).is_none());
537    }
538
539    #[test]
540    fn gpu_compute_dispatch_roundtrip() {
541        let Some((device, queue)) = try_gpu() else {
542            return;
543        };
544        let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
545
546        let input: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
547        let input_buf = crate::buffer::create_storage_buffer(
548            &device,
549            bytemuck::cast_slice(&input),
550            "input",
551            true,
552        );
553        let output_buf = crate::buffer::create_storage_buffer_empty(&device, 16, "output", false);
554
555        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
556            label: Some("test_bg"),
557            layout: pipeline.bind_group_layout(0).unwrap(),
558            entries: &[
559                wgpu::BindGroupEntry {
560                    binding: 0,
561                    resource: output_buf.as_entire_binding(),
562                },
563                wgpu::BindGroupEntry {
564                    binding: 1,
565                    resource: input_buf.as_entire_binding(),
566                },
567            ],
568        });
569
570        pipeline.dispatch(&device, &queue, &bind_group, 1, 1, 1);
571
572        let result: Vec<f32> =
573            crate::buffer::read_buffer_typed(&device, &queue, &output_buf, 4).unwrap();
574        assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0]);
575    }
576
577    #[test]
578    fn gpu_ping_pong_buffer() {
579        let Some((device, _queue)) = try_gpu() else {
580            return;
581        };
582        let mut pp = PingPongBuffer::new(&device, 64, "pp_test");
583        assert_eq!(pp.index(), 0);
584        let src0 = pp.source() as *const _;
585        let dst0 = pp.dest() as *const _;
586        pp.swap();
587        assert_eq!(pp.index(), 1);
588        // After swap, source/dest are swapped
589        assert_eq!(src0, pp.dest() as *const _);
590        assert_eq!(dst0, pp.source() as *const _);
591    }
592}