Skip to main content

oximedia_gpu/
compute.rs

1//! Compute pipeline management for GPU operations
2//!
3//! This module provides high-level abstractions for managing compute pipelines,
4//! including pipeline creation, caching, and execution.
5
6use crate::{GpuDevice, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::Arc;
10use wgpu::{
11    BindGroupLayout, CommandEncoder, ComputePass, ComputePassDescriptor, ComputePipeline,
12    ComputePipelineDescriptor, PipelineLayoutDescriptor, ShaderModule,
13};
14
15/// Compute pipeline wrapper with metadata
16pub struct ComputePipelineHandle {
17    pipeline: ComputePipeline,
18    workgroup_size: (u32, u32, u32),
19    label: String,
20}
21
22impl ComputePipelineHandle {
23    /// Create a new compute pipeline handle
24    #[must_use]
25    pub fn new(pipeline: ComputePipeline, workgroup_size: (u32, u32, u32), label: String) -> Self {
26        Self {
27            pipeline,
28            workgroup_size,
29            label,
30        }
31    }
32
33    /// Get the underlying pipeline
34    #[must_use]
35    pub fn pipeline(&self) -> &ComputePipeline {
36        &self.pipeline
37    }
38
39    /// Get the workgroup size
40    #[must_use]
41    pub fn workgroup_size(&self) -> (u32, u32, u32) {
42        self.workgroup_size
43    }
44
45    /// Get the pipeline label
46    #[must_use]
47    pub fn label(&self) -> &str {
48        &self.label
49    }
50}
51
52/// Compute pipeline manager with caching
53pub struct ComputePipelineManager {
54    device: Arc<wgpu::Device>,
55    pipelines: RwLock<HashMap<String, Arc<ComputePipelineHandle>>>,
56}
57
58impl ComputePipelineManager {
59    /// Create a new compute pipeline manager
60    #[must_use]
61    pub fn new(device: &GpuDevice) -> Self {
62        Self {
63            device: Arc::clone(device.device()),
64            pipelines: RwLock::new(HashMap::new()),
65        }
66    }
67
68    /// Get or create a compute pipeline
69    ///
70    /// # Arguments
71    ///
72    /// * `key` - Unique key for caching the pipeline
73    /// * `label` - Human-readable label for debugging
74    /// * `shader` - Compiled shader module
75    /// * `entry_point` - Entry point function name
76    /// * `bind_group_layout` - Bind group layout for resources
77    /// * `workgroup_size` - Workgroup size (x, y, z)
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if pipeline creation fails.
82    #[allow(clippy::too_many_arguments)]
83    pub fn get_or_create(
84        &self,
85        key: &str,
86        label: &str,
87        shader: &ShaderModule,
88        entry_point: &str,
89        bind_group_layout: &BindGroupLayout,
90        workgroup_size: (u32, u32, u32),
91    ) -> Result<Arc<ComputePipelineHandle>> {
92        // Check cache first
93        {
94            let cache = self.pipelines.read();
95            if let Some(pipeline) = cache.get(key) {
96                return Ok(Arc::clone(pipeline));
97            }
98        }
99
100        // Create pipeline
101        let pipeline = self.create_pipeline(label, shader, entry_point, bind_group_layout)?;
102        let handle = Arc::new(ComputePipelineHandle::new(
103            pipeline,
104            workgroup_size,
105            label.to_string(),
106        ));
107
108        // Cache it
109        {
110            let mut cache = self.pipelines.write();
111            cache.insert(key.to_string(), Arc::clone(&handle));
112        }
113
114        Ok(handle)
115    }
116
117    /// Create a new compute pipeline
118    fn create_pipeline(
119        &self,
120        label: &str,
121        shader: &ShaderModule,
122        entry_point: &str,
123        bind_group_layout: &BindGroupLayout,
124    ) -> Result<ComputePipeline> {
125        let pipeline_layout = self
126            .device
127            .create_pipeline_layout(&PipelineLayoutDescriptor {
128                label: Some(&format!("{label} Layout")),
129                bind_group_layouts: &[Some(bind_group_layout)],
130                immediate_size: 0,
131            });
132
133        Ok(self
134            .device
135            .create_compute_pipeline(&ComputePipelineDescriptor {
136                label: Some(label),
137                layout: Some(&pipeline_layout),
138                module: shader,
139                entry_point: Some(entry_point),
140                cache: None,
141                compilation_options: Default::default(),
142            }))
143    }
144
145    /// Clear the pipeline cache
146    pub fn clear_cache(&self) {
147        let mut cache = self.pipelines.write();
148        cache.clear();
149    }
150
151    /// Get number of cached pipelines
152    pub fn cache_size(&self) -> usize {
153        let cache = self.pipelines.read();
154        cache.len()
155    }
156}
157
158/// Compute pass builder for easier command encoding
159pub struct ComputePassBuilder<'a> {
160    encoder: &'a mut CommandEncoder,
161    label: Option<String>,
162}
163
164impl<'a> ComputePassBuilder<'a> {
165    /// Create a new compute pass builder
166    pub fn new(encoder: &'a mut CommandEncoder) -> Self {
167        Self {
168            encoder,
169            label: None,
170        }
171    }
172
173    /// Set the compute pass label
174    pub fn with_label(mut self, label: impl Into<String>) -> Self {
175        self.label = Some(label.into());
176        self
177    }
178
179    /// Begin the compute pass and execute commands
180    ///
181    /// # Arguments
182    ///
183    /// * `f` - Function that configures the compute pass
184    pub fn execute<F>(self, f: F)
185    where
186        F: FnOnce(&mut ComputePass<'_>),
187    {
188        let mut pass = self.encoder.begin_compute_pass(&ComputePassDescriptor {
189            label: self.label.as_deref(),
190            timestamp_writes: None,
191        });
192
193        f(&mut pass);
194    }
195}
196
197/// Helper for dispatching compute workgroups
198pub struct DispatchHelper;
199
200impl DispatchHelper {
201    /// Calculate dispatch dimensions for 1D workload
202    ///
203    /// # Arguments
204    ///
205    /// * `count` - Total number of elements
206    /// * `workgroup_size` - Workgroup size
207    ///
208    /// # Returns
209    ///
210    /// Number of workgroups to dispatch
211    #[must_use]
212    pub fn dispatch_1d(count: u32, workgroup_size: u32) -> u32 {
213        count.div_ceil(workgroup_size)
214    }
215
216    /// Calculate dispatch dimensions for 2D workload
217    ///
218    /// # Arguments
219    ///
220    /// * `width` - Width of the workload
221    /// * `height` - Height of the workload
222    /// * `workgroup_size` - Workgroup size (x, y)
223    ///
224    /// # Returns
225    ///
226    /// Number of workgroups to dispatch (x, y)
227    #[must_use]
228    pub fn dispatch_2d(width: u32, height: u32, workgroup_size: (u32, u32)) -> (u32, u32) {
229        let x = width.div_ceil(workgroup_size.0);
230        let y = height.div_ceil(workgroup_size.1);
231        (x, y)
232    }
233
234    /// Calculate dispatch dimensions for 3D workload
235    ///
236    /// # Arguments
237    ///
238    /// * `width` - Width of the workload
239    /// * `height` - Height of the workload
240    /// * `depth` - Depth of the workload
241    /// * `workgroup_size` - Workgroup size (x, y, z)
242    ///
243    /// # Returns
244    ///
245    /// Number of workgroups to dispatch (x, y, z)
246    #[must_use]
247    pub fn dispatch_3d(
248        width: u32,
249        height: u32,
250        depth: u32,
251        workgroup_size: (u32, u32, u32),
252    ) -> (u32, u32, u32) {
253        let x = width.div_ceil(workgroup_size.0);
254        let y = height.div_ceil(workgroup_size.1);
255        let z = depth.div_ceil(workgroup_size.2);
256        (x, y, z)
257    }
258}
259
260/// Compute operation executor
261pub struct ComputeExecutor<'a> {
262    device: &'a GpuDevice,
263    encoder: CommandEncoder,
264}
265
266impl<'a> ComputeExecutor<'a> {
267    /// Create a new compute executor
268    #[must_use]
269    pub fn new(device: &'a GpuDevice, label: &str) -> Self {
270        let encoder = device
271            .device()
272            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
273
274        Self { device, encoder }
275    }
276
277    /// Begin a compute pass
278    pub fn begin_pass(&mut self, label: &str) -> ComputePassBuilder<'_> {
279        ComputePassBuilder::new(&mut self.encoder).with_label(label)
280    }
281
282    // Note: Simple dispatch helper removed due to lifetime complexity.
283    // Use begin_pass() directly for compute dispatches.
284    // Example:
285    // executor.begin_pass("My Compute Pass").execute(|pass| {
286    //     pass.set_pipeline(&pipeline);
287    //     pass.set_bind_group(0, &bind_group, &[]);
288    //     pass.dispatch_workgroups(x, y, z);
289    // });
290
291    /// Finish encoding and submit commands
292    pub fn submit(self) {
293        let command_buffer = self.encoder.finish();
294        self.device.queue().submit(Some(command_buffer));
295    }
296
297    /// Get a mutable reference to the encoder for advanced operations
298    pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
299        &mut self.encoder
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_dispatch_1d() {
309        assert_eq!(DispatchHelper::dispatch_1d(100, 64), 2);
310        assert_eq!(DispatchHelper::dispatch_1d(64, 64), 1);
311        assert_eq!(DispatchHelper::dispatch_1d(65, 64), 2);
312        assert_eq!(DispatchHelper::dispatch_1d(0, 64), 0);
313    }
314
315    #[test]
316    fn test_dispatch_2d() {
317        assert_eq!(DispatchHelper::dispatch_2d(100, 100, (16, 16)), (7, 7));
318        assert_eq!(DispatchHelper::dispatch_2d(16, 16, (16, 16)), (1, 1));
319        assert_eq!(DispatchHelper::dispatch_2d(17, 17, (16, 16)), (2, 2));
320    }
321
322    #[test]
323    fn test_dispatch_3d() {
324        assert_eq!(
325            DispatchHelper::dispatch_3d(100, 100, 100, (8, 8, 8)),
326            (13, 13, 13)
327        );
328        assert_eq!(DispatchHelper::dispatch_3d(8, 8, 8, (8, 8, 8)), (1, 1, 1));
329        assert_eq!(DispatchHelper::dispatch_3d(9, 9, 9, (8, 8, 8)), (2, 2, 2));
330    }
331}