astrelis_render/
compute.rs

1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::frame::FrameContext;
9
10/// Builder for creating compute passes.
11///
12/// # Example
13///
14/// ```ignore
15/// let mut compute_pass = ComputePassBuilder::new()
16///     .label("My Compute Pass")
17///     .build(&mut frame);
18///
19/// compute_pass.set_pipeline(&pipeline);
20/// compute_pass.set_bind_group(0, &bind_group, &[]);
21/// compute_pass.dispatch_workgroups(64, 64, 1);
22/// ```
23pub struct ComputePassBuilder<'a> {
24    label: Option<&'a str>,
25}
26
27impl<'a> ComputePassBuilder<'a> {
28    /// Create a new compute pass builder.
29    pub fn new() -> Self {
30        Self { label: None }
31    }
32
33    /// Set a debug label for the compute pass.
34    pub fn label(mut self, label: &'a str) -> Self {
35        self.label = Some(label);
36        self
37    }
38
39    /// Build the compute pass.
40    ///
41    /// This takes the command encoder from the FrameContext and returns it
42    /// when the ComputePass is dropped.
43    pub fn build(self, frame_context: &'a mut FrameContext) -> ComputePass<'a> {
44        let mut encoder = frame_context.encoder.take().unwrap();
45
46        let compute_pass = encoder
47            .begin_compute_pass(&wgpu::ComputePassDescriptor {
48                label: self.label,
49                timestamp_writes: None,
50            })
51            .forget_lifetime();
52
53        ComputePass {
54            context: frame_context,
55            encoder: Some(encoder),
56            pass: Some(compute_pass),
57        }
58    }
59}
60
61impl Default for ComputePassBuilder<'_> {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// A compute pass wrapper that automatically returns the encoder to the frame context.
68///
69/// This struct mirrors `RenderPass` in its lifecycle management - it takes the
70/// encoder from `FrameContext` and returns it when dropped.
71pub struct ComputePass<'a> {
72    pub(crate) context: &'a mut FrameContext,
73    pub(crate) encoder: Option<wgpu::CommandEncoder>,
74    pub(crate) pass: Option<wgpu::ComputePass<'static>>,
75}
76
77impl<'a> ComputePass<'a> {
78    /// Get the underlying wgpu compute pass.
79    pub fn pass(&mut self) -> &mut wgpu::ComputePass<'static> {
80        self.pass.as_mut().unwrap()
81    }
82
83    /// Set the compute pipeline to use.
84    pub fn set_pipeline(&mut self, pipeline: &'a wgpu::ComputePipeline) {
85        self.pass().set_pipeline(pipeline);
86    }
87
88    /// Set a bind group.
89    pub fn set_bind_group(
90        &mut self,
91        index: u32,
92        bind_group: &'a wgpu::BindGroup,
93        offsets: &[u32],
94    ) {
95        self.pass().set_bind_group(index, bind_group, offsets);
96    }
97
98    /// Dispatch workgroups.
99    ///
100    /// # Arguments
101    ///
102    /// * `x` - Number of workgroups in the X dimension
103    /// * `y` - Number of workgroups in the Y dimension
104    /// * `z` - Number of workgroups in the Z dimension
105    pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
106        self.pass().dispatch_workgroups(x, y, z);
107    }
108
109    /// Dispatch workgroups with a 1D configuration.
110    ///
111    /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
112    pub fn dispatch_workgroups_1d(&mut self, x: u32) {
113        self.dispatch_workgroups(x, 1, 1);
114    }
115
116    /// Dispatch workgroups with a 2D configuration.
117    ///
118    /// Equivalent to `dispatch_workgroups(x, y, 1)`.
119    pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
120        self.dispatch_workgroups(x, y, 1);
121    }
122
123    /// Dispatch workgroups indirectly from a buffer.
124    ///
125    /// The buffer should contain a `DispatchIndirect` struct:
126    /// ```ignore
127    /// #[repr(C)]
128    /// struct DispatchIndirect {
129    ///     x: u32,
130    ///     y: u32,
131    ///     z: u32,
132    /// }
133    /// ```
134    pub fn dispatch_workgroups_indirect(&mut self, buffer: &'a wgpu::Buffer, offset: u64) {
135        self.pass().dispatch_workgroups_indirect(buffer, offset);
136    }
137
138    /// Insert a debug marker.
139    pub fn insert_debug_marker(&mut self, label: &str) {
140        self.pass().insert_debug_marker(label);
141    }
142
143    /// Push a debug group.
144    pub fn push_debug_group(&mut self, label: &str) {
145        self.pass().push_debug_group(label);
146    }
147
148    /// Pop a debug group.
149    pub fn pop_debug_group(&mut self) {
150        self.pass().pop_debug_group();
151    }
152
153    /// Set push constants for the compute shader.
154    ///
155    /// Push constants are a fast way to pass small amounts of data to shaders
156    /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
157    /// to be enabled on the device.
158    ///
159    /// # Arguments
160    ///
161    /// * `offset` - Byte offset into the push constant range
162    /// * `data` - Data to set (must be `Pod` for safe byte casting)
163    ///
164    /// # Example
165    ///
166    /// ```ignore
167    /// #[repr(C)]
168    /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
169    /// struct ComputeConstants {
170    ///     workgroup_count: u32,
171    ///     time: f32,
172    /// }
173    ///
174    /// let constants = ComputeConstants {
175    ///     workgroup_count: 64,
176    ///     time: 1.5,
177    /// };
178    ///
179    /// pass.set_push_constants(0, &constants);
180    /// ```
181    pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
182        self.pass()
183            .set_push_constants(offset, bytemuck::bytes_of(data));
184    }
185
186    /// Set push constants from raw bytes.
187    ///
188    /// Use this when you need more control over the data layout.
189    pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
190        self.pass().set_push_constants(offset, data);
191    }
192
193    /// Finish the compute pass, returning the encoder to the frame context.
194    pub fn finish(self) {
195        drop(self);
196    }
197}
198
199impl Drop for ComputePass<'_> {
200    fn drop(&mut self) {
201        profile_function!();
202
203        // End the compute pass
204        drop(self.pass.take());
205
206        // Return the encoder to the frame context
207        self.context.encoder = self.encoder.take();
208    }
209}
210
211/// Indirect dispatch command.
212///
213/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
214#[repr(C)]
215#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
216pub struct DispatchIndirect {
217    /// Number of workgroups in the X dimension.
218    pub x: u32,
219    /// Number of workgroups in the Y dimension.
220    pub y: u32,
221    /// Number of workgroups in the Z dimension.
222    pub z: u32,
223}
224
225// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
226unsafe impl bytemuck::Pod for DispatchIndirect {}
227unsafe impl bytemuck::Zeroable for DispatchIndirect {}
228
229impl DispatchIndirect {
230    /// Create a new dispatch command.
231    pub const fn new(x: u32, y: u32, z: u32) -> Self {
232        Self { x, y, z }
233    }
234
235    /// Create a 1D dispatch command.
236    pub const fn new_1d(x: u32) -> Self {
237        Self::new(x, 1, 1)
238    }
239
240    /// Create a 2D dispatch command.
241    pub const fn new_2d(x: u32, y: u32) -> Self {
242        Self::new(x, y, 1)
243    }
244
245    /// Size of the command in bytes.
246    pub const fn size() -> u64 {
247        std::mem::size_of::<Self>() as u64
248    }
249}
250
251/// Helper trait for creating compute passes from FrameContext.
252pub trait ComputePassExt {
253    /// Create a compute pass with a label.
254    fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a>;
255
256    /// Create a compute pass without a label.
257    fn compute_pass_unlabeled(&mut self) -> ComputePass<'_>;
258}
259
260impl ComputePassExt for FrameContext {
261    fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a> {
262        ComputePassBuilder::new().label(label).build(self)
263    }
264
265    fn compute_pass_unlabeled(&mut self) -> ComputePass<'_> {
266        ComputePassBuilder::new().build(self)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_dispatch_indirect_size() {
276        // Verify the struct matches wgpu's expected layout
277        assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
278    }
279
280    #[test]
281    fn test_dispatch_indirect_1d() {
282        let cmd = DispatchIndirect::new_1d(64);
283        assert_eq!(cmd.x, 64);
284        assert_eq!(cmd.y, 1);
285        assert_eq!(cmd.z, 1);
286    }
287
288    #[test]
289    fn test_dispatch_indirect_2d() {
290        let cmd = DispatchIndirect::new_2d(32, 32);
291        assert_eq!(cmd.x, 32);
292        assert_eq!(cmd.y, 32);
293        assert_eq!(cmd.z, 1);
294    }
295}