Skip to main content

oximedia_gpu/
pipeline.rs

1//! GPU processing pipeline management
2//!
3//! Provides a directed-acyclic-graph (DAG) style pipeline for composing GPU
4//! processing stages. Pipeline nodes are connected via edges; the pipeline
5//! validates that the graph is acyclic before execution.
6
7#![allow(dead_code)]
8#![allow(clippy::cast_precision_loss)]
9
10/// A stage in the GPU processing pipeline
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum PipelineStage {
13    /// Decode compressed media
14    Decode,
15    /// Colour-space conversion (e.g., YUV → RGB)
16    Colorspace,
17    /// Image filter (blur, sharpen, …)
18    Filter,
19    /// Encode to compressed output
20    Encode,
21    /// Render to display surface
22    Display,
23}
24
25impl std::fmt::Display for PipelineStage {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            Self::Decode => write!(f, "Decode"),
29            Self::Colorspace => write!(f, "Colorspace"),
30            Self::Filter => write!(f, "Filter"),
31            Self::Encode => write!(f, "Encode"),
32            Self::Display => write!(f, "Display"),
33        }
34    }
35}
36
37/// A single node in the GPU pipeline
38#[derive(Debug, Clone)]
39pub struct PipelineNode {
40    /// Unique identifier for this node
41    pub id: u64,
42    /// The processing stage this node represents
43    pub stage: PipelineStage,
44    /// Human-readable name
45    pub name: String,
46    /// Number of input connections
47    pub input_count: usize,
48    /// Number of output connections
49    pub output_count: usize,
50}
51
52impl PipelineNode {
53    /// Create a new pipeline node
54    pub fn new(id: u64, stage: PipelineStage, name: impl Into<String>) -> Self {
55        Self {
56            id,
57            stage,
58            name: name.into(),
59            input_count: 0,
60            output_count: 0,
61        }
62    }
63}
64
65/// A directed-acyclic-graph GPU processing pipeline
66#[derive(Debug, Clone)]
67pub struct GpuPipeline {
68    nodes: Vec<PipelineNode>,
69    edges: Vec<(u64, u64)>,
70    active: bool,
71}
72
73impl GpuPipeline {
74    /// Create a new empty pipeline
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            nodes: Vec::new(),
79            edges: Vec::new(),
80            active: false,
81        }
82    }
83
84    /// Add a node to the pipeline; returns the node id
85    pub fn add_node(&mut self, mut node: PipelineNode) -> u64 {
86        let id = node.id;
87        node.input_count = 0;
88        node.output_count = 0;
89        self.nodes.push(node);
90        id
91    }
92
93    /// Connect two nodes by id (from → to)
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if either node does not exist or if the connection
98    /// would create a cycle.
99    pub fn connect(&mut self, from: u64, to: u64) -> Result<(), String> {
100        if self.find_node(from).is_none() {
101            return Err(format!("Source node {from} not found"));
102        }
103        if self.find_node(to).is_none() {
104            return Err(format!("Target node {to} not found"));
105        }
106        if from == to {
107            return Err("Self-loop not allowed".to_string());
108        }
109        // Check for duplicate edge
110        if self.edges.contains(&(from, to)) {
111            return Err(format!("Edge ({from}, {to}) already exists"));
112        }
113        // Tentatively add and check for cycle
114        self.edges.push((from, to));
115        if self.has_cycle() {
116            self.edges.pop();
117            return Err(format!("Adding edge ({from}, {to}) would create a cycle"));
118        }
119        // Update port counts
120        if let Some(n) = self.nodes.iter_mut().find(|n| n.id == from) {
121            n.output_count += 1;
122        }
123        if let Some(n) = self.nodes.iter_mut().find(|n| n.id == to) {
124            n.input_count += 1;
125        }
126        Ok(())
127    }
128
129    /// Validate the pipeline (no isolated sinks without a source, etc.)
130    ///
131    /// # Errors
132    ///
133    /// Returns an error describing the first validation problem found.
134    pub fn validate(&self) -> Result<(), String> {
135        if self.nodes.is_empty() {
136            return Err("Pipeline has no nodes".to_string());
137        }
138        if self.has_cycle() {
139            return Err("Pipeline contains a cycle".to_string());
140        }
141        Ok(())
142    }
143
144    /// Number of nodes in the pipeline
145    #[must_use]
146    pub fn node_count(&self) -> usize {
147        self.nodes.len()
148    }
149
150    /// Returns `true` if the pipeline is valid (non-empty, acyclic)
151    #[must_use]
152    pub fn is_valid(&self) -> bool {
153        self.validate().is_ok()
154    }
155
156    /// Activate the pipeline for processing
157    pub fn activate(&mut self) {
158        self.active = true;
159    }
160
161    /// Deactivate the pipeline
162    pub fn deactivate(&mut self) {
163        self.active = false;
164    }
165
166    /// Whether the pipeline is currently active
167    #[must_use]
168    pub fn is_active(&self) -> bool {
169        self.active
170    }
171
172    /// Access the node list
173    #[must_use]
174    pub fn nodes(&self) -> &[PipelineNode] {
175        &self.nodes
176    }
177
178    /// Access the edge list
179    #[must_use]
180    pub fn edges(&self) -> &[(u64, u64)] {
181        &self.edges
182    }
183
184    // ----- private helpers -----
185
186    fn find_node(&self, id: u64) -> Option<&PipelineNode> {
187        self.nodes.iter().find(|n| n.id == id)
188    }
189
190    /// Cycle detection via DFS
191    fn has_cycle(&self) -> bool {
192        let node_ids: Vec<u64> = self.nodes.iter().map(|n| n.id).collect();
193        let mut visited = std::collections::HashSet::new();
194        let mut stack = std::collections::HashSet::new();
195
196        for &id in &node_ids {
197            if self.dfs_cycle(id, &mut visited, &mut stack) {
198                return true;
199            }
200        }
201        false
202    }
203
204    fn dfs_cycle(
205        &self,
206        node: u64,
207        visited: &mut std::collections::HashSet<u64>,
208        stack: &mut std::collections::HashSet<u64>,
209    ) -> bool {
210        if stack.contains(&node) {
211            return true;
212        }
213        if visited.contains(&node) {
214            return false;
215        }
216        visited.insert(node);
217        stack.insert(node);
218        for &(from, to) in &self.edges {
219            if from == node && self.dfs_cycle(to, visited, stack) {
220                return true;
221            }
222        }
223        stack.remove(&node);
224        false
225    }
226}
227
228impl Default for GpuPipeline {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234/// Aggregated performance metrics for a pipeline
235#[derive(Debug, Clone)]
236pub struct PipelineMetrics {
237    /// Total frames successfully processed
238    pub frames_processed: u64,
239    /// Average frame processing latency in milliseconds
240    pub avg_latency_ms: f64,
241    /// Number of frames dropped due to backpressure / overflow
242    pub dropped_frames: u64,
243    /// GPU utilisation in [0.0, 1.0]
244    pub utilization: f64,
245}
246
247impl PipelineMetrics {
248    /// Create a zeroed metrics record
249    #[must_use]
250    pub fn new() -> Self {
251        Self {
252            frames_processed: 0,
253            avg_latency_ms: 0.0,
254            dropped_frames: 0,
255            utilization: 0.0,
256        }
257    }
258
259    /// Record a new frame with the given latency
260    pub fn record_frame(&mut self, latency_ms: f64) {
261        let n = self.frames_processed as f64;
262        self.avg_latency_ms = (self.avg_latency_ms * n + latency_ms) / (n + 1.0);
263        self.frames_processed += 1;
264    }
265
266    /// Record a dropped frame
267    pub fn record_drop(&mut self) {
268        self.dropped_frames += 1;
269    }
270
271    /// Drop rate in [0.0, 1.0]
272    #[must_use]
273    pub fn drop_rate(&self) -> f64 {
274        let total = self.frames_processed + self.dropped_frames;
275        if total == 0 {
276            0.0
277        } else {
278            self.dropped_frames as f64 / total as f64
279        }
280    }
281}
282
283impl Default for PipelineMetrics {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289// ============================================================
290// BarrierBatcher — batched GPU memory barrier management
291// ============================================================
292
293/// Direction of a buffer memory barrier.
294#[derive(Debug, Clone, PartialEq, Eq)]
295pub enum BarrierKind {
296    /// Read-after-write hazard: a prior write must complete before a read.
297    ReadAfterWrite,
298    /// Write-after-read hazard: a prior read must complete before a write.
299    WriteAfterRead,
300}
301
302/// Represents a logical buffer barrier between two pipeline stages.
303#[derive(Debug, Clone)]
304pub struct BufferBarrier {
305    /// Identifier of the buffer resource.
306    pub buffer_id: u64,
307    /// Kind of hazard this barrier resolves.
308    pub kind: BarrierKind,
309    /// Source pipeline stage (ordering context).
310    pub src_stage: PipelineStage,
311    /// Destination pipeline stage (ordering context).
312    pub dst_stage: PipelineStage,
313}
314
315impl BufferBarrier {
316    /// Create a new `BufferBarrier`.
317    #[must_use]
318    pub fn new(buffer_id: u64, kind: BarrierKind, src: PipelineStage, dst: PipelineStage) -> Self {
319        Self {
320            buffer_id,
321            kind,
322            src_stage: src,
323            dst_stage: dst,
324        }
325    }
326}
327
328/// Strategy governing when accumulated barriers are flushed to the encoder.
329#[derive(Debug, Clone, Copy, PartialEq, Eq)]
330pub enum BarrierStrategy {
331    /// Flush after every single barrier is added (maximum safety, more overhead).
332    Eager,
333    /// Flush once at least `N` barriers have accumulated.
334    Batched(usize),
335    /// Only flush when explicitly requested (e.g. at pass boundaries).
336    Deferred,
337}
338
339/// Tracks a recorded flush event for observability in tests and diagnostics.
340#[derive(Debug, Clone)]
341pub struct FlushRecord {
342    /// Number of read-after-write barriers sent in this flush.
343    pub raw_count: usize,
344    /// Number of write-after-read barriers sent in this flush.
345    pub war_count: usize,
346}
347
348/// Accumulates GPU buffer barriers and issues them to a mock encoder in batches.
349///
350/// In a real GPU engine the `flush` call would translate the accumulated
351/// barriers into a `wgpu::CommandEncoder::insert_debug_marker` / pipeline-
352/// barrier equivalent.  Here we model the encoder with a simple callback so
353/// that the logic can be exercised in pure-CPU unit tests without a GPU device.
354pub struct BarrierBatcher {
355    pending_read_after_write: Vec<BufferBarrier>,
356    pending_write_after_read: Vec<BufferBarrier>,
357    strategy: BarrierStrategy,
358    /// Number of individual barriers that have been batched and submitted.
359    batched_count: u64,
360    /// History of flush events (used for test assertions and diagnostics).
361    flush_log: Vec<FlushRecord>,
362}
363
364impl BarrierBatcher {
365    /// Create a `BarrierBatcher` with the given strategy.
366    #[must_use]
367    pub fn new(strategy: BarrierStrategy) -> Self {
368        Self {
369            pending_read_after_write: Vec::new(),
370            pending_write_after_read: Vec::new(),
371            strategy,
372            batched_count: 0,
373            flush_log: Vec::new(),
374        }
375    }
376
377    /// Add a barrier.  In `Eager` mode this immediately triggers a flush;
378    /// in `Batched(n)` mode a flush is triggered once `n` barriers are pending;
379    /// in `Deferred` mode barriers accumulate until `flush()` is called explicitly.
380    ///
381    /// Returns `true` if a flush occurred as a result of adding this barrier.
382    pub fn add_barrier(&mut self, barrier: BufferBarrier) -> bool {
383        match barrier.kind {
384            BarrierKind::ReadAfterWrite => self.pending_read_after_write.push(barrier),
385            BarrierKind::WriteAfterRead => self.pending_write_after_read.push(barrier),
386        }
387
388        let should_flush = match self.strategy {
389            BarrierStrategy::Eager => true,
390            BarrierStrategy::Batched(n) => self.pending_count() >= n,
391            BarrierStrategy::Deferred => false,
392        };
393
394        if should_flush {
395            self.flush();
396            true
397        } else {
398            false
399        }
400    }
401
402    /// Flush all pending barriers to the (simulated) encoder.
403    ///
404    /// Returns the total number of barriers flushed in this call.
405    /// After flushing, the pending queues are empty.
406    pub fn flush(&mut self) -> usize {
407        let raw = self.pending_read_after_write.len();
408        let war = self.pending_write_after_read.len();
409        let total = raw + war;
410
411        if total == 0 {
412            return 0;
413        }
414
415        // Record this flush for observability.
416        self.flush_log.push(FlushRecord {
417            raw_count: raw,
418            war_count: war,
419        });
420        self.batched_count += total as u64;
421
422        self.pending_read_after_write.clear();
423        self.pending_write_after_read.clear();
424
425        total
426    }
427
428    /// Number of barriers currently waiting to be flushed.
429    #[must_use]
430    pub fn pending_count(&self) -> usize {
431        self.pending_read_after_write.len() + self.pending_write_after_read.len()
432    }
433
434    /// Total number of individual barriers that have been submitted to the encoder.
435    #[must_use]
436    pub fn batched_count(&self) -> u64 {
437        self.batched_count
438    }
439
440    /// Immutable view of the flush history.
441    #[must_use]
442    pub fn flush_log(&self) -> &[FlushRecord] {
443        &self.flush_log
444    }
445
446    /// Active strategy.
447    #[must_use]
448    pub fn strategy(&self) -> BarrierStrategy {
449        self.strategy
450    }
451}
452
453impl std::fmt::Debug for BarrierBatcher {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        f.debug_struct("BarrierBatcher")
456            .field("strategy", &self.strategy)
457            .field("pending", &self.pending_count())
458            .field("batched_count", &self.batched_count)
459            .field("flush_events", &self.flush_log.len())
460            .finish()
461    }
462}
463
464// ============================================================
465// Unit tests
466// ============================================================
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    fn make_node(id: u64, stage: PipelineStage) -> PipelineNode {
472        PipelineNode::new(id, stage, format!("node_{id}"))
473    }
474
475    #[test]
476    fn test_pipeline_new_is_empty() {
477        let p = GpuPipeline::new();
478        assert_eq!(p.node_count(), 0);
479        assert!(!p.is_active());
480    }
481
482    #[test]
483    fn test_add_node_returns_id() {
484        let mut p = GpuPipeline::new();
485        let id = p.add_node(make_node(42, PipelineStage::Decode));
486        assert_eq!(id, 42);
487        assert_eq!(p.node_count(), 1);
488    }
489
490    #[test]
491    fn test_connect_nodes_ok() {
492        let mut p = GpuPipeline::new();
493        p.add_node(make_node(1, PipelineStage::Decode));
494        p.add_node(make_node(2, PipelineStage::Colorspace));
495        assert!(p.connect(1, 2).is_ok());
496        assert_eq!(p.edges().len(), 1);
497    }
498
499    #[test]
500    fn test_connect_missing_node_err() {
501        let mut p = GpuPipeline::new();
502        p.add_node(make_node(1, PipelineStage::Decode));
503        assert!(p.connect(1, 99).is_err());
504    }
505
506    #[test]
507    fn test_connect_self_loop_err() {
508        let mut p = GpuPipeline::new();
509        p.add_node(make_node(1, PipelineStage::Filter));
510        assert!(p.connect(1, 1).is_err());
511    }
512
513    #[test]
514    fn test_connect_duplicate_edge_err() {
515        let mut p = GpuPipeline::new();
516        p.add_node(make_node(1, PipelineStage::Decode));
517        p.add_node(make_node(2, PipelineStage::Encode));
518        p.connect(1, 2).expect("pipeline connection should succeed");
519        assert!(p.connect(1, 2).is_err());
520    }
521
522    #[test]
523    fn test_connect_cycle_detected() {
524        let mut p = GpuPipeline::new();
525        p.add_node(make_node(1, PipelineStage::Decode));
526        p.add_node(make_node(2, PipelineStage::Filter));
527        p.add_node(make_node(3, PipelineStage::Encode));
528        p.connect(1, 2).expect("pipeline connection should succeed");
529        p.connect(2, 3).expect("pipeline connection should succeed");
530        assert!(p.connect(3, 1).is_err());
531    }
532
533    #[test]
534    fn test_validate_empty_err() {
535        let p = GpuPipeline::new();
536        assert!(p.validate().is_err());
537    }
538
539    #[test]
540    fn test_validate_single_node_ok() {
541        let mut p = GpuPipeline::new();
542        p.add_node(make_node(1, PipelineStage::Display));
543        assert!(p.validate().is_ok());
544        assert!(p.is_valid());
545    }
546
547    #[test]
548    fn test_activate_deactivate() {
549        let mut p = GpuPipeline::new();
550        p.activate();
551        assert!(p.is_active());
552        p.deactivate();
553        assert!(!p.is_active());
554    }
555
556    #[test]
557    fn test_port_counts_updated() {
558        let mut p = GpuPipeline::new();
559        p.add_node(make_node(1, PipelineStage::Decode));
560        p.add_node(make_node(2, PipelineStage::Encode));
561        p.connect(1, 2).expect("pipeline connection should succeed");
562        let n1 = p
563            .nodes()
564            .iter()
565            .find(|n| n.id == 1)
566            .expect("find should return a result");
567        let n2 = p
568            .nodes()
569            .iter()
570            .find(|n| n.id == 2)
571            .expect("find should return a result");
572        assert_eq!(n1.output_count, 1);
573        assert_eq!(n2.input_count, 1);
574    }
575
576    #[test]
577    fn test_metrics_record_frame() {
578        let mut m = PipelineMetrics::new();
579        m.record_frame(10.0);
580        m.record_frame(20.0);
581        assert_eq!(m.frames_processed, 2);
582        assert!((m.avg_latency_ms - 15.0).abs() < 1e-9);
583    }
584
585    #[test]
586    fn test_metrics_drop_rate() {
587        let mut m = PipelineMetrics::new();
588        m.record_frame(5.0);
589        m.record_drop();
590        assert!((m.drop_rate() - 0.5).abs() < 1e-9);
591    }
592
593    #[test]
594    fn test_stage_display() {
595        assert_eq!(PipelineStage::Decode.to_string(), "Decode");
596        assert_eq!(PipelineStage::Display.to_string(), "Display");
597    }
598
599    // ── BarrierBatcher tests ──────────────────────────────────────────────────
600
601    fn raw_barrier(buf_id: u64) -> BufferBarrier {
602        BufferBarrier::new(
603            buf_id,
604            BarrierKind::ReadAfterWrite,
605            PipelineStage::Decode,
606            PipelineStage::Filter,
607        )
608    }
609
610    fn war_barrier(buf_id: u64) -> BufferBarrier {
611        BufferBarrier::new(
612            buf_id,
613            BarrierKind::WriteAfterRead,
614            PipelineStage::Filter,
615            PipelineStage::Encode,
616        )
617    }
618
619    #[test]
620    fn test_batcher_eager_flushes_immediately() {
621        let mut b = BarrierBatcher::new(BarrierStrategy::Eager);
622        let flushed = b.add_barrier(raw_barrier(1));
623        assert!(flushed, "eager strategy must flush on every add");
624        assert_eq!(b.pending_count(), 0, "pending must be 0 after eager flush");
625        assert_eq!(b.batched_count(), 1);
626    }
627
628    #[test]
629    fn test_batcher_eager_each_barrier_is_one_flush() {
630        let mut b = BarrierBatcher::new(BarrierStrategy::Eager);
631        for i in 0..5u64 {
632            b.add_barrier(raw_barrier(i));
633        }
634        assert_eq!(b.flush_log().len(), 5, "5 adds → 5 flushes in eager mode");
635        assert_eq!(b.batched_count(), 5);
636    }
637
638    #[test]
639    fn test_batcher_batched_accumulates_before_flush() {
640        let mut b = BarrierBatcher::new(BarrierStrategy::Batched(5));
641        // Add 4 barriers — should not flush yet
642        for i in 0..4u64 {
643            let flushed = b.add_barrier(raw_barrier(i));
644            assert!(!flushed, "should not flush before reaching threshold");
645        }
646        assert_eq!(b.pending_count(), 4);
647        assert_eq!(b.flush_log().len(), 0, "no flushes yet");
648        // 5th barrier triggers flush
649        let flushed = b.add_barrier(raw_barrier(4));
650        assert!(flushed, "5th barrier must trigger flush");
651        assert_eq!(b.pending_count(), 0);
652        assert_eq!(b.flush_log().len(), 1, "exactly 1 batch flush occurred");
653        assert_eq!(b.flush_log()[0].raw_count, 5);
654        assert_eq!(b.batched_count(), 5);
655    }
656
657    #[test]
658    fn test_batcher_batched_two_batches() {
659        let mut b = BarrierBatcher::new(BarrierStrategy::Batched(3));
660        for i in 0..6u64 {
661            b.add_barrier(raw_barrier(i));
662        }
663        assert_eq!(
664            b.flush_log().len(),
665            2,
666            "6 barriers at threshold=3 → 2 flushes"
667        );
668        assert_eq!(b.batched_count(), 6);
669    }
670
671    #[test]
672    fn test_batcher_deferred_does_not_auto_flush() {
673        let mut b = BarrierBatcher::new(BarrierStrategy::Deferred);
674        for i in 0..10u64 {
675            let flushed = b.add_barrier(raw_barrier(i));
676            assert!(!flushed, "deferred mode must never auto-flush");
677        }
678        assert_eq!(b.pending_count(), 10);
679        assert_eq!(b.flush_log().len(), 0);
680    }
681
682    #[test]
683    fn test_batcher_manual_flush_clears_pending() {
684        let mut b = BarrierBatcher::new(BarrierStrategy::Deferred);
685        b.add_barrier(raw_barrier(1));
686        b.add_barrier(war_barrier(2));
687        assert_eq!(b.pending_count(), 2);
688        let flushed_count = b.flush();
689        assert_eq!(flushed_count, 2);
690        assert_eq!(b.pending_count(), 0);
691        assert_eq!(b.batched_count(), 2);
692    }
693
694    #[test]
695    fn test_batcher_empty_flush_does_nothing() {
696        let mut b = BarrierBatcher::new(BarrierStrategy::Deferred);
697        let count = b.flush();
698        assert_eq!(count, 0, "flush on empty batcher should return 0");
699        assert_eq!(
700            b.flush_log().len(),
701            0,
702            "empty flush should not log a record"
703        );
704    }
705
706    #[test]
707    fn test_batcher_mixed_kinds_tracked_separately() {
708        let mut b = BarrierBatcher::new(BarrierStrategy::Deferred);
709        b.add_barrier(raw_barrier(1));
710        b.add_barrier(raw_barrier(2));
711        b.add_barrier(war_barrier(3));
712        b.flush();
713        let record = &b.flush_log()[0];
714        assert_eq!(record.raw_count, 2);
715        assert_eq!(record.war_count, 1);
716    }
717
718    #[test]
719    fn test_batcher_strategy_accessor() {
720        let b = BarrierBatcher::new(BarrierStrategy::Batched(8));
721        assert_eq!(b.strategy(), BarrierStrategy::Batched(8));
722    }
723
724    #[test]
725    fn test_batcher_debug_fmt() {
726        let b = BarrierBatcher::new(BarrierStrategy::Eager);
727        let s = format!("{b:?}");
728        assert!(s.contains("BarrierBatcher"));
729    }
730}