Skip to main content

oximedia_graph/
pipeline_graph.rs

1//! Pipeline-centric graph for media processing.
2#![allow(dead_code)]
3
4use std::collections::{HashMap, HashSet, VecDeque};
5
6/// Type of a node in a pipeline graph.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum PipelineNodeType {
9    /// Produces frames but has no upstream inputs (e.g., a decoder).
10    Source,
11    /// Transforms frames (e.g., scaler, colour corrector).
12    Transform,
13    /// Merges multiple streams into one (e.g., overlay, mix).
14    Merge,
15    /// Splits one stream into multiple (e.g., tee).
16    Split,
17    /// Consumes frames and produces no outputs (e.g., encoder, display sink).
18    Sink,
19}
20
21impl PipelineNodeType {
22    /// Returns `true` for node types that terminate the pipeline (no outputs).
23    #[must_use]
24    pub fn is_terminal(&self) -> bool {
25        matches!(self, Self::Sink)
26    }
27
28    /// Returns `true` for node types that originate data (no required inputs).
29    #[must_use]
30    pub fn is_source(&self) -> bool {
31        matches!(self, Self::Source)
32    }
33}
34
35/// Lightweight handle identifying a node inside a [`PipelineGraph`].
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub struct PipelineNodeId(pub u32);
38
39/// Metadata stored per node.
40#[derive(Debug, Clone)]
41struct NodeMeta {
42    node_type: PipelineNodeType,
43    label: String,
44}
45
46/// A directed acyclic graph of pipeline nodes.
47pub struct PipelineGraph {
48    nodes: HashMap<PipelineNodeId, NodeMeta>,
49    /// Adjacency list: `edges[a]` = set of nodes that `a` feeds into.
50    edges: HashMap<PipelineNodeId, HashSet<PipelineNodeId>>,
51    next_id: u32,
52}
53
54impl PipelineGraph {
55    /// Create an empty graph.
56    #[must_use]
57    pub fn new() -> Self {
58        Self {
59            nodes: HashMap::new(),
60            edges: HashMap::new(),
61            next_id: 0,
62        }
63    }
64
65    /// Add a new node and return its id.
66    pub fn add_node(
67        &mut self,
68        node_type: PipelineNodeType,
69        label: impl Into<String>,
70    ) -> PipelineNodeId {
71        let id = PipelineNodeId(self.next_id);
72        self.next_id += 1;
73        self.nodes.insert(
74            id,
75            NodeMeta {
76                node_type,
77                label: label.into(),
78            },
79        );
80        self.edges.entry(id).or_default();
81        id
82    }
83
84    /// Connect `from` → `to`.
85    ///
86    /// Returns an error string if either node does not exist.
87    pub fn connect(&mut self, from: PipelineNodeId, to: PipelineNodeId) -> Result<(), String> {
88        if !self.nodes.contains_key(&from) {
89            return Err(format!("Source node {:?} not found", from));
90        }
91        if !self.nodes.contains_key(&to) {
92            return Err(format!("Destination node {:?} not found", to));
93        }
94        self.edges.entry(from).or_default().insert(to);
95        // Ensure destination also has an entry in the adjacency list.
96        self.edges.entry(to).or_default();
97        Ok(())
98    }
99
100    /// Returns `true` if the graph contains no directed cycles (is a valid DAG).
101    #[must_use]
102    pub fn is_valid_dag(&self) -> bool {
103        // Kahn's algorithm
104        let mut in_degree: HashMap<PipelineNodeId, usize> =
105            self.nodes.keys().map(|&k| (k, 0)).collect();
106        for neighbours in self.edges.values() {
107            for &n in neighbours {
108                *in_degree.entry(n).or_insert(0) += 1;
109            }
110        }
111        let mut queue: VecDeque<PipelineNodeId> = in_degree
112            .iter()
113            .filter(|(_, &d)| d == 0)
114            .map(|(&k, _)| k)
115            .collect();
116        let mut visited = 0usize;
117        while let Some(node) = queue.pop_front() {
118            visited += 1;
119            if let Some(neighbours) = self.edges.get(&node) {
120                for &n in neighbours {
121                    let deg = in_degree.entry(n).or_insert(0);
122                    *deg -= 1;
123                    if *deg == 0 {
124                        queue.push_back(n);
125                    }
126                }
127            }
128        }
129        visited == self.nodes.len()
130    }
131
132    /// Returns the ids of all source nodes (nodes with no incoming edges).
133    #[must_use]
134    pub fn sources(&self) -> Vec<PipelineNodeId> {
135        let mut has_incoming: HashSet<PipelineNodeId> = HashSet::new();
136        for neighbours in self.edges.values() {
137            for &n in neighbours {
138                has_incoming.insert(n);
139            }
140        }
141        self.nodes
142            .keys()
143            .copied()
144            .filter(|id| !has_incoming.contains(id))
145            .collect()
146    }
147
148    /// Returns the ids of all sink nodes (nodes with no outgoing edges).
149    #[must_use]
150    pub fn sinks(&self) -> Vec<PipelineNodeId> {
151        self.edges
152            .iter()
153            .filter(|(id, neighbours)| self.nodes.contains_key(id) && neighbours.is_empty())
154            .map(|(&id, _)| id)
155            .collect()
156    }
157
158    /// Total number of nodes.
159    #[must_use]
160    pub fn node_count(&self) -> usize {
161        self.nodes.len()
162    }
163
164    /// Total number of directed edges.
165    #[must_use]
166    pub fn edge_count(&self) -> usize {
167        self.edges.values().map(|s| s.len()).sum()
168    }
169
170    /// Look up the type of a node.
171    #[must_use]
172    pub fn node_type(&self, id: PipelineNodeId) -> Option<PipelineNodeType> {
173        self.nodes.get(&id).map(|m| m.node_type)
174    }
175
176    /// Look up the label of a node.
177    #[must_use]
178    pub fn node_label(&self, id: PipelineNodeId) -> Option<&str> {
179        self.nodes.get(&id).map(|m| m.label.as_str())
180    }
181}
182
183impl Default for PipelineGraph {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    fn simple_graph() -> (PipelineGraph, PipelineNodeId, PipelineNodeId) {
194        let mut g = PipelineGraph::new();
195        let src = g.add_node(PipelineNodeType::Source, "src");
196        let sink = g.add_node(PipelineNodeType::Sink, "sink");
197        g.connect(src, sink).expect("connect should succeed");
198        (g, src, sink)
199    }
200
201    // --- PipelineNodeType tests ---
202
203    #[test]
204    fn test_sink_is_terminal() {
205        assert!(PipelineNodeType::Sink.is_terminal());
206    }
207
208    #[test]
209    fn test_transform_not_terminal() {
210        assert!(!PipelineNodeType::Transform.is_terminal());
211    }
212
213    #[test]
214    fn test_source_is_source() {
215        assert!(PipelineNodeType::Source.is_source());
216    }
217
218    #[test]
219    fn test_merge_not_source() {
220        assert!(!PipelineNodeType::Merge.is_source());
221    }
222
223    // --- PipelineGraph tests ---
224
225    #[test]
226    fn test_add_node_increments_count() {
227        let mut g = PipelineGraph::new();
228        g.add_node(PipelineNodeType::Source, "s");
229        g.add_node(PipelineNodeType::Sink, "k");
230        assert_eq!(g.node_count(), 2);
231    }
232
233    #[test]
234    fn test_connect_increments_edge_count() {
235        let (g, _, _) = simple_graph();
236        assert_eq!(g.edge_count(), 1);
237    }
238
239    #[test]
240    fn test_connect_nonexistent_node_returns_error() {
241        let mut g = PipelineGraph::new();
242        let src = g.add_node(PipelineNodeType::Source, "s");
243        let ghost = PipelineNodeId(999);
244        assert!(g.connect(src, ghost).is_err());
245    }
246
247    #[test]
248    fn test_is_valid_dag_linear() {
249        let (g, _, _) = simple_graph();
250        assert!(g.is_valid_dag());
251    }
252
253    #[test]
254    fn test_is_valid_dag_cycle_detected() {
255        let mut g = PipelineGraph::new();
256        let a = g.add_node(PipelineNodeType::Source, "a");
257        let b = g.add_node(PipelineNodeType::Transform, "b");
258        g.connect(a, b).expect("connect should succeed");
259        g.connect(b, a).expect("connect should succeed"); // create cycle
260        assert!(!g.is_valid_dag());
261    }
262
263    #[test]
264    fn test_sources_returns_nodes_without_incoming() {
265        let (g, src, _) = simple_graph();
266        let sources = g.sources();
267        assert_eq!(sources.len(), 1);
268        assert!(sources.contains(&src));
269    }
270
271    #[test]
272    fn test_sinks_returns_nodes_without_outgoing() {
273        let (g, _, sink) = simple_graph();
274        let sinks = g.sinks();
275        assert_eq!(sinks.len(), 1);
276        assert!(sinks.contains(&sink));
277    }
278
279    #[test]
280    fn test_node_type_query() {
281        let (g, src, _) = simple_graph();
282        assert_eq!(g.node_type(src), Some(PipelineNodeType::Source));
283    }
284
285    #[test]
286    fn test_node_label_query() {
287        let (g, src, _) = simple_graph();
288        assert_eq!(g.node_label(src), Some("src"));
289    }
290
291    #[test]
292    fn test_diamond_graph_is_valid_dag() {
293        let mut g = PipelineGraph::new();
294        let a = g.add_node(PipelineNodeType::Source, "a");
295        let b = g.add_node(PipelineNodeType::Transform, "b");
296        let c = g.add_node(PipelineNodeType::Transform, "c");
297        let d = g.add_node(PipelineNodeType::Sink, "d");
298        g.connect(a, b).expect("connect should succeed");
299        g.connect(a, c).expect("connect should succeed");
300        g.connect(b, d).expect("connect should succeed");
301        g.connect(c, d).expect("connect should succeed");
302        assert!(g.is_valid_dag());
303    }
304
305    #[test]
306    fn test_multiple_sources_and_sinks() {
307        let mut g = PipelineGraph::new();
308        let s1 = g.add_node(PipelineNodeType::Source, "s1");
309        let s2 = g.add_node(PipelineNodeType::Source, "s2");
310        let m = g.add_node(PipelineNodeType::Merge, "m");
311        let k1 = g.add_node(PipelineNodeType::Sink, "k1");
312        let k2 = g.add_node(PipelineNodeType::Sink, "k2");
313        g.connect(s1, m).expect("connect should succeed");
314        g.connect(s2, m).expect("connect should succeed");
315        g.connect(m, k1).expect("connect should succeed");
316        g.connect(m, k2).expect("connect should succeed");
317        assert_eq!(g.sources().len(), 2);
318        assert_eq!(g.sinks().len(), 2);
319    }
320
321    #[test]
322    fn test_empty_graph_is_valid_dag() {
323        let g = PipelineGraph::new();
324        assert!(g.is_valid_dag());
325    }
326}