astrelis_render/
compute.rs1use astrelis_core::profiling::profile_function;
7
8use crate::frame::FrameContext;
9
10pub struct ComputePassBuilder<'a> {
24 label: Option<&'a str>,
25}
26
27impl<'a> ComputePassBuilder<'a> {
28 pub fn new() -> Self {
30 Self { label: None }
31 }
32
33 pub fn label(mut self, label: &'a str) -> Self {
35 self.label = Some(label);
36 self
37 }
38
39 pub fn build(self, frame_context: &'a mut FrameContext) -> ComputePass<'a> {
44 let mut encoder = frame_context.encoder.take().unwrap();
45
46 let compute_pass = encoder
47 .begin_compute_pass(&wgpu::ComputePassDescriptor {
48 label: self.label,
49 timestamp_writes: None,
50 })
51 .forget_lifetime();
52
53 ComputePass {
54 context: frame_context,
55 encoder: Some(encoder),
56 pass: Some(compute_pass),
57 }
58 }
59}
60
61impl Default for ComputePassBuilder<'_> {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67pub struct ComputePass<'a> {
72 pub(crate) context: &'a mut FrameContext,
73 pub(crate) encoder: Option<wgpu::CommandEncoder>,
74 pub(crate) pass: Option<wgpu::ComputePass<'static>>,
75}
76
77impl<'a> ComputePass<'a> {
78 pub fn pass(&mut self) -> &mut wgpu::ComputePass<'static> {
80 self.pass.as_mut().unwrap()
81 }
82
83 pub fn set_pipeline(&mut self, pipeline: &'a wgpu::ComputePipeline) {
85 self.pass().set_pipeline(pipeline);
86 }
87
88 pub fn set_bind_group(
90 &mut self,
91 index: u32,
92 bind_group: &'a wgpu::BindGroup,
93 offsets: &[u32],
94 ) {
95 self.pass().set_bind_group(index, bind_group, offsets);
96 }
97
98 pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
106 self.pass().dispatch_workgroups(x, y, z);
107 }
108
109 pub fn dispatch_workgroups_1d(&mut self, x: u32) {
113 self.dispatch_workgroups(x, 1, 1);
114 }
115
116 pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
120 self.dispatch_workgroups(x, y, 1);
121 }
122
123 pub fn dispatch_workgroups_indirect(&mut self, buffer: &'a wgpu::Buffer, offset: u64) {
135 self.pass().dispatch_workgroups_indirect(buffer, offset);
136 }
137
138 pub fn insert_debug_marker(&mut self, label: &str) {
140 self.pass().insert_debug_marker(label);
141 }
142
143 pub fn push_debug_group(&mut self, label: &str) {
145 self.pass().push_debug_group(label);
146 }
147
148 pub fn pop_debug_group(&mut self) {
150 self.pass().pop_debug_group();
151 }
152
153 pub fn finish(self) {
155 drop(self);
156 }
157}
158
159impl Drop for ComputePass<'_> {
160 fn drop(&mut self) {
161 profile_function!();
162
163 drop(self.pass.take());
165
166 self.context.encoder = self.encoder.take();
168 }
169}
170
171#[repr(C)]
175#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
176pub struct DispatchIndirect {
177 pub x: u32,
179 pub y: u32,
181 pub z: u32,
183}
184
185unsafe impl bytemuck::Pod for DispatchIndirect {}
187unsafe impl bytemuck::Zeroable for DispatchIndirect {}
188
189impl DispatchIndirect {
190 pub const fn new(x: u32, y: u32, z: u32) -> Self {
192 Self { x, y, z }
193 }
194
195 pub const fn new_1d(x: u32) -> Self {
197 Self::new(x, 1, 1)
198 }
199
200 pub const fn new_2d(x: u32, y: u32) -> Self {
202 Self::new(x, y, 1)
203 }
204
205 pub const fn size() -> u64 {
207 std::mem::size_of::<Self>() as u64
208 }
209}
210
211pub trait ComputePassExt {
213 fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a>;
215
216 fn compute_pass_unlabeled(&mut self) -> ComputePass<'_>;
218}
219
220impl ComputePassExt for FrameContext {
221 fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a> {
222 ComputePassBuilder::new().label(label).build(self)
223 }
224
225 fn compute_pass_unlabeled(&mut self) -> ComputePass<'_> {
226 ComputePassBuilder::new().build(self)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_dispatch_indirect_size() {
236 assert_eq!(DispatchIndirect::size(), 12); }
239
240 #[test]
241 fn test_dispatch_indirect_1d() {
242 let cmd = DispatchIndirect::new_1d(64);
243 assert_eq!(cmd.x, 64);
244 assert_eq!(cmd.y, 1);
245 assert_eq!(cmd.z, 1);
246 }
247
248 #[test]
249 fn test_dispatch_indirect_2d() {
250 let cmd = DispatchIndirect::new_2d(32, 32);
251 assert_eq!(cmd.x, 32);
252 assert_eq!(cmd.y, 32);
253 assert_eq!(cmd.z, 1);
254 }
255}