1#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PassType {
9 Video,
11 Audio,
13 Image,
15 PostProcess,
17}
18
19impl PassType {
20 #[must_use]
22 pub fn is_real_time(&self) -> bool {
23 matches!(self, Self::Video | Self::Audio)
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct BufferBinding {
30 pub slot: u8,
32 pub size_bytes: u32,
34 pub read_only: bool,
36}
37
38impl BufferBinding {
39 #[must_use]
41 pub fn new(slot: u8, size_bytes: u32, read_only: bool) -> Self {
42 Self {
43 slot,
44 size_bytes,
45 read_only,
46 }
47 }
48
49 #[must_use]
51 pub fn is_input(&self) -> bool {
52 self.read_only
53 }
54
55 #[must_use]
57 pub fn is_output(&self) -> bool {
58 !self.read_only
59 }
60}
61
62#[derive(Debug)]
64pub struct ComputePass {
65 pub name: String,
67 pub pass_type: PassType,
69 pub bindings: Vec<BufferBinding>,
71 pub workgroups: (u32, u32, u32),
73}
74
75impl ComputePass {
76 #[must_use]
78 pub fn new(name: impl Into<String>, pt: PassType) -> Self {
79 Self {
80 name: name.into(),
81 pass_type: pt,
82 bindings: Vec::new(),
83 workgroups: (1, 1, 1),
84 }
85 }
86
87 pub fn add_input_binding(&mut self, slot: u8, size: u32) {
89 self.bindings.push(BufferBinding::new(slot, size, true));
90 }
91
92 pub fn add_output_binding(&mut self, slot: u8, size: u32) {
94 self.bindings.push(BufferBinding::new(slot, size, false));
95 }
96
97 #[must_use]
99 pub fn total_work_items(&self) -> u64 {
100 u64::from(self.workgroups.0) * u64::from(self.workgroups.1) * u64::from(self.workgroups.2)
101 }
102
103 #[must_use]
105 pub fn binding_count(&self) -> usize {
106 self.bindings.len()
107 }
108}
109
110#[derive(Debug, Default)]
112pub struct PassQueue {
113 passes: Vec<ComputePass>,
114}
115
116impl PassQueue {
117 #[must_use]
119 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn add(&mut self, pass: ComputePass) {
125 self.passes.push(pass);
126 }
127
128 #[must_use]
130 pub fn passes_of_type(&self, pt: &PassType) -> Vec<&ComputePass> {
131 self.passes.iter().filter(|p| &p.pass_type == pt).collect()
132 }
133
134 #[must_use]
136 pub fn total_bindings(&self) -> usize {
137 self.passes.iter().map(ComputePass::binding_count).sum()
138 }
139
140 #[must_use]
142 pub fn pass_count(&self) -> usize {
143 self.passes.len()
144 }
145}
146
147#[derive(Debug, Clone)]
153pub struct DispatchCommand {
154 pub pipeline_id: u64,
156 pub bind_group: u32,
158 pub dispatch_x: u32,
160 pub dispatch_y: u32,
162 pub dispatch_z: u32,
164}
165
166impl DispatchCommand {
167 #[must_use]
169 pub fn new(
170 pipeline_id: u64,
171 bind_group: u32,
172 dispatch_x: u32,
173 dispatch_y: u32,
174 dispatch_z: u32,
175 ) -> Self {
176 Self {
177 pipeline_id,
178 bind_group,
179 dispatch_x,
180 dispatch_y,
181 dispatch_z,
182 }
183 }
184}
185
186pub struct BatchedComputePass {
195 pending: Vec<DispatchCommand>,
196 max_batch_size: usize,
197 total_flushed: u64,
199}
200
201impl BatchedComputePass {
202 #[must_use]
206 pub fn new(max_batch_size: usize) -> Self {
207 Self {
208 pending: Vec::new(),
209 max_batch_size: max_batch_size.max(1),
210 total_flushed: 0,
211 }
212 }
213
214 pub fn submit(&mut self, cmd: DispatchCommand) -> bool {
220 self.pending.push(cmd);
221 if self.pending.len() >= self.max_batch_size {
222 true
224 } else {
225 false
226 }
227 }
228
229 pub fn flush(&mut self) -> Vec<DispatchCommand> {
234 if self.pending.is_empty() {
235 return Vec::new();
236 }
237 let mut batch = std::mem::take(&mut self.pending);
238 batch.sort_by_key(|c| c.pipeline_id);
240 self.total_flushed += batch.len() as u64;
241 batch
242 }
243
244 #[must_use]
246 pub fn pending_count(&self) -> usize {
247 self.pending.len()
248 }
249
250 #[must_use]
252 pub fn total_flushed(&self) -> u64 {
253 self.total_flushed
254 }
255
256 #[must_use]
258 pub fn max_batch_size(&self) -> usize {
259 self.max_batch_size
260 }
261}
262
263impl std::fmt::Debug for BatchedComputePass {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 f.debug_struct("BatchedComputePass")
266 .field("pending", &self.pending.len())
267 .field("max_batch_size", &self.max_batch_size)
268 .field("total_flushed", &self.total_flushed)
269 .finish()
270 }
271}
272
273#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_pass_type_video_is_real_time() {
283 assert!(PassType::Video.is_real_time());
284 }
285
286 #[test]
287 fn test_pass_type_audio_is_real_time() {
288 assert!(PassType::Audio.is_real_time());
289 }
290
291 #[test]
292 fn test_pass_type_image_not_real_time() {
293 assert!(!PassType::Image.is_real_time());
294 }
295
296 #[test]
297 fn test_pass_type_post_process_not_real_time() {
298 assert!(!PassType::PostProcess.is_real_time());
299 }
300
301 #[test]
302 fn test_buffer_binding_input() {
303 let b = BufferBinding::new(0, 1024, true);
304 assert!(b.is_input());
305 assert!(!b.is_output());
306 }
307
308 #[test]
309 fn test_buffer_binding_output() {
310 let b = BufferBinding::new(1, 2048, false);
311 assert!(b.is_output());
312 assert!(!b.is_input());
313 }
314
315 #[test]
316 fn test_compute_pass_new_defaults() {
317 let pass = ComputePass::new("test", PassType::Image);
318 assert_eq!(pass.name, "test");
319 assert_eq!(pass.workgroups, (1, 1, 1));
320 assert_eq!(pass.binding_count(), 0);
321 }
322
323 #[test]
324 fn test_compute_pass_add_input_binding() {
325 let mut pass = ComputePass::new("p", PassType::Video);
326 pass.add_input_binding(0, 512);
327 assert_eq!(pass.binding_count(), 1);
328 assert!(pass.bindings[0].is_input());
329 }
330
331 #[test]
332 fn test_compute_pass_add_output_binding() {
333 let mut pass = ComputePass::new("p", PassType::Video);
334 pass.add_output_binding(1, 512);
335 assert_eq!(pass.binding_count(), 1);
336 assert!(pass.bindings[0].is_output());
337 }
338
339 #[test]
340 fn test_total_work_items_1x1x1() {
341 let pass = ComputePass::new("p", PassType::Audio);
342 assert_eq!(pass.total_work_items(), 1);
343 }
344
345 #[test]
346 fn test_total_work_items_custom() {
347 let mut pass = ComputePass::new("p", PassType::Image);
348 pass.workgroups = (4, 8, 2);
349 assert_eq!(pass.total_work_items(), 64);
350 }
351
352 #[test]
353 fn test_pass_queue_add_and_count() {
354 let mut q = PassQueue::new();
355 q.add(ComputePass::new("a", PassType::Video));
356 q.add(ComputePass::new("b", PassType::Image));
357 assert_eq!(q.pass_count(), 2);
358 }
359
360 #[test]
361 fn test_pass_queue_passes_of_type() {
362 let mut q = PassQueue::new();
363 q.add(ComputePass::new("v1", PassType::Video));
364 q.add(ComputePass::new("i1", PassType::Image));
365 q.add(ComputePass::new("v2", PassType::Video));
366 let videos = q.passes_of_type(&PassType::Video);
367 assert_eq!(videos.len(), 2);
368 }
369
370 #[test]
371 fn test_pass_queue_passes_of_type_empty_result() {
372 let mut q = PassQueue::new();
373 q.add(ComputePass::new("a", PassType::Audio));
374 let results = q.passes_of_type(&PassType::PostProcess);
375 assert!(results.is_empty());
376 }
377
378 #[test]
379 fn test_pass_queue_total_bindings() {
380 let mut q = PassQueue::new();
381 let mut p1 = ComputePass::new("p1", PassType::Video);
382 p1.add_input_binding(0, 256);
383 p1.add_output_binding(1, 256);
384 let mut p2 = ComputePass::new("p2", PassType::Image);
385 p2.add_input_binding(0, 128);
386 q.add(p1);
387 q.add(p2);
388 assert_eq!(q.total_bindings(), 3);
389 }
390
391 #[test]
392 fn test_pass_queue_empty() {
393 let q = PassQueue::new();
394 assert_eq!(q.pass_count(), 0);
395 assert_eq!(q.total_bindings(), 0);
396 }
397
398 fn make_cmd(pipeline_id: u64, x: u32) -> DispatchCommand {
401 DispatchCommand::new(pipeline_id, 0, x, 1, 1)
402 }
403
404 #[test]
405 fn test_batched_submit_no_auto_flush_below_limit() {
406 let mut b = BatchedComputePass::new(5);
407 for i in 0..4u32 {
408 let flushed = b.submit(make_cmd(1, i));
409 assert!(!flushed, "should not auto-flush below max_batch_size");
410 }
411 assert_eq!(b.pending_count(), 4);
412 }
413
414 #[test]
415 fn test_batched_submit_auto_flush_at_limit() {
416 let mut b = BatchedComputePass::new(5);
417 for i in 0..4u32 {
418 b.submit(make_cmd(1, i));
419 }
420 let triggered = b.submit(make_cmd(1, 4));
421 assert!(triggered, "5th submit should signal auto-flush");
422 }
423
424 #[test]
425 fn test_batched_flush_returns_all_pending() {
426 let mut b = BatchedComputePass::new(10);
427 b.submit(make_cmd(3, 1));
428 b.submit(make_cmd(1, 2));
429 b.submit(make_cmd(2, 3));
430 let batch = b.flush();
431 assert_eq!(batch.len(), 3);
432 assert_eq!(b.pending_count(), 0);
433 }
434
435 #[test]
436 fn test_batched_flush_sorts_by_pipeline_id() {
437 let mut b = BatchedComputePass::new(100);
438 b.submit(make_cmd(5, 0));
439 b.submit(make_cmd(1, 0));
440 b.submit(make_cmd(3, 0));
441 b.submit(make_cmd(2, 0));
442 let batch = b.flush();
443 let ids: Vec<u64> = batch.iter().map(|c| c.pipeline_id).collect();
444 assert_eq!(ids, vec![1, 2, 3, 5], "batch must be sorted by pipeline_id");
445 }
446
447 #[test]
448 fn test_batched_flush_empty_returns_empty() {
449 let mut b = BatchedComputePass::new(5);
450 let batch = b.flush();
451 assert!(
452 batch.is_empty(),
453 "flushing an empty batcher returns empty vec"
454 );
455 }
456
457 #[test]
458 fn test_batched_total_flushed_accumulates() {
459 let mut b = BatchedComputePass::new(3);
460 for i in 0..6u32 {
461 b.submit(make_cmd(1, i));
462 }
463 b.flush(); assert_eq!(
465 b.total_flushed(),
466 6,
467 "total flushed must equal total submitted"
468 );
469 }
470
471 #[test]
472 fn test_batched_similar_pipeline_ids_adjacent() {
473 let mut b = BatchedComputePass::new(100);
474 b.submit(make_cmd(20, 0));
476 b.submit(make_cmd(10, 0));
477 b.submit(make_cmd(20, 1));
478 b.submit(make_cmd(10, 1));
479 let batch = b.flush();
480 let ids: Vec<u64> = batch.iter().map(|c| c.pipeline_id).collect();
482 assert_eq!(ids[0], 10);
483 assert_eq!(ids[1], 10);
484 assert_eq!(ids[2], 20);
485 assert_eq!(ids[3], 20);
486 }
487
488 #[test]
489 fn test_batched_max_batch_size_accessor() {
490 let b = BatchedComputePass::new(8);
491 assert_eq!(b.max_batch_size(), 8);
492 }
493
494 #[test]
495 fn test_batched_debug_fmt() {
496 let b = BatchedComputePass::new(4);
497 let s = format!("{b:?}");
498 assert!(s.contains("BatchedComputePass"));
499 }
500}