Skip to main content

oximedia_gpu/
compute_pass.rs

1//! GPU compute pass management — pass types, buffer bindings, and pass queues.
2
3#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6/// Category of work that a compute pass performs.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PassType {
9    /// Video processing (real-time).
10    Video,
11    /// Audio processing (real-time).
12    Audio,
13    /// Still-image processing.
14    Image,
15    /// Post-processing effects.
16    PostProcess,
17}
18
19impl PassType {
20    /// Returns `true` for pass types that operate in real-time context.
21    #[must_use]
22    pub fn is_real_time(&self) -> bool {
23        matches!(self, Self::Video | Self::Audio)
24    }
25}
26
27/// A binding between a GPU buffer slot and a logical buffer.
28#[derive(Debug, Clone)]
29pub struct BufferBinding {
30    /// The shader binding slot index.
31    pub slot: u8,
32    /// Size of the buffer in bytes.
33    pub size_bytes: u32,
34    /// Whether the binding is read-only (i.e. an input buffer).
35    pub read_only: bool,
36}
37
38impl BufferBinding {
39    /// Creates a new `BufferBinding`.
40    #[must_use]
41    pub fn new(slot: u8, size_bytes: u32, read_only: bool) -> Self {
42        Self {
43            slot,
44            size_bytes,
45            read_only,
46        }
47    }
48
49    /// Returns `true` if this binding is an input (read-only) binding.
50    #[must_use]
51    pub fn is_input(&self) -> bool {
52        self.read_only
53    }
54
55    /// Returns `true` if this binding is an output (writable) binding.
56    #[must_use]
57    pub fn is_output(&self) -> bool {
58        !self.read_only
59    }
60}
61
62/// A single compute pass with a name, type, buffer bindings, and dispatch dimensions.
63#[derive(Debug)]
64pub struct ComputePass {
65    /// Human-readable name for debugging.
66    pub name: String,
67    /// The category of this pass.
68    pub pass_type: PassType,
69    /// Buffer bindings used by this pass.
70    pub bindings: Vec<BufferBinding>,
71    /// Workgroup dispatch dimensions (x, y, z).
72    pub workgroups: (u32, u32, u32),
73}
74
75impl ComputePass {
76    /// Creates a new `ComputePass` with no bindings and a default dispatch of (1, 1, 1).
77    #[must_use]
78    pub fn new(name: impl Into<String>, pt: PassType) -> Self {
79        Self {
80            name: name.into(),
81            pass_type: pt,
82            bindings: Vec::new(),
83            workgroups: (1, 1, 1),
84        }
85    }
86
87    /// Adds a read-only (input) buffer binding on the given slot.
88    pub fn add_input_binding(&mut self, slot: u8, size: u32) {
89        self.bindings.push(BufferBinding::new(slot, size, true));
90    }
91
92    /// Adds a writable (output) buffer binding on the given slot.
93    pub fn add_output_binding(&mut self, slot: u8, size: u32) {
94        self.bindings.push(BufferBinding::new(slot, size, false));
95    }
96
97    /// Total work items = workgroups.x × workgroups.y × workgroups.z.
98    #[must_use]
99    pub fn total_work_items(&self) -> u64 {
100        u64::from(self.workgroups.0) * u64::from(self.workgroups.1) * u64::from(self.workgroups.2)
101    }
102
103    /// Returns the number of bindings attached to this pass.
104    #[must_use]
105    pub fn binding_count(&self) -> usize {
106        self.bindings.len()
107    }
108}
109
110/// An ordered queue of [`ComputePass`] entries.
111#[derive(Debug, Default)]
112pub struct PassQueue {
113    passes: Vec<ComputePass>,
114}
115
116impl PassQueue {
117    /// Creates an empty `PassQueue`.
118    #[must_use]
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Appends a pass to the queue.
124    pub fn add(&mut self, pass: ComputePass) {
125        self.passes.push(pass);
126    }
127
128    /// Returns references to all passes whose type matches `pt`.
129    #[must_use]
130    pub fn passes_of_type(&self, pt: &PassType) -> Vec<&ComputePass> {
131        self.passes.iter().filter(|p| &p.pass_type == pt).collect()
132    }
133
134    /// Total number of bindings across all passes.
135    #[must_use]
136    pub fn total_bindings(&self) -> usize {
137        self.passes.iter().map(ComputePass::binding_count).sum()
138    }
139
140    /// Returns the number of passes in the queue.
141    #[must_use]
142    pub fn pass_count(&self) -> usize {
143        self.passes.len()
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Unit tests
149// ---------------------------------------------------------------------------
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_pass_type_video_is_real_time() {
157        assert!(PassType::Video.is_real_time());
158    }
159
160    #[test]
161    fn test_pass_type_audio_is_real_time() {
162        assert!(PassType::Audio.is_real_time());
163    }
164
165    #[test]
166    fn test_pass_type_image_not_real_time() {
167        assert!(!PassType::Image.is_real_time());
168    }
169
170    #[test]
171    fn test_pass_type_post_process_not_real_time() {
172        assert!(!PassType::PostProcess.is_real_time());
173    }
174
175    #[test]
176    fn test_buffer_binding_input() {
177        let b = BufferBinding::new(0, 1024, true);
178        assert!(b.is_input());
179        assert!(!b.is_output());
180    }
181
182    #[test]
183    fn test_buffer_binding_output() {
184        let b = BufferBinding::new(1, 2048, false);
185        assert!(b.is_output());
186        assert!(!b.is_input());
187    }
188
189    #[test]
190    fn test_compute_pass_new_defaults() {
191        let pass = ComputePass::new("test", PassType::Image);
192        assert_eq!(pass.name, "test");
193        assert_eq!(pass.workgroups, (1, 1, 1));
194        assert_eq!(pass.binding_count(), 0);
195    }
196
197    #[test]
198    fn test_compute_pass_add_input_binding() {
199        let mut pass = ComputePass::new("p", PassType::Video);
200        pass.add_input_binding(0, 512);
201        assert_eq!(pass.binding_count(), 1);
202        assert!(pass.bindings[0].is_input());
203    }
204
205    #[test]
206    fn test_compute_pass_add_output_binding() {
207        let mut pass = ComputePass::new("p", PassType::Video);
208        pass.add_output_binding(1, 512);
209        assert_eq!(pass.binding_count(), 1);
210        assert!(pass.bindings[0].is_output());
211    }
212
213    #[test]
214    fn test_total_work_items_1x1x1() {
215        let pass = ComputePass::new("p", PassType::Audio);
216        assert_eq!(pass.total_work_items(), 1);
217    }
218
219    #[test]
220    fn test_total_work_items_custom() {
221        let mut pass = ComputePass::new("p", PassType::Image);
222        pass.workgroups = (4, 8, 2);
223        assert_eq!(pass.total_work_items(), 64);
224    }
225
226    #[test]
227    fn test_pass_queue_add_and_count() {
228        let mut q = PassQueue::new();
229        q.add(ComputePass::new("a", PassType::Video));
230        q.add(ComputePass::new("b", PassType::Image));
231        assert_eq!(q.pass_count(), 2);
232    }
233
234    #[test]
235    fn test_pass_queue_passes_of_type() {
236        let mut q = PassQueue::new();
237        q.add(ComputePass::new("v1", PassType::Video));
238        q.add(ComputePass::new("i1", PassType::Image));
239        q.add(ComputePass::new("v2", PassType::Video));
240        let videos = q.passes_of_type(&PassType::Video);
241        assert_eq!(videos.len(), 2);
242    }
243
244    #[test]
245    fn test_pass_queue_passes_of_type_empty_result() {
246        let mut q = PassQueue::new();
247        q.add(ComputePass::new("a", PassType::Audio));
248        let results = q.passes_of_type(&PassType::PostProcess);
249        assert!(results.is_empty());
250    }
251
252    #[test]
253    fn test_pass_queue_total_bindings() {
254        let mut q = PassQueue::new();
255        let mut p1 = ComputePass::new("p1", PassType::Video);
256        p1.add_input_binding(0, 256);
257        p1.add_output_binding(1, 256);
258        let mut p2 = ComputePass::new("p2", PassType::Image);
259        p2.add_input_binding(0, 128);
260        q.add(p1);
261        q.add(p2);
262        assert_eq!(q.total_bindings(), 3);
263    }
264
265    #[test]
266    fn test_pass_queue_empty() {
267        let q = PassQueue::new();
268        assert_eq!(q.pass_count(), 0);
269        assert_eq!(q.total_bindings(), 0);
270    }
271}