1use crate::{GpuDevice, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::Arc;
10use wgpu::{
11 BindGroupLayout, CommandEncoder, ComputePass, ComputePassDescriptor, ComputePipeline,
12 ComputePipelineDescriptor, PipelineLayoutDescriptor, ShaderModule,
13};
14
15pub struct ComputePipelineHandle {
17 pipeline: ComputePipeline,
18 workgroup_size: (u32, u32, u32),
19 label: String,
20}
21
22impl ComputePipelineHandle {
23 #[must_use]
25 pub fn new(pipeline: ComputePipeline, workgroup_size: (u32, u32, u32), label: String) -> Self {
26 Self {
27 pipeline,
28 workgroup_size,
29 label,
30 }
31 }
32
33 #[must_use]
35 pub fn pipeline(&self) -> &ComputePipeline {
36 &self.pipeline
37 }
38
39 #[must_use]
41 pub fn workgroup_size(&self) -> (u32, u32, u32) {
42 self.workgroup_size
43 }
44
45 #[must_use]
47 pub fn label(&self) -> &str {
48 &self.label
49 }
50}
51
52pub struct ComputePipelineManager {
54 device: Arc<wgpu::Device>,
55 pipelines: RwLock<HashMap<String, Arc<ComputePipelineHandle>>>,
56}
57
58impl ComputePipelineManager {
59 #[must_use]
61 pub fn new(device: &GpuDevice) -> Self {
62 Self {
63 device: Arc::clone(device.device()),
64 pipelines: RwLock::new(HashMap::new()),
65 }
66 }
67
68 #[allow(clippy::too_many_arguments)]
83 pub fn get_or_create(
84 &self,
85 key: &str,
86 label: &str,
87 shader: &ShaderModule,
88 entry_point: &str,
89 bind_group_layout: &BindGroupLayout,
90 workgroup_size: (u32, u32, u32),
91 ) -> Result<Arc<ComputePipelineHandle>> {
92 {
94 let cache = self.pipelines.read();
95 if let Some(pipeline) = cache.get(key) {
96 return Ok(Arc::clone(pipeline));
97 }
98 }
99
100 let pipeline = self.create_pipeline(label, shader, entry_point, bind_group_layout)?;
102 let handle = Arc::new(ComputePipelineHandle::new(
103 pipeline,
104 workgroup_size,
105 label.to_string(),
106 ));
107
108 {
110 let mut cache = self.pipelines.write();
111 cache.insert(key.to_string(), Arc::clone(&handle));
112 }
113
114 Ok(handle)
115 }
116
117 fn create_pipeline(
119 &self,
120 label: &str,
121 shader: &ShaderModule,
122 entry_point: &str,
123 bind_group_layout: &BindGroupLayout,
124 ) -> Result<ComputePipeline> {
125 let pipeline_layout = self
126 .device
127 .create_pipeline_layout(&PipelineLayoutDescriptor {
128 label: Some(&format!("{label} Layout")),
129 bind_group_layouts: &[Some(bind_group_layout)],
130 immediate_size: 0,
131 });
132
133 Ok(self
134 .device
135 .create_compute_pipeline(&ComputePipelineDescriptor {
136 label: Some(label),
137 layout: Some(&pipeline_layout),
138 module: shader,
139 entry_point: Some(entry_point),
140 cache: None,
141 compilation_options: Default::default(),
142 }))
143 }
144
145 pub fn clear_cache(&self) {
147 let mut cache = self.pipelines.write();
148 cache.clear();
149 }
150
151 pub fn cache_size(&self) -> usize {
153 let cache = self.pipelines.read();
154 cache.len()
155 }
156}
157
158pub struct ComputePassBuilder<'a> {
160 encoder: &'a mut CommandEncoder,
161 label: Option<String>,
162}
163
164impl<'a> ComputePassBuilder<'a> {
165 pub fn new(encoder: &'a mut CommandEncoder) -> Self {
167 Self {
168 encoder,
169 label: None,
170 }
171 }
172
173 pub fn with_label(mut self, label: impl Into<String>) -> Self {
175 self.label = Some(label.into());
176 self
177 }
178
179 pub fn execute<F>(self, f: F)
185 where
186 F: FnOnce(&mut ComputePass<'_>),
187 {
188 let mut pass = self.encoder.begin_compute_pass(&ComputePassDescriptor {
189 label: self.label.as_deref(),
190 timestamp_writes: None,
191 });
192
193 f(&mut pass);
194 }
195}
196
197pub struct DispatchHelper;
199
200impl DispatchHelper {
201 #[must_use]
212 pub fn dispatch_1d(count: u32, workgroup_size: u32) -> u32 {
213 count.div_ceil(workgroup_size)
214 }
215
216 #[must_use]
228 pub fn dispatch_2d(width: u32, height: u32, workgroup_size: (u32, u32)) -> (u32, u32) {
229 let x = width.div_ceil(workgroup_size.0);
230 let y = height.div_ceil(workgroup_size.1);
231 (x, y)
232 }
233
234 #[must_use]
247 pub fn dispatch_3d(
248 width: u32,
249 height: u32,
250 depth: u32,
251 workgroup_size: (u32, u32, u32),
252 ) -> (u32, u32, u32) {
253 let x = width.div_ceil(workgroup_size.0);
254 let y = height.div_ceil(workgroup_size.1);
255 let z = depth.div_ceil(workgroup_size.2);
256 (x, y, z)
257 }
258}
259
260pub struct ComputeExecutor<'a> {
262 device: &'a GpuDevice,
263 encoder: CommandEncoder,
264}
265
266impl<'a> ComputeExecutor<'a> {
267 #[must_use]
269 pub fn new(device: &'a GpuDevice, label: &str) -> Self {
270 let encoder = device
271 .device()
272 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
273
274 Self { device, encoder }
275 }
276
277 pub fn begin_pass(&mut self, label: &str) -> ComputePassBuilder<'_> {
279 ComputePassBuilder::new(&mut self.encoder).with_label(label)
280 }
281
282 pub fn submit(self) {
293 let command_buffer = self.encoder.finish();
294 self.device.queue().submit(Some(command_buffer));
295 }
296
297 pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
299 &mut self.encoder
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_dispatch_1d() {
309 assert_eq!(DispatchHelper::dispatch_1d(100, 64), 2);
310 assert_eq!(DispatchHelper::dispatch_1d(64, 64), 1);
311 assert_eq!(DispatchHelper::dispatch_1d(65, 64), 2);
312 assert_eq!(DispatchHelper::dispatch_1d(0, 64), 0);
313 }
314
315 #[test]
316 fn test_dispatch_2d() {
317 assert_eq!(DispatchHelper::dispatch_2d(100, 100, (16, 16)), (7, 7));
318 assert_eq!(DispatchHelper::dispatch_2d(16, 16, (16, 16)), (1, 1));
319 assert_eq!(DispatchHelper::dispatch_2d(17, 17, (16, 16)), (2, 2));
320 }
321
322 #[test]
323 fn test_dispatch_3d() {
324 assert_eq!(
325 DispatchHelper::dispatch_3d(100, 100, 100, (8, 8, 8)),
326 (13, 13, 13)
327 );
328 assert_eq!(DispatchHelper::dispatch_3d(8, 8, 8, (8, 8, 8)), (1, 1, 1));
329 assert_eq!(DispatchHelper::dispatch_3d(9, 9, 9, (8, 8, 8)), (2, 2, 2));
330 }
331}