astrelis_render/compute.rs
1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::frame::FrameContext;
9
10/// Builder for creating compute passes.
11///
12/// # Example
13///
14/// ```ignore
15/// let mut compute_pass = ComputePassBuilder::new()
16/// .label("My Compute Pass")
17/// .build(&mut frame);
18///
19/// compute_pass.set_pipeline(&pipeline);
20/// compute_pass.set_bind_group(0, &bind_group, &[]);
21/// compute_pass.dispatch_workgroups(64, 64, 1);
22/// ```
23pub struct ComputePassBuilder<'a> {
24 label: Option<&'a str>,
25}
26
27impl<'a> ComputePassBuilder<'a> {
28 /// Create a new compute pass builder.
29 pub fn new() -> Self {
30 Self { label: None }
31 }
32
33 /// Set a debug label for the compute pass.
34 pub fn label(mut self, label: &'a str) -> Self {
35 self.label = Some(label);
36 self
37 }
38
39 /// Build the compute pass.
40 ///
41 /// This takes the command encoder from the FrameContext and returns it
42 /// when the ComputePass is dropped.
43 pub fn build(self, frame_context: &'a mut FrameContext) -> ComputePass<'a> {
44 let mut encoder = frame_context.encoder.take().unwrap();
45
46 let compute_pass = encoder
47 .begin_compute_pass(&wgpu::ComputePassDescriptor {
48 label: self.label,
49 timestamp_writes: None,
50 })
51 .forget_lifetime();
52
53 ComputePass {
54 context: frame_context,
55 encoder: Some(encoder),
56 pass: Some(compute_pass),
57 }
58 }
59}
60
61impl Default for ComputePassBuilder<'_> {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67/// A compute pass wrapper that automatically returns the encoder to the frame context.
68///
69/// This struct mirrors `RenderPass` in its lifecycle management - it takes the
70/// encoder from `FrameContext` and returns it when dropped.
71pub struct ComputePass<'a> {
72 pub(crate) context: &'a mut FrameContext,
73 pub(crate) encoder: Option<wgpu::CommandEncoder>,
74 pub(crate) pass: Option<wgpu::ComputePass<'static>>,
75}
76
77impl<'a> ComputePass<'a> {
78 /// Get the underlying wgpu compute pass.
79 pub fn wgpu_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
80 self.pass.as_mut().unwrap()
81 }
82
83 /// Get raw access to the underlying wgpu compute pass.
84 ///
85 /// This is an alias for [`wgpu_pass()`](Self::wgpu_pass) for consistency with `RenderPass::raw_pass()`.
86 pub fn raw_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
87 self.pass.as_mut().unwrap()
88 }
89
90 /// Get the graphics context.
91 pub fn graphics_context(&self) -> &crate::context::GraphicsContext {
92 &self.context.context
93 }
94
95 /// Set the compute pipeline to use.
96 pub fn set_pipeline(&mut self, pipeline: &'a wgpu::ComputePipeline) {
97 self.wgpu_pass().set_pipeline(pipeline);
98 }
99
100 /// Set a bind group.
101 pub fn set_bind_group(
102 &mut self,
103 index: u32,
104 bind_group: &'a wgpu::BindGroup,
105 offsets: &[u32],
106 ) {
107 self.wgpu_pass().set_bind_group(index, bind_group, offsets);
108 }
109
110 /// Dispatch workgroups.
111 ///
112 /// # Arguments
113 ///
114 /// * `x` - Number of workgroups in the X dimension
115 /// * `y` - Number of workgroups in the Y dimension
116 /// * `z` - Number of workgroups in the Z dimension
117 pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
118 self.wgpu_pass().dispatch_workgroups(x, y, z);
119 }
120
121 /// Dispatch workgroups with a 1D configuration.
122 ///
123 /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
124 pub fn dispatch_workgroups_1d(&mut self, x: u32) {
125 self.dispatch_workgroups(x, 1, 1);
126 }
127
128 /// Dispatch workgroups with a 2D configuration.
129 ///
130 /// Equivalent to `dispatch_workgroups(x, y, 1)`.
131 pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
132 self.dispatch_workgroups(x, y, 1);
133 }
134
135 /// Dispatch workgroups indirectly from a buffer.
136 ///
137 /// The buffer should contain a `DispatchIndirect` struct:
138 /// ```ignore
139 /// #[repr(C)]
140 /// struct DispatchIndirect {
141 /// x: u32,
142 /// y: u32,
143 /// z: u32,
144 /// }
145 /// ```
146 pub fn dispatch_workgroups_indirect(&mut self, buffer: &'a wgpu::Buffer, offset: u64) {
147 self.wgpu_pass().dispatch_workgroups_indirect(buffer, offset);
148 }
149
150 /// Insert a debug marker.
151 pub fn insert_debug_marker(&mut self, label: &str) {
152 self.wgpu_pass().insert_debug_marker(label);
153 }
154
155 /// Push a debug group.
156 pub fn push_debug_group(&mut self, label: &str) {
157 self.wgpu_pass().push_debug_group(label);
158 }
159
160 /// Pop a debug group.
161 pub fn pop_debug_group(&mut self) {
162 self.wgpu_pass().pop_debug_group();
163 }
164
165 /// Set push constants for the compute shader.
166 ///
167 /// Push constants are a fast way to pass small amounts of data to shaders
168 /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
169 /// to be enabled on the device.
170 ///
171 /// # Arguments
172 ///
173 /// * `offset` - Byte offset into the push constant range
174 /// * `data` - Data to set (must be `Pod` for safe byte casting)
175 ///
176 /// # Example
177 ///
178 /// ```ignore
179 /// #[repr(C)]
180 /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
181 /// struct ComputeConstants {
182 /// workgroup_count: u32,
183 /// time: f32,
184 /// }
185 ///
186 /// let constants = ComputeConstants {
187 /// workgroup_count: 64,
188 /// time: 1.5,
189 /// };
190 ///
191 /// pass.set_push_constants(0, &constants);
192 /// ```
193 pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
194 self.wgpu_pass()
195 .set_push_constants(offset, bytemuck::bytes_of(data));
196 }
197
198 /// Set push constants from raw bytes.
199 ///
200 /// Use this when you need more control over the data layout.
201 pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
202 self.wgpu_pass().set_push_constants(offset, data);
203 }
204
205 /// Finish the compute pass, returning the encoder to the frame context.
206 pub fn finish(self) {
207 drop(self);
208 }
209}
210
211impl Drop for ComputePass<'_> {
212 fn drop(&mut self) {
213 profile_function!();
214
215 // End the compute pass
216 drop(self.pass.take());
217
218 // Return the encoder to the frame context
219 self.context.encoder = self.encoder.take();
220 }
221}
222
223/// Indirect dispatch command.
224///
225/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
226#[repr(C)]
227#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
228pub struct DispatchIndirect {
229 /// Number of workgroups in the X dimension.
230 pub x: u32,
231 /// Number of workgroups in the Y dimension.
232 pub y: u32,
233 /// Number of workgroups in the Z dimension.
234 pub z: u32,
235}
236
237// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
238unsafe impl bytemuck::Pod for DispatchIndirect {}
239unsafe impl bytemuck::Zeroable for DispatchIndirect {}
240
241impl DispatchIndirect {
242 /// Create a new dispatch command.
243 pub const fn new(x: u32, y: u32, z: u32) -> Self {
244 Self { x, y, z }
245 }
246
247 /// Create a 1D dispatch command.
248 pub const fn new_1d(x: u32) -> Self {
249 Self::new(x, 1, 1)
250 }
251
252 /// Create a 2D dispatch command.
253 pub const fn new_2d(x: u32, y: u32) -> Self {
254 Self::new(x, y, 1)
255 }
256
257 /// Size of the command in bytes.
258 pub const fn size() -> u64 {
259 std::mem::size_of::<Self>() as u64
260 }
261}
262
263impl FrameContext {
264 /// Create a compute pass with a label.
265 pub fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a> {
266 ComputePassBuilder::new().label(label).build(self)
267 }
268
269 /// Create a compute pass without a label.
270 pub fn compute_pass_unlabeled(&mut self) -> ComputePass<'_> {
271 ComputePassBuilder::new().build(self)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_dispatch_indirect_size() {
281 // Verify the struct matches wgpu's expected layout
282 assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
283 }
284
285 #[test]
286 fn test_dispatch_indirect_1d() {
287 let cmd = DispatchIndirect::new_1d(64);
288 assert_eq!(cmd.x, 64);
289 assert_eq!(cmd.y, 1);
290 assert_eq!(cmd.z, 1);
291 }
292
293 #[test]
294 fn test_dispatch_indirect_2d() {
295 let cmd = DispatchIndirect::new_2d(32, 32);
296 assert_eq!(cmd.x, 32);
297 assert_eq!(cmd.y, 32);
298 assert_eq!(cmd.z, 1);
299 }
300}