Skip to main content

astrelis_render/
compute.rs

1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::context::GraphicsContext;
9use crate::frame::Frame;
10
11/// Builder for creating compute passes.
12///
13/// # Example
14///
15/// ```ignore
16/// let mut compute_pass = frame.compute_pass()
17///     .label("My Compute Pass")
18///     .build();
19///
20/// compute_pass.set_pipeline(&pipeline);
21/// compute_pass.set_bind_group(0, &bind_group, &[]);
22/// compute_pass.dispatch_workgroups(64, 64, 1);
23/// ```
24pub struct ComputePassBuilder<'f, 'w> {
25    frame: &'f Frame<'w>,
26    label: Option<String>,
27}
28
29impl<'f, 'w> ComputePassBuilder<'f, 'w> {
30    /// Create a new compute pass builder.
31    pub fn new(frame: &'f Frame<'w>) -> Self {
32        Self { frame, label: None }
33    }
34
35    /// Set a debug label for the compute pass.
36    pub fn label(mut self, label: impl Into<String>) -> Self {
37        self.label = Some(label.into());
38        self
39    }
40
41    /// Build the compute pass.
42    ///
43    /// This creates a new command encoder for this compute pass and returns it
44    /// when the ComputePass is dropped.
45    pub fn build(self) -> ComputePass<'f> {
46        profile_function!();
47
48        let label = self.label.clone();
49        let label_str = label.as_deref();
50
51        // Create encoder for this pass
52        let encoder = self
53            .frame
54            .device()
55            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: label_str });
56
57        let mut encoder = encoder;
58        let compute_pass = encoder
59            .begin_compute_pass(&wgpu::ComputePassDescriptor {
60                label: label_str,
61                timestamp_writes: None,
62            })
63            .forget_lifetime();
64
65        ComputePass {
66            frame: self.frame,
67            encoder: Some(encoder),
68            pass: Some(compute_pass),
69        }
70    }
71}
72
73impl Default for ComputePassBuilder<'_, '_> {
74    fn default() -> Self {
75        unimplemented!("ComputePassBuilder requires a Frame reference")
76    }
77}
78
79/// A compute pass wrapper that automatically returns the command buffer to the frame.
80///
81/// This struct mirrors `RenderPass` in its lifecycle management - it owns its
82/// encoder and pushes the command buffer to the frame when dropped.
83pub struct ComputePass<'f> {
84    /// Reference to the frame (for pushing command buffer on drop).
85    frame: &'f Frame<'f>,
86    /// The command encoder (owned by this pass).
87    pub(crate) encoder: Option<wgpu::CommandEncoder>,
88    /// The active wgpu compute pass.
89    pub(crate) pass: Option<wgpu::ComputePass<'static>>,
90}
91
92impl<'f> ComputePass<'f> {
93    /// Get the underlying wgpu compute pass (mutable).
94    pub fn wgpu_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
95        self.pass.as_mut().expect("ComputePass already consumed")
96    }
97
98    /// Get the underlying wgpu compute pass (immutable).
99    pub fn wgpu_pass_ref(&self) -> &wgpu::ComputePass<'static> {
100        self.pass.as_ref().expect("ComputePass already consumed")
101    }
102
103    /// Get raw access to the underlying wgpu compute pass.
104    ///
105    /// This is an alias for [`wgpu_pass()`](Self::wgpu_pass) for consistency with `RenderPass::raw_pass()`.
106    pub fn raw_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
107        self.pass.as_mut().expect("ComputePass already consumed")
108    }
109
110    /// Get the graphics context.
111    pub fn graphics(&self) -> &GraphicsContext {
112        self.frame.graphics()
113    }
114
115    /// Set the compute pipeline to use.
116    pub fn set_pipeline(&mut self, pipeline: &wgpu::ComputePipeline) {
117        self.wgpu_pass().set_pipeline(pipeline);
118    }
119
120    /// Set a bind group.
121    pub fn set_bind_group(&mut self, index: u32, bind_group: &wgpu::BindGroup, offsets: &[u32]) {
122        self.wgpu_pass().set_bind_group(index, bind_group, offsets);
123    }
124
125    /// Dispatch workgroups.
126    ///
127    /// # Arguments
128    ///
129    /// * `x` - Number of workgroups in the X dimension
130    /// * `y` - Number of workgroups in the Y dimension
131    /// * `z` - Number of workgroups in the Z dimension
132    pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
133        self.wgpu_pass().dispatch_workgroups(x, y, z);
134    }
135
136    /// Dispatch workgroups with a 1D configuration.
137    ///
138    /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
139    pub fn dispatch_workgroups_1d(&mut self, x: u32) {
140        self.dispatch_workgroups(x, 1, 1);
141    }
142
143    /// Dispatch workgroups with a 2D configuration.
144    ///
145    /// Equivalent to `dispatch_workgroups(x, y, 1)`.
146    pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
147        self.dispatch_workgroups(x, y, 1);
148    }
149
150    /// Dispatch workgroups indirectly from a buffer.
151    ///
152    /// The buffer should contain a `DispatchIndirect` struct:
153    /// ```ignore
154    /// #[repr(C)]
155    /// struct DispatchIndirect {
156    ///     x: u32,
157    ///     y: u32,
158    ///     z: u32,
159    /// }
160    /// ```
161    pub fn dispatch_workgroups_indirect(&mut self, buffer: &wgpu::Buffer, offset: u64) {
162        self.wgpu_pass()
163            .dispatch_workgroups_indirect(buffer, offset);
164    }
165
166    /// Insert a debug marker.
167    pub fn insert_debug_marker(&mut self, label: &str) {
168        self.wgpu_pass().insert_debug_marker(label);
169    }
170
171    /// Push a debug group.
172    pub fn push_debug_group(&mut self, label: &str) {
173        self.wgpu_pass().push_debug_group(label);
174    }
175
176    /// Pop a debug group.
177    pub fn pop_debug_group(&mut self) {
178        self.wgpu_pass().pop_debug_group();
179    }
180
181    /// Set push constants for the compute shader.
182    ///
183    /// Push constants are a fast way to pass small amounts of data to shaders
184    /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
185    /// to be enabled on the device.
186    ///
187    /// # Arguments
188    ///
189    /// * `offset` - Byte offset into the push constant range
190    /// * `data` - Data to set (must be `Pod` for safe byte casting)
191    ///
192    /// # Example
193    ///
194    /// ```ignore
195    /// #[repr(C)]
196    /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
197    /// struct ComputeConstants {
198    ///     workgroup_count: u32,
199    ///     time: f32,
200    /// }
201    ///
202    /// let constants = ComputeConstants {
203    ///     workgroup_count: 64,
204    ///     time: 1.5,
205    /// };
206    ///
207    /// pass.set_push_constants(0, &constants);
208    /// ```
209    pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
210        self.wgpu_pass()
211            .set_push_constants(offset, bytemuck::bytes_of(data));
212    }
213
214    /// Set push constants from raw bytes.
215    ///
216    /// Use this when you need more control over the data layout.
217    pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
218        self.wgpu_pass().set_push_constants(offset, data);
219    }
220
221    /// Finish the compute pass, pushing the command buffer to the frame.
222    pub fn finish(self) {
223        drop(self);
224    }
225}
226
227impl Drop for ComputePass<'_> {
228    fn drop(&mut self) {
229        profile_function!();
230
231        // End the compute pass
232        drop(self.pass.take());
233
234        // Finish encoder and push command buffer to frame
235        if let Some(encoder) = self.encoder.take() {
236            let command_buffer = encoder.finish();
237            self.frame.command_buffers.borrow_mut().push(command_buffer);
238        }
239    }
240}
241
242/// Indirect dispatch command.
243///
244/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
245#[repr(C)]
246#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
247pub struct DispatchIndirect {
248    /// Number of workgroups in the X dimension.
249    pub x: u32,
250    /// Number of workgroups in the Y dimension.
251    pub y: u32,
252    /// Number of workgroups in the Z dimension.
253    pub z: u32,
254}
255
256// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
257unsafe impl bytemuck::Pod for DispatchIndirect {}
258unsafe impl bytemuck::Zeroable for DispatchIndirect {}
259
260impl DispatchIndirect {
261    /// Create a new dispatch command.
262    pub const fn new(x: u32, y: u32, z: u32) -> Self {
263        Self { x, y, z }
264    }
265
266    /// Create a 1D dispatch command.
267    pub const fn new_1d(x: u32) -> Self {
268        Self::new(x, 1, 1)
269    }
270
271    /// Create a 2D dispatch command.
272    pub const fn new_2d(x: u32, y: u32) -> Self {
273        Self::new(x, y, 1)
274    }
275
276    /// Size of the command in bytes.
277    pub const fn size() -> u64 {
278        std::mem::size_of::<Self>() as u64
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_dispatch_indirect_size() {
288        // Verify the struct matches wgpu's expected layout
289        assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
290    }
291
292    #[test]
293    fn test_dispatch_indirect_1d() {
294        let cmd = DispatchIndirect::new_1d(64);
295        assert_eq!(cmd.x, 64);
296        assert_eq!(cmd.y, 1);
297        assert_eq!(cmd.z, 1);
298    }
299
300    #[test]
301    fn test_dispatch_indirect_2d() {
302        let cmd = DispatchIndirect::new_2d(32, 32);
303        assert_eq!(cmd.x, 32);
304        assert_eq!(cmd.y, 32);
305        assert_eq!(cmd.z, 1);
306    }
307}