Skip to main content

oximedia_gpu/
compute_pass.rs

1//! GPU compute pass management — pass types, buffer bindings, and pass queues.
2
3#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6/// Category of work that a compute pass performs.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PassType {
9    /// Video processing (real-time).
10    Video,
11    /// Audio processing (real-time).
12    Audio,
13    /// Still-image processing.
14    Image,
15    /// Post-processing effects.
16    PostProcess,
17}
18
19impl PassType {
20    /// Returns `true` for pass types that operate in real-time context.
21    #[must_use]
22    pub fn is_real_time(&self) -> bool {
23        matches!(self, Self::Video | Self::Audio)
24    }
25}
26
27/// A binding between a GPU buffer slot and a logical buffer.
28#[derive(Debug, Clone)]
29pub struct BufferBinding {
30    /// The shader binding slot index.
31    pub slot: u8,
32    /// Size of the buffer in bytes.
33    pub size_bytes: u32,
34    /// Whether the binding is read-only (i.e. an input buffer).
35    pub read_only: bool,
36}
37
38impl BufferBinding {
39    /// Creates a new `BufferBinding`.
40    #[must_use]
41    pub fn new(slot: u8, size_bytes: u32, read_only: bool) -> Self {
42        Self {
43            slot,
44            size_bytes,
45            read_only,
46        }
47    }
48
49    /// Returns `true` if this binding is an input (read-only) binding.
50    #[must_use]
51    pub fn is_input(&self) -> bool {
52        self.read_only
53    }
54
55    /// Returns `true` if this binding is an output (writable) binding.
56    #[must_use]
57    pub fn is_output(&self) -> bool {
58        !self.read_only
59    }
60}
61
62/// A single compute pass with a name, type, buffer bindings, and dispatch dimensions.
63#[derive(Debug)]
64pub struct ComputePass {
65    /// Human-readable name for debugging.
66    pub name: String,
67    /// The category of this pass.
68    pub pass_type: PassType,
69    /// Buffer bindings used by this pass.
70    pub bindings: Vec<BufferBinding>,
71    /// Workgroup dispatch dimensions (x, y, z).
72    pub workgroups: (u32, u32, u32),
73}
74
75impl ComputePass {
76    /// Creates a new `ComputePass` with no bindings and a default dispatch of (1, 1, 1).
77    #[must_use]
78    pub fn new(name: impl Into<String>, pt: PassType) -> Self {
79        Self {
80            name: name.into(),
81            pass_type: pt,
82            bindings: Vec::new(),
83            workgroups: (1, 1, 1),
84        }
85    }
86
87    /// Adds a read-only (input) buffer binding on the given slot.
88    pub fn add_input_binding(&mut self, slot: u8, size: u32) {
89        self.bindings.push(BufferBinding::new(slot, size, true));
90    }
91
92    /// Adds a writable (output) buffer binding on the given slot.
93    pub fn add_output_binding(&mut self, slot: u8, size: u32) {
94        self.bindings.push(BufferBinding::new(slot, size, false));
95    }
96
97    /// Total work items = workgroups.x × workgroups.y × workgroups.z.
98    #[must_use]
99    pub fn total_work_items(&self) -> u64 {
100        u64::from(self.workgroups.0) * u64::from(self.workgroups.1) * u64::from(self.workgroups.2)
101    }
102
103    /// Returns the number of bindings attached to this pass.
104    #[must_use]
105    pub fn binding_count(&self) -> usize {
106        self.bindings.len()
107    }
108}
109
110/// An ordered queue of [`ComputePass`] entries.
111#[derive(Debug, Default)]
112pub struct PassQueue {
113    passes: Vec<ComputePass>,
114}
115
116impl PassQueue {
117    /// Creates an empty `PassQueue`.
118    #[must_use]
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Appends a pass to the queue.
124    pub fn add(&mut self, pass: ComputePass) {
125        self.passes.push(pass);
126    }
127
128    /// Returns references to all passes whose type matches `pt`.
129    #[must_use]
130    pub fn passes_of_type(&self, pt: &PassType) -> Vec<&ComputePass> {
131        self.passes.iter().filter(|p| &p.pass_type == pt).collect()
132    }
133
134    /// Total number of bindings across all passes.
135    #[must_use]
136    pub fn total_bindings(&self) -> usize {
137        self.passes.iter().map(ComputePass::binding_count).sum()
138    }
139
140    /// Returns the number of passes in the queue.
141    #[must_use]
142    pub fn pass_count(&self) -> usize {
143        self.passes.len()
144    }
145}
146
147// ---------------------------------------------------------------------------
148// BatchedComputePass — batched GPU dispatch management
149// ---------------------------------------------------------------------------
150
151/// A single queued GPU compute dispatch command.
152#[derive(Debug, Clone)]
153pub struct DispatchCommand {
154    /// Identifier of the compute pipeline this dispatch uses.
155    pub pipeline_id: u64,
156    /// Bind group index to set before dispatching.
157    pub bind_group: u32,
158    /// Number of workgroups along the X axis.
159    pub dispatch_x: u32,
160    /// Number of workgroups along the Y axis.
161    pub dispatch_y: u32,
162    /// Number of workgroups along the Z axis.
163    pub dispatch_z: u32,
164}
165
166impl DispatchCommand {
167    /// Create a new `DispatchCommand`.
168    #[must_use]
169    pub fn new(
170        pipeline_id: u64,
171        bind_group: u32,
172        dispatch_x: u32,
173        dispatch_y: u32,
174        dispatch_z: u32,
175    ) -> Self {
176        Self {
177            pipeline_id,
178            bind_group,
179            dispatch_x,
180            dispatch_y,
181            dispatch_z,
182        }
183    }
184}
185
186/// Accumulates compute dispatch commands and issues them in batches.
187///
188/// Batching reduces command-encoder overhead by coalescing small dispatches
189/// and sorting them by pipeline so that pipeline switches are minimised.
190///
191/// When the number of pending commands reaches `max_batch_size`, an automatic
192/// flush is triggered and the commands are sorted by `pipeline_id` before
193/// being returned.
194pub struct BatchedComputePass {
195    pending: Vec<DispatchCommand>,
196    max_batch_size: usize,
197    /// Total number of commands that have been flushed across all batches.
198    total_flushed: u64,
199}
200
201impl BatchedComputePass {
202    /// Create a `BatchedComputePass` with the given `max_batch_size`.
203    ///
204    /// A `max_batch_size` of 0 is treated as 1 (each submit is auto-flushed).
205    #[must_use]
206    pub fn new(max_batch_size: usize) -> Self {
207        Self {
208            pending: Vec::new(),
209            max_batch_size: max_batch_size.max(1),
210            total_flushed: 0,
211        }
212    }
213
214    /// Submit a dispatch command.
215    ///
216    /// Returns `true` if an automatic flush was triggered (i.e., the pending
217    /// queue reached `max_batch_size`).  The caller should retrieve the flushed
218    /// batch via [`flush`][Self::flush] when this returns `true`.
219    pub fn submit(&mut self, cmd: DispatchCommand) -> bool {
220        self.pending.push(cmd);
221        if self.pending.len() >= self.max_batch_size {
222            // Auto-flush triggered; drain will happen on next flush() call.
223            true
224        } else {
225            false
226        }
227    }
228
229    /// Drain all pending commands, sorted by `pipeline_id` (ascending) to
230    /// minimise pipeline state switches on the GPU.
231    ///
232    /// Returns the sorted batch.  If the queue is empty, returns an empty `Vec`.
233    pub fn flush(&mut self) -> Vec<DispatchCommand> {
234        if self.pending.is_empty() {
235            return Vec::new();
236        }
237        let mut batch = std::mem::take(&mut self.pending);
238        // Sort by pipeline_id so similar pipelines are adjacent.
239        batch.sort_by_key(|c| c.pipeline_id);
240        self.total_flushed += batch.len() as u64;
241        batch
242    }
243
244    /// Number of commands currently pending.
245    #[must_use]
246    pub fn pending_count(&self) -> usize {
247        self.pending.len()
248    }
249
250    /// Total commands flushed across all batches.
251    #[must_use]
252    pub fn total_flushed(&self) -> u64 {
253        self.total_flushed
254    }
255
256    /// Maximum batch size before an auto-flush is triggered.
257    #[must_use]
258    pub fn max_batch_size(&self) -> usize {
259        self.max_batch_size
260    }
261}
262
263impl std::fmt::Debug for BatchedComputePass {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        f.debug_struct("BatchedComputePass")
266            .field("pending", &self.pending.len())
267            .field("max_batch_size", &self.max_batch_size)
268            .field("total_flushed", &self.total_flushed)
269            .finish()
270    }
271}
272
273// ---------------------------------------------------------------------------
274// Unit tests
275// ---------------------------------------------------------------------------
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_pass_type_video_is_real_time() {
283        assert!(PassType::Video.is_real_time());
284    }
285
286    #[test]
287    fn test_pass_type_audio_is_real_time() {
288        assert!(PassType::Audio.is_real_time());
289    }
290
291    #[test]
292    fn test_pass_type_image_not_real_time() {
293        assert!(!PassType::Image.is_real_time());
294    }
295
296    #[test]
297    fn test_pass_type_post_process_not_real_time() {
298        assert!(!PassType::PostProcess.is_real_time());
299    }
300
301    #[test]
302    fn test_buffer_binding_input() {
303        let b = BufferBinding::new(0, 1024, true);
304        assert!(b.is_input());
305        assert!(!b.is_output());
306    }
307
308    #[test]
309    fn test_buffer_binding_output() {
310        let b = BufferBinding::new(1, 2048, false);
311        assert!(b.is_output());
312        assert!(!b.is_input());
313    }
314
315    #[test]
316    fn test_compute_pass_new_defaults() {
317        let pass = ComputePass::new("test", PassType::Image);
318        assert_eq!(pass.name, "test");
319        assert_eq!(pass.workgroups, (1, 1, 1));
320        assert_eq!(pass.binding_count(), 0);
321    }
322
323    #[test]
324    fn test_compute_pass_add_input_binding() {
325        let mut pass = ComputePass::new("p", PassType::Video);
326        pass.add_input_binding(0, 512);
327        assert_eq!(pass.binding_count(), 1);
328        assert!(pass.bindings[0].is_input());
329    }
330
331    #[test]
332    fn test_compute_pass_add_output_binding() {
333        let mut pass = ComputePass::new("p", PassType::Video);
334        pass.add_output_binding(1, 512);
335        assert_eq!(pass.binding_count(), 1);
336        assert!(pass.bindings[0].is_output());
337    }
338
339    #[test]
340    fn test_total_work_items_1x1x1() {
341        let pass = ComputePass::new("p", PassType::Audio);
342        assert_eq!(pass.total_work_items(), 1);
343    }
344
345    #[test]
346    fn test_total_work_items_custom() {
347        let mut pass = ComputePass::new("p", PassType::Image);
348        pass.workgroups = (4, 8, 2);
349        assert_eq!(pass.total_work_items(), 64);
350    }
351
352    #[test]
353    fn test_pass_queue_add_and_count() {
354        let mut q = PassQueue::new();
355        q.add(ComputePass::new("a", PassType::Video));
356        q.add(ComputePass::new("b", PassType::Image));
357        assert_eq!(q.pass_count(), 2);
358    }
359
360    #[test]
361    fn test_pass_queue_passes_of_type() {
362        let mut q = PassQueue::new();
363        q.add(ComputePass::new("v1", PassType::Video));
364        q.add(ComputePass::new("i1", PassType::Image));
365        q.add(ComputePass::new("v2", PassType::Video));
366        let videos = q.passes_of_type(&PassType::Video);
367        assert_eq!(videos.len(), 2);
368    }
369
370    #[test]
371    fn test_pass_queue_passes_of_type_empty_result() {
372        let mut q = PassQueue::new();
373        q.add(ComputePass::new("a", PassType::Audio));
374        let results = q.passes_of_type(&PassType::PostProcess);
375        assert!(results.is_empty());
376    }
377
378    #[test]
379    fn test_pass_queue_total_bindings() {
380        let mut q = PassQueue::new();
381        let mut p1 = ComputePass::new("p1", PassType::Video);
382        p1.add_input_binding(0, 256);
383        p1.add_output_binding(1, 256);
384        let mut p2 = ComputePass::new("p2", PassType::Image);
385        p2.add_input_binding(0, 128);
386        q.add(p1);
387        q.add(p2);
388        assert_eq!(q.total_bindings(), 3);
389    }
390
391    #[test]
392    fn test_pass_queue_empty() {
393        let q = PassQueue::new();
394        assert_eq!(q.pass_count(), 0);
395        assert_eq!(q.total_bindings(), 0);
396    }
397
398    // ── BatchedComputePass tests ──────────────────────────────────────────────
399
400    fn make_cmd(pipeline_id: u64, x: u32) -> DispatchCommand {
401        DispatchCommand::new(pipeline_id, 0, x, 1, 1)
402    }
403
404    #[test]
405    fn test_batched_submit_no_auto_flush_below_limit() {
406        let mut b = BatchedComputePass::new(5);
407        for i in 0..4u32 {
408            let flushed = b.submit(make_cmd(1, i));
409            assert!(!flushed, "should not auto-flush below max_batch_size");
410        }
411        assert_eq!(b.pending_count(), 4);
412    }
413
414    #[test]
415    fn test_batched_submit_auto_flush_at_limit() {
416        let mut b = BatchedComputePass::new(5);
417        for i in 0..4u32 {
418            b.submit(make_cmd(1, i));
419        }
420        let triggered = b.submit(make_cmd(1, 4));
421        assert!(triggered, "5th submit should signal auto-flush");
422    }
423
424    #[test]
425    fn test_batched_flush_returns_all_pending() {
426        let mut b = BatchedComputePass::new(10);
427        b.submit(make_cmd(3, 1));
428        b.submit(make_cmd(1, 2));
429        b.submit(make_cmd(2, 3));
430        let batch = b.flush();
431        assert_eq!(batch.len(), 3);
432        assert_eq!(b.pending_count(), 0);
433    }
434
435    #[test]
436    fn test_batched_flush_sorts_by_pipeline_id() {
437        let mut b = BatchedComputePass::new(100);
438        b.submit(make_cmd(5, 0));
439        b.submit(make_cmd(1, 0));
440        b.submit(make_cmd(3, 0));
441        b.submit(make_cmd(2, 0));
442        let batch = b.flush();
443        let ids: Vec<u64> = batch.iter().map(|c| c.pipeline_id).collect();
444        assert_eq!(ids, vec![1, 2, 3, 5], "batch must be sorted by pipeline_id");
445    }
446
447    #[test]
448    fn test_batched_flush_empty_returns_empty() {
449        let mut b = BatchedComputePass::new(5);
450        let batch = b.flush();
451        assert!(
452            batch.is_empty(),
453            "flushing an empty batcher returns empty vec"
454        );
455    }
456
457    #[test]
458    fn test_batched_total_flushed_accumulates() {
459        let mut b = BatchedComputePass::new(3);
460        for i in 0..6u32 {
461            b.submit(make_cmd(1, i));
462        }
463        b.flush(); // flush remaining
464        assert_eq!(
465            b.total_flushed(),
466            6,
467            "total flushed must equal total submitted"
468        );
469    }
470
471    #[test]
472    fn test_batched_similar_pipeline_ids_adjacent() {
473        let mut b = BatchedComputePass::new(100);
474        // Mix of pipeline IDs 10 and 20
475        b.submit(make_cmd(20, 0));
476        b.submit(make_cmd(10, 0));
477        b.submit(make_cmd(20, 1));
478        b.submit(make_cmd(10, 1));
479        let batch = b.flush();
480        // Expect: 10, 10, 20, 20
481        let ids: Vec<u64> = batch.iter().map(|c| c.pipeline_id).collect();
482        assert_eq!(ids[0], 10);
483        assert_eq!(ids[1], 10);
484        assert_eq!(ids[2], 20);
485        assert_eq!(ids[3], 20);
486    }
487
488    #[test]
489    fn test_batched_max_batch_size_accessor() {
490        let b = BatchedComputePass::new(8);
491        assert_eq!(b.max_batch_size(), 8);
492    }
493
494    #[test]
495    fn test_batched_debug_fmt() {
496        let b = BatchedComputePass::new(4);
497        let s = format!("{b:?}");
498        assert!(s.contains("BatchedComputePass"));
499    }
500}