oximedia_gpu/
compute_pass.rs1#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PassType {
9 Video,
11 Audio,
13 Image,
15 PostProcess,
17}
18
19impl PassType {
20 #[must_use]
22 pub fn is_real_time(&self) -> bool {
23 matches!(self, Self::Video | Self::Audio)
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct BufferBinding {
30 pub slot: u8,
32 pub size_bytes: u32,
34 pub read_only: bool,
36}
37
38impl BufferBinding {
39 #[must_use]
41 pub fn new(slot: u8, size_bytes: u32, read_only: bool) -> Self {
42 Self {
43 slot,
44 size_bytes,
45 read_only,
46 }
47 }
48
49 #[must_use]
51 pub fn is_input(&self) -> bool {
52 self.read_only
53 }
54
55 #[must_use]
57 pub fn is_output(&self) -> bool {
58 !self.read_only
59 }
60}
61
62#[derive(Debug)]
64pub struct ComputePass {
65 pub name: String,
67 pub pass_type: PassType,
69 pub bindings: Vec<BufferBinding>,
71 pub workgroups: (u32, u32, u32),
73}
74
75impl ComputePass {
76 #[must_use]
78 pub fn new(name: impl Into<String>, pt: PassType) -> Self {
79 Self {
80 name: name.into(),
81 pass_type: pt,
82 bindings: Vec::new(),
83 workgroups: (1, 1, 1),
84 }
85 }
86
87 pub fn add_input_binding(&mut self, slot: u8, size: u32) {
89 self.bindings.push(BufferBinding::new(slot, size, true));
90 }
91
92 pub fn add_output_binding(&mut self, slot: u8, size: u32) {
94 self.bindings.push(BufferBinding::new(slot, size, false));
95 }
96
97 #[must_use]
99 pub fn total_work_items(&self) -> u64 {
100 u64::from(self.workgroups.0) * u64::from(self.workgroups.1) * u64::from(self.workgroups.2)
101 }
102
103 #[must_use]
105 pub fn binding_count(&self) -> usize {
106 self.bindings.len()
107 }
108}
109
110#[derive(Debug, Default)]
112pub struct PassQueue {
113 passes: Vec<ComputePass>,
114}
115
116impl PassQueue {
117 #[must_use]
119 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn add(&mut self, pass: ComputePass) {
125 self.passes.push(pass);
126 }
127
128 #[must_use]
130 pub fn passes_of_type(&self, pt: &PassType) -> Vec<&ComputePass> {
131 self.passes.iter().filter(|p| &p.pass_type == pt).collect()
132 }
133
134 #[must_use]
136 pub fn total_bindings(&self) -> usize {
137 self.passes.iter().map(ComputePass::binding_count).sum()
138 }
139
140 #[must_use]
142 pub fn pass_count(&self) -> usize {
143 self.passes.len()
144 }
145}
146
147#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_pass_type_video_is_real_time() {
157 assert!(PassType::Video.is_real_time());
158 }
159
160 #[test]
161 fn test_pass_type_audio_is_real_time() {
162 assert!(PassType::Audio.is_real_time());
163 }
164
165 #[test]
166 fn test_pass_type_image_not_real_time() {
167 assert!(!PassType::Image.is_real_time());
168 }
169
170 #[test]
171 fn test_pass_type_post_process_not_real_time() {
172 assert!(!PassType::PostProcess.is_real_time());
173 }
174
175 #[test]
176 fn test_buffer_binding_input() {
177 let b = BufferBinding::new(0, 1024, true);
178 assert!(b.is_input());
179 assert!(!b.is_output());
180 }
181
182 #[test]
183 fn test_buffer_binding_output() {
184 let b = BufferBinding::new(1, 2048, false);
185 assert!(b.is_output());
186 assert!(!b.is_input());
187 }
188
189 #[test]
190 fn test_compute_pass_new_defaults() {
191 let pass = ComputePass::new("test", PassType::Image);
192 assert_eq!(pass.name, "test");
193 assert_eq!(pass.workgroups, (1, 1, 1));
194 assert_eq!(pass.binding_count(), 0);
195 }
196
197 #[test]
198 fn test_compute_pass_add_input_binding() {
199 let mut pass = ComputePass::new("p", PassType::Video);
200 pass.add_input_binding(0, 512);
201 assert_eq!(pass.binding_count(), 1);
202 assert!(pass.bindings[0].is_input());
203 }
204
205 #[test]
206 fn test_compute_pass_add_output_binding() {
207 let mut pass = ComputePass::new("p", PassType::Video);
208 pass.add_output_binding(1, 512);
209 assert_eq!(pass.binding_count(), 1);
210 assert!(pass.bindings[0].is_output());
211 }
212
213 #[test]
214 fn test_total_work_items_1x1x1() {
215 let pass = ComputePass::new("p", PassType::Audio);
216 assert_eq!(pass.total_work_items(), 1);
217 }
218
219 #[test]
220 fn test_total_work_items_custom() {
221 let mut pass = ComputePass::new("p", PassType::Image);
222 pass.workgroups = (4, 8, 2);
223 assert_eq!(pass.total_work_items(), 64);
224 }
225
226 #[test]
227 fn test_pass_queue_add_and_count() {
228 let mut q = PassQueue::new();
229 q.add(ComputePass::new("a", PassType::Video));
230 q.add(ComputePass::new("b", PassType::Image));
231 assert_eq!(q.pass_count(), 2);
232 }
233
234 #[test]
235 fn test_pass_queue_passes_of_type() {
236 let mut q = PassQueue::new();
237 q.add(ComputePass::new("v1", PassType::Video));
238 q.add(ComputePass::new("i1", PassType::Image));
239 q.add(ComputePass::new("v2", PassType::Video));
240 let videos = q.passes_of_type(&PassType::Video);
241 assert_eq!(videos.len(), 2);
242 }
243
244 #[test]
245 fn test_pass_queue_passes_of_type_empty_result() {
246 let mut q = PassQueue::new();
247 q.add(ComputePass::new("a", PassType::Audio));
248 let results = q.passes_of_type(&PassType::PostProcess);
249 assert!(results.is_empty());
250 }
251
252 #[test]
253 fn test_pass_queue_total_bindings() {
254 let mut q = PassQueue::new();
255 let mut p1 = ComputePass::new("p1", PassType::Video);
256 p1.add_input_binding(0, 256);
257 p1.add_output_binding(1, 256);
258 let mut p2 = ComputePass::new("p2", PassType::Image);
259 p2.add_input_binding(0, 128);
260 q.add(p1);
261 q.add(p2);
262 assert_eq!(q.total_bindings(), 3);
263 }
264
265 #[test]
266 fn test_pass_queue_empty() {
267 let q = PassQueue::new();
268 assert_eq!(q.pass_count(), 0);
269 assert_eq!(q.total_bindings(), 0);
270 }
271}