Skip to main content

oximedia_graph/
processing_graph.rs

1//! Media processing graph with nodes, edges, topological execution ordering,
2//! and retry-with-exponential-backoff for transient node failures.
3//!
4//! This module models a directed acyclic graph (DAG) of media processing nodes.
5//! Nodes represent processing stages (source, filter, encoder, etc.), and edges
6//! represent data flow between them.
7
8/// Classification of a processing node.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum NodeType {
11    /// Produces media data (e.g. file reader, camera capture).
12    Source,
13    /// Decodes compressed media into raw frames.
14    Decoder,
15    /// Transforms media data (e.g. scaler, colour converter).
16    Filter,
17    /// Encodes raw frames into a compressed format.
18    Encoder,
19    /// Consumes media data (e.g. file writer, display).
20    Sink,
21    /// Combines multiple input streams into one.
22    Mixer,
23    /// Distributes one input stream to multiple outputs.
24    Splitter,
25}
26
27impl NodeType {
28    /// Maximum number of input connections accepted by this node type.
29    pub fn max_inputs(&self) -> usize {
30        match self {
31            Self::Source => 0,
32            Self::Decoder => 1,
33            Self::Filter => 1,
34            Self::Encoder => 1,
35            Self::Sink => 1,
36            Self::Mixer => 8,
37            Self::Splitter => 1,
38        }
39    }
40
41    /// Maximum number of output connections this node type can produce.
42    pub fn max_outputs(&self) -> usize {
43        match self {
44            Self::Source => 1,
45            Self::Decoder => 1,
46            Self::Filter => 1,
47            Self::Encoder => 1,
48            Self::Sink => 0,
49            Self::Mixer => 1,
50            Self::Splitter => 8,
51        }
52    }
53}
54
55/// A single node in a media [`ProcessingGraph`].
56#[derive(Debug, Clone)]
57pub struct GraphNode {
58    /// Unique identifier for this node within the graph.
59    pub id: u64,
60    /// Human-readable name.
61    pub name: String,
62    /// Functional type of this node.
63    pub node_type: NodeType,
64    /// Whether this node should participate in processing.
65    pub enabled: bool,
66    /// Arbitrary key-value configuration parameters.
67    pub params: Vec<(String, String)>,
68}
69
70impl GraphNode {
71    /// Creates a new, enabled node with no parameters.
72    pub fn new(id: u64, name: &str, node_type: NodeType) -> Self {
73        Self {
74            id,
75            name: name.to_string(),
76            node_type,
77            enabled: true,
78            params: Vec::new(),
79        }
80    }
81
82    /// Returns the value for `key`, or `None` if not set.
83    pub fn get_param(&self, key: &str) -> Option<&str> {
84        self.params
85            .iter()
86            .find(|(k, _)| k == key)
87            .map(|(_, v)| v.as_str())
88    }
89
90    /// Sets (or updates) `key` to `value`.
91    pub fn set_param(&mut self, key: &str, value: &str) {
92        if let Some(entry) = self.params.iter_mut().find(|(k, _)| k == key) {
93            entry.1 = value.to_string();
94        } else {
95            self.params.push((key.to_string(), value.to_string()));
96        }
97    }
98}
99
100/// A directed connection between two ports on two nodes.
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct GraphEdge {
103    /// Source node identifier.
104    pub from_node: u64,
105    /// Output port index on the source node.
106    pub from_port: u32,
107    /// Destination node identifier.
108    pub to_node: u64,
109    /// Input port index on the destination node.
110    pub to_port: u32,
111}
112
113impl GraphEdge {
114    /// Returns `true` if this edge goes from `from` to `to`.
115    pub fn connects(&self, from: u64, to: u64) -> bool {
116        self.from_node == from && self.to_node == to
117    }
118}
119
120/// A directed acyclic graph of media processing nodes.
121#[derive(Debug, Default)]
122pub struct ProcessingGraph {
123    /// All nodes in the graph.
124    pub nodes: Vec<GraphNode>,
125    /// All edges in the graph.
126    pub edges: Vec<GraphEdge>,
127    /// When `true` the graph is considered executing and hot-swap is refused.
128    pub is_locked: bool,
129}
130
131impl ProcessingGraph {
132    /// Creates an empty processing graph.
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Adds `node` to the graph.  Duplicate IDs are allowed but discouraged.
138    pub fn add_node(&mut self, node: GraphNode) {
139        self.nodes.push(node);
140    }
141
142    /// Removes the node with `id` and all edges referencing it.
143    ///
144    /// Returns `true` if a node was removed.
145    pub fn remove_node(&mut self, id: u64) -> bool {
146        let before = self.nodes.len();
147        self.nodes.retain(|n| n.id != id);
148        self.edges.retain(|e| e.from_node != id && e.to_node != id);
149        self.nodes.len() < before
150    }
151
152    /// Adds an edge from `(from, from_port)` to `(to, to_port)`.
153    ///
154    /// Returns `false` if either node does not exist; `true` on success.
155    pub fn connect(&mut self, from: u64, from_port: u32, to: u64, to_port: u32) -> bool {
156        let has_from = self.nodes.iter().any(|n| n.id == from);
157        let has_to = self.nodes.iter().any(|n| n.id == to);
158        if !has_from || !has_to {
159            return false;
160        }
161        self.edges.push(GraphEdge {
162            from_node: from,
163            from_port,
164            to_node: to,
165            to_port,
166        });
167        true
168    }
169
170    /// Removes all edges from node `from` to node `to`.
171    ///
172    /// Returns `true` if at least one edge was removed.
173    pub fn disconnect(&mut self, from: u64, to: u64) -> bool {
174        let before = self.edges.len();
175        self.edges.retain(|e| !e.connects(from, to));
176        self.edges.len() < before
177    }
178
179    /// Returns references to all nodes whose type has zero inputs (source nodes).
180    pub fn source_nodes(&self) -> Vec<&GraphNode> {
181        self.nodes
182            .iter()
183            .filter(|n| n.node_type.max_inputs() == 0)
184            .collect()
185    }
186
187    /// Returns references to all nodes whose type has zero outputs (sink nodes).
188    pub fn sink_nodes(&self) -> Vec<&GraphNode> {
189        self.nodes
190            .iter()
191            .filter(|n| n.node_type.max_outputs() == 0)
192            .collect()
193    }
194
195    /// Returns node IDs in topological execution order (Kahn's algorithm).
196    ///
197    /// Nodes not reachable from any source, or that form cycles, are appended
198    /// in arbitrary order at the end.
199    pub fn execution_order(&self) -> Vec<u64> {
200        use std::collections::{HashMap, VecDeque};
201
202        // Count incoming edges per node (enabled nodes only).
203        let mut in_degree: HashMap<u64, usize> = self
204            .nodes
205            .iter()
206            .filter(|n| n.enabled)
207            .map(|n| (n.id, 0))
208            .collect();
209
210        for edge in &self.edges {
211            if in_degree.contains_key(&edge.from_node) && in_degree.contains_key(&edge.to_node) {
212                *in_degree.entry(edge.to_node).or_insert(0) += 1;
213            }
214        }
215
216        // Seed the queue with zero-in-degree nodes.
217        let mut queue: VecDeque<u64> = in_degree
218            .iter()
219            .filter(|(_, &deg)| deg == 0)
220            .map(|(&id, _)| id)
221            .collect();
222
223        // Sort for determinism.
224        let mut queue_vec: Vec<u64> = queue.drain(..).collect();
225        queue_vec.sort_unstable();
226        queue.extend(queue_vec);
227
228        let mut order = Vec::with_capacity(self.nodes.len());
229
230        while let Some(id) = queue.pop_front() {
231            order.push(id);
232            // Find successors and decrement their in-degree.
233            let mut new_ready: Vec<u64> = self
234                .edges
235                .iter()
236                .filter(|e| e.from_node == id)
237                .filter_map(|e| {
238                    let deg = in_degree.get_mut(&e.to_node)?;
239                    *deg = deg.saturating_sub(1);
240                    if *deg == 0 {
241                        Some(e.to_node)
242                    } else {
243                        None
244                    }
245                })
246                .collect();
247            new_ready.sort_unstable();
248            queue.extend(new_ready);
249        }
250
251        // Append any remaining nodes (disabled or cycle members) in id order.
252        let mut remaining: Vec<u64> = self
253            .nodes
254            .iter()
255            .map(|n| n.id)
256            .filter(|id| !order.contains(id))
257            .collect();
258        remaining.sort_unstable();
259        order.extend(remaining);
260
261        order
262    }
263}
264
265// ─────────────────────────────────────────────────────────────────────────────
266// Retry policy for transient node failures
267// ─────────────────────────────────────────────────────────────────────────────
268
269/// Marker trait for errors that are considered *transient* (i.e. worth
270/// retrying, such as a temporary resource contention or a momentary decode
271/// stall).
272///
273/// Implement this on your custom error type and return it from a node's
274/// process function to enable automatic retry via [`RetryPolicy`].
275pub trait TransientError {
276    /// Returns `true` when the error represents a transient condition.
277    fn is_transient(&self) -> bool;
278}
279
280/// Policy controlling how many times a failing node operation should be
281/// retried and how long to wait between attempts.
282///
283/// The inter-attempt sleep uses *exponential backoff*: attempt `n` (0-indexed)
284/// sleeps for `backoff_ms * 2^n` milliseconds before being made.
285///
286/// # Example
287/// ```
288/// use oximedia_graph::processing_graph::RetryPolicy;
289///
290/// // Three attempts total, starting with a 10 ms back-off.
291/// let policy = RetryPolicy { max_attempts: 3, backoff_ms: 10 };
292/// ```
293#[derive(Debug, Clone)]
294pub struct RetryPolicy {
295    /// Maximum number of attempts (including the first).  A value of `1`
296    /// means "try once; do not retry".
297    pub max_attempts: u32,
298    /// Base back-off in milliseconds.  The sleep before attempt `n` is
299    /// `backoff_ms * 2^(n-1)` ms (so the first retry sleeps `backoff_ms` ms,
300    /// the second sleeps `2 * backoff_ms` ms, and so on).
301    pub backoff_ms: u64,
302}
303
304impl Default for RetryPolicy {
305    fn default() -> Self {
306        Self {
307            max_attempts: 3,
308            backoff_ms: 1,
309        }
310    }
311}
312
313impl RetryPolicy {
314    /// Executes `f` up to `max_attempts` times, returning the first `Ok(T)`.
315    ///
316    /// If `f` returns an `Err(e)` where `e.is_transient()` returns `true` and
317    /// there are remaining attempts, the function sleeps for an exponentially
318    /// increasing duration and then calls `f` again.
319    ///
320    /// If `f` returns an `Err(e)` where `e.is_transient()` returns `false`,
321    /// the error is returned immediately (no further retries are made).
322    ///
323    /// If all `max_attempts` are exhausted the last error is returned.
324    pub fn execute<T, E, F>(&self, mut f: F) -> Result<T, E>
325    where
326        E: TransientError,
327        F: FnMut() -> Result<T, E>,
328    {
329        let mut last_err: Option<E> = None;
330        for attempt in 0..self.max_attempts {
331            match f() {
332                Ok(v) => return Ok(v),
333                Err(e) => {
334                    if !e.is_transient() {
335                        return Err(e);
336                    }
337                    // Sleep with exponential back-off before the next attempt.
338                    if attempt + 1 < self.max_attempts {
339                        let sleep_ms = self.backoff_ms.saturating_mul(1u64 << attempt);
340                        std::thread::sleep(std::time::Duration::from_millis(sleep_ms));
341                    }
342                    last_err = Some(e);
343                }
344            }
345        }
346        // `max_attempts >= 1` and we enter at least one iteration, so
347        // `last_err` is always `Some` here.
348        Err(last_err.expect("at least one attempt must have set last_err"))
349    }
350}
351
352// ─────────────────────────────────────────────────────────────────────────────
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    fn source(id: u64) -> GraphNode {
358        GraphNode::new(id, &format!("source_{id}"), NodeType::Source)
359    }
360    fn filter(id: u64) -> GraphNode {
361        GraphNode::new(id, &format!("filter_{id}"), NodeType::Filter)
362    }
363    fn sink(id: u64) -> GraphNode {
364        GraphNode::new(id, &format!("sink_{id}"), NodeType::Sink)
365    }
366
367    // ── NodeType ─────────────────────────────────────────────────────────────
368
369    #[test]
370    fn source_has_zero_inputs() {
371        assert_eq!(NodeType::Source.max_inputs(), 0);
372    }
373
374    #[test]
375    fn sink_has_zero_outputs() {
376        assert_eq!(NodeType::Sink.max_outputs(), 0);
377    }
378
379    #[test]
380    fn mixer_accepts_multiple_inputs() {
381        assert!(NodeType::Mixer.max_inputs() > 1);
382    }
383
384    #[test]
385    fn splitter_produces_multiple_outputs() {
386        assert!(NodeType::Splitter.max_outputs() > 1);
387    }
388
389    // ── GraphNode ────────────────────────────────────────────────────────────
390
391    #[test]
392    fn node_set_and_get_param() {
393        let mut n = filter(1);
394        n.set_param("width", "1920");
395        assert_eq!(n.get_param("width"), Some("1920"));
396    }
397
398    #[test]
399    fn node_update_existing_param() {
400        let mut n = filter(2);
401        n.set_param("fps", "24");
402        n.set_param("fps", "60");
403        assert_eq!(n.get_param("fps"), Some("60"));
404        // Only one entry for the key.
405        assert_eq!(n.params.iter().filter(|(k, _)| k == "fps").count(), 1);
406    }
407
408    #[test]
409    fn node_missing_param_returns_none() {
410        let n = source(3);
411        assert!(n.get_param("nonexistent").is_none());
412    }
413
414    // ── GraphEdge ────────────────────────────────────────────────────────────
415
416    #[test]
417    fn edge_connects_returns_true_for_matching_pair() {
418        let edge = GraphEdge {
419            from_node: 1,
420            from_port: 0,
421            to_node: 2,
422            to_port: 0,
423        };
424        assert!(edge.connects(1, 2));
425    }
426
427    #[test]
428    fn edge_connects_returns_false_for_reversed_pair() {
429        let edge = GraphEdge {
430            from_node: 1,
431            from_port: 0,
432            to_node: 2,
433            to_port: 0,
434        };
435        assert!(!edge.connects(2, 1));
436    }
437
438    // ── ProcessingGraph ───────────────────────────────────────────────────────
439
440    #[test]
441    fn add_and_remove_node() {
442        let mut g = ProcessingGraph::new();
443        g.add_node(source(10));
444        assert_eq!(g.nodes.len(), 1);
445        assert!(g.remove_node(10));
446        assert!(g.nodes.is_empty());
447    }
448
449    #[test]
450    fn remove_node_also_removes_edges() {
451        let mut g = ProcessingGraph::new();
452        g.add_node(source(1));
453        g.add_node(sink(2));
454        g.connect(1, 0, 2, 0);
455        g.remove_node(1);
456        assert!(g.edges.is_empty());
457    }
458
459    #[test]
460    fn connect_fails_for_missing_node() {
461        let mut g = ProcessingGraph::new();
462        g.add_node(source(1));
463        assert!(!g.connect(1, 0, 99, 0)); // node 99 missing
464    }
465
466    #[test]
467    fn disconnect_removes_all_matching_edges() {
468        let mut g = ProcessingGraph::new();
469        g.add_node(source(1));
470        g.add_node(sink(2));
471        g.connect(1, 0, 2, 0);
472        g.connect(1, 0, 2, 1);
473        assert!(g.disconnect(1, 2));
474        assert!(g.edges.is_empty());
475    }
476
477    #[test]
478    fn source_nodes_returns_only_sources() {
479        let mut g = ProcessingGraph::new();
480        g.add_node(source(1));
481        g.add_node(filter(2));
482        g.add_node(sink(3));
483        let srcs: Vec<u64> = g.source_nodes().into_iter().map(|n| n.id).collect();
484        assert_eq!(srcs, vec![1]);
485    }
486
487    #[test]
488    fn sink_nodes_returns_only_sinks() {
489        let mut g = ProcessingGraph::new();
490        g.add_node(source(1));
491        g.add_node(sink(2));
492        let sinks: Vec<u64> = g.sink_nodes().into_iter().map(|n| n.id).collect();
493        assert_eq!(sinks, vec![2]);
494    }
495
496    #[test]
497    fn execution_order_linear_pipeline() {
498        // source(1) -> filter(2) -> sink(3)
499        let mut g = ProcessingGraph::new();
500        g.add_node(source(1));
501        g.add_node(filter(2));
502        g.add_node(sink(3));
503        g.connect(1, 0, 2, 0);
504        g.connect(2, 0, 3, 0);
505        let order = g.execution_order();
506        assert_eq!(order, vec![1, 2, 3]);
507    }
508
509    #[test]
510    fn execution_order_independent_nodes_are_included() {
511        let mut g = ProcessingGraph::new();
512        g.add_node(source(1));
513        g.add_node(source(2));
514        let order = g.execution_order();
515        assert_eq!(order.len(), 2);
516    }
517
518    // ── RetryPolicy ───────────────────────────────────────────────────────────
519
520    /// A simple error type used in retry tests.
521    #[derive(Debug, PartialEq, Eq)]
522    enum TestError {
523        Transient(String),
524        Fatal(String),
525    }
526
527    impl TransientError for TestError {
528        fn is_transient(&self) -> bool {
529            matches!(self, Self::Transient(_))
530        }
531    }
532
533    #[test]
534    fn test_retry_succeeds_on_second_attempt() {
535        let call_count = std::cell::Cell::new(0u32);
536        let policy = RetryPolicy {
537            max_attempts: 3,
538            backoff_ms: 0,
539        };
540
541        let result = policy.execute(|| {
542            let n = call_count.get();
543            call_count.set(n + 1);
544            if n == 0 {
545                // First call fails with a transient error.
546                Err(TestError::Transient("temp".to_string()))
547            } else {
548                // Second call succeeds.
549                Ok(42u32)
550            }
551        });
552
553        assert_eq!(result, Ok(42));
554        assert_eq!(call_count.get(), 2, "should have been called exactly twice");
555    }
556
557    #[test]
558    fn test_retry_exhausted() {
559        let call_count = std::cell::Cell::new(0u32);
560        let policy = RetryPolicy {
561            max_attempts: 3,
562            backoff_ms: 0,
563        };
564
565        let result: Result<u32, TestError> = policy.execute(|| {
566            call_count.set(call_count.get() + 1);
567            Err(TestError::Transient("always fails".to_string()))
568        });
569
570        assert!(result.is_err(), "all attempts exhausted; must return Err");
571        assert_eq!(
572            call_count.get(),
573            3,
574            "execute must invoke f exactly max_attempts times"
575        );
576    }
577
578    #[test]
579    fn test_retry_fatal_error_no_retry() {
580        let call_count = std::cell::Cell::new(0u32);
581        let policy = RetryPolicy {
582            max_attempts: 5,
583            backoff_ms: 0,
584        };
585
586        let result: Result<u32, TestError> = policy.execute(|| {
587            call_count.set(call_count.get() + 1);
588            Err(TestError::Fatal("unrecoverable".to_string()))
589        });
590
591        assert!(result.is_err());
592        // Fatal error on the very first call must prevent further retries.
593        assert_eq!(
594            call_count.get(),
595            1,
596            "fatal error must halt retries immediately"
597        );
598    }
599}