Skip to main content

astrelis_render/
compute.rs

1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::frame::FrameContext;
9
10/// Builder for creating compute passes.
11///
12/// # Example
13///
14/// ```ignore
15/// let mut compute_pass = ComputePassBuilder::new()
16///     .label("My Compute Pass")
17///     .build(&mut frame);
18///
19/// compute_pass.set_pipeline(&pipeline);
20/// compute_pass.set_bind_group(0, &bind_group, &[]);
21/// compute_pass.dispatch_workgroups(64, 64, 1);
22/// ```
23pub struct ComputePassBuilder<'a> {
24    label: Option<&'a str>,
25}
26
27impl<'a> ComputePassBuilder<'a> {
28    /// Create a new compute pass builder.
29    pub fn new() -> Self {
30        Self { label: None }
31    }
32
33    /// Set a debug label for the compute pass.
34    pub fn label(mut self, label: &'a str) -> Self {
35        self.label = Some(label);
36        self
37    }
38
39    /// Build the compute pass.
40    ///
41    /// This takes the command encoder from the FrameContext and returns it
42    /// when the ComputePass is dropped.
43    pub fn build(self, frame_context: &'a mut FrameContext) -> ComputePass<'a> {
44        let mut encoder = frame_context.encoder.take().unwrap();
45
46        let compute_pass = encoder
47            .begin_compute_pass(&wgpu::ComputePassDescriptor {
48                label: self.label,
49                timestamp_writes: None,
50            })
51            .forget_lifetime();
52
53        ComputePass {
54            context: frame_context,
55            encoder: Some(encoder),
56            pass: Some(compute_pass),
57        }
58    }
59}
60
61impl Default for ComputePassBuilder<'_> {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// A compute pass wrapper that automatically returns the encoder to the frame context.
68///
69/// This struct mirrors `RenderPass` in its lifecycle management - it takes the
70/// encoder from `FrameContext` and returns it when dropped.
71pub struct ComputePass<'a> {
72    pub(crate) context: &'a mut FrameContext,
73    pub(crate) encoder: Option<wgpu::CommandEncoder>,
74    pub(crate) pass: Option<wgpu::ComputePass<'static>>,
75}
76
77impl<'a> ComputePass<'a> {
78    /// Get the underlying wgpu compute pass.
79    pub fn wgpu_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
80        self.pass.as_mut().unwrap()
81    }
82
83    /// Get raw access to the underlying wgpu compute pass.
84    ///
85    /// This is an alias for [`wgpu_pass()`](Self::wgpu_pass) for consistency with `RenderPass::raw_pass()`.
86    pub fn raw_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
87        self.pass.as_mut().unwrap()
88    }
89
90    /// Get the graphics context.
91    pub fn graphics_context(&self) -> &crate::context::GraphicsContext {
92        &self.context.context
93    }
94
95    /// Set the compute pipeline to use.
96    pub fn set_pipeline(&mut self, pipeline: &'a wgpu::ComputePipeline) {
97        self.wgpu_pass().set_pipeline(pipeline);
98    }
99
100    /// Set a bind group.
101    pub fn set_bind_group(
102        &mut self,
103        index: u32,
104        bind_group: &'a wgpu::BindGroup,
105        offsets: &[u32],
106    ) {
107        self.wgpu_pass().set_bind_group(index, bind_group, offsets);
108    }
109
110    /// Dispatch workgroups.
111    ///
112    /// # Arguments
113    ///
114    /// * `x` - Number of workgroups in the X dimension
115    /// * `y` - Number of workgroups in the Y dimension
116    /// * `z` - Number of workgroups in the Z dimension
117    pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
118        self.wgpu_pass().dispatch_workgroups(x, y, z);
119    }
120
121    /// Dispatch workgroups with a 1D configuration.
122    ///
123    /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
124    pub fn dispatch_workgroups_1d(&mut self, x: u32) {
125        self.dispatch_workgroups(x, 1, 1);
126    }
127
128    /// Dispatch workgroups with a 2D configuration.
129    ///
130    /// Equivalent to `dispatch_workgroups(x, y, 1)`.
131    pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
132        self.dispatch_workgroups(x, y, 1);
133    }
134
135    /// Dispatch workgroups indirectly from a buffer.
136    ///
137    /// The buffer should contain a `DispatchIndirect` struct:
138    /// ```ignore
139    /// #[repr(C)]
140    /// struct DispatchIndirect {
141    ///     x: u32,
142    ///     y: u32,
143    ///     z: u32,
144    /// }
145    /// ```
146    pub fn dispatch_workgroups_indirect(&mut self, buffer: &'a wgpu::Buffer, offset: u64) {
147        self.wgpu_pass().dispatch_workgroups_indirect(buffer, offset);
148    }
149
150    /// Insert a debug marker.
151    pub fn insert_debug_marker(&mut self, label: &str) {
152        self.wgpu_pass().insert_debug_marker(label);
153    }
154
155    /// Push a debug group.
156    pub fn push_debug_group(&mut self, label: &str) {
157        self.wgpu_pass().push_debug_group(label);
158    }
159
160    /// Pop a debug group.
161    pub fn pop_debug_group(&mut self) {
162        self.wgpu_pass().pop_debug_group();
163    }
164
165    /// Set push constants for the compute shader.
166    ///
167    /// Push constants are a fast way to pass small amounts of data to shaders
168    /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
169    /// to be enabled on the device.
170    ///
171    /// # Arguments
172    ///
173    /// * `offset` - Byte offset into the push constant range
174    /// * `data` - Data to set (must be `Pod` for safe byte casting)
175    ///
176    /// # Example
177    ///
178    /// ```ignore
179    /// #[repr(C)]
180    /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
181    /// struct ComputeConstants {
182    ///     workgroup_count: u32,
183    ///     time: f32,
184    /// }
185    ///
186    /// let constants = ComputeConstants {
187    ///     workgroup_count: 64,
188    ///     time: 1.5,
189    /// };
190    ///
191    /// pass.set_push_constants(0, &constants);
192    /// ```
193    pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
194        self.wgpu_pass()
195            .set_push_constants(offset, bytemuck::bytes_of(data));
196    }
197
198    /// Set push constants from raw bytes.
199    ///
200    /// Use this when you need more control over the data layout.
201    pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
202        self.wgpu_pass().set_push_constants(offset, data);
203    }
204
205    /// Finish the compute pass, returning the encoder to the frame context.
206    pub fn finish(self) {
207        drop(self);
208    }
209}
210
211impl Drop for ComputePass<'_> {
212    fn drop(&mut self) {
213        profile_function!();
214
215        // End the compute pass
216        drop(self.pass.take());
217
218        // Return the encoder to the frame context
219        self.context.encoder = self.encoder.take();
220    }
221}
222
223/// Indirect dispatch command.
224///
225/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
226#[repr(C)]
227#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
228pub struct DispatchIndirect {
229    /// Number of workgroups in the X dimension.
230    pub x: u32,
231    /// Number of workgroups in the Y dimension.
232    pub y: u32,
233    /// Number of workgroups in the Z dimension.
234    pub z: u32,
235}
236
237// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
238unsafe impl bytemuck::Pod for DispatchIndirect {}
239unsafe impl bytemuck::Zeroable for DispatchIndirect {}
240
241impl DispatchIndirect {
242    /// Create a new dispatch command.
243    pub const fn new(x: u32, y: u32, z: u32) -> Self {
244        Self { x, y, z }
245    }
246
247    /// Create a 1D dispatch command.
248    pub const fn new_1d(x: u32) -> Self {
249        Self::new(x, 1, 1)
250    }
251
252    /// Create a 2D dispatch command.
253    pub const fn new_2d(x: u32, y: u32) -> Self {
254        Self::new(x, y, 1)
255    }
256
257    /// Size of the command in bytes.
258    pub const fn size() -> u64 {
259        std::mem::size_of::<Self>() as u64
260    }
261}
262
263impl FrameContext {
264    /// Create a compute pass with a label.
265    pub fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a> {
266        ComputePassBuilder::new().label(label).build(self)
267    }
268
269    /// Create a compute pass without a label.
270    pub fn compute_pass_unlabeled(&mut self) -> ComputePass<'_> {
271        ComputePassBuilder::new().build(self)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_dispatch_indirect_size() {
281        // Verify the struct matches wgpu's expected layout
282        assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
283    }
284
285    #[test]
286    fn test_dispatch_indirect_1d() {
287        let cmd = DispatchIndirect::new_1d(64);
288        assert_eq!(cmd.x, 64);
289        assert_eq!(cmd.y, 1);
290        assert_eq!(cmd.z, 1);
291    }
292
293    #[test]
294    fn test_dispatch_indirect_2d() {
295        let cmd = DispatchIndirect::new_2d(32, 32);
296        assert_eq!(cmd.x, 32);
297        assert_eq!(cmd.y, 32);
298        assert_eq!(cmd.z, 1);
299    }
300}