Skip to main content

oximedia_graph/
graph.rs

1//! Filter graph builder and execution.
2//!
3//! The filter graph connects nodes together to form a processing pipeline.
4//! Use [`GraphBuilder`] to construct graphs with compile-time safety.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use crate::error::{GraphError, GraphResult};
9use crate::frame::FilterFrame;
10use crate::node::{Node, NodeId, NodeRuntime, NodeState, NodeType};
11use crate::port::{Connection, PortId};
12
13/// A filter graph that processes media through connected nodes.
14#[allow(dead_code)]
15pub struct FilterGraph {
16    /// Nodes in the graph indexed by ID.
17    nodes: HashMap<NodeId, NodeRuntime>,
18    /// Connections between nodes.
19    connections: Vec<Connection>,
20    /// Topologically sorted node order for execution.
21    execution_order: Vec<NodeId>,
22    /// Source nodes (entry points).
23    source_nodes: Vec<NodeId>,
24    /// Sink nodes (exit points).
25    sink_nodes: Vec<NodeId>,
26    /// Next available node ID.
27    next_id: u64,
28}
29
30impl FilterGraph {
31    /// Create a new empty filter graph.
32    #[must_use]
33    pub fn new() -> Self {
34        Self {
35            nodes: HashMap::new(),
36            connections: Vec::new(),
37            execution_order: Vec::new(),
38            source_nodes: Vec::new(),
39            sink_nodes: Vec::new(),
40            next_id: 0,
41        }
42    }
43
44    /// Create a new graph builder.
45    #[must_use]
46    pub fn builder() -> GraphBuilder<Empty> {
47        GraphBuilder::new()
48    }
49
50    /// Get a node by ID.
51    #[must_use]
52    pub fn node(&self, id: NodeId) -> Option<&dyn Node> {
53        self.nodes.get(&id).map(|r| r.node())
54    }
55
56    /// Get a mutable node by ID.
57    pub fn node_mut(&mut self, id: NodeId) -> Option<&mut dyn Node> {
58        self.nodes.get_mut(&id).map(|r| r.node_mut())
59    }
60
61    /// Get all node IDs.
62    #[must_use]
63    pub fn node_ids(&self) -> Vec<NodeId> {
64        self.nodes.keys().copied().collect()
65    }
66
67    /// Get the execution order.
68    #[must_use]
69    pub fn execution_order(&self) -> &[NodeId] {
70        &self.execution_order
71    }
72
73    /// Get source nodes.
74    #[must_use]
75    pub fn source_nodes(&self) -> &[NodeId] {
76        &self.source_nodes
77    }
78
79    /// Get sink nodes.
80    #[must_use]
81    pub fn sink_nodes(&self) -> &[NodeId] {
82        &self.sink_nodes
83    }
84
85    /// Get connections.
86    #[must_use]
87    pub fn connections(&self) -> &[Connection] {
88        &self.connections
89    }
90
91    /// Check if the graph is empty.
92    #[must_use]
93    pub fn is_empty(&self) -> bool {
94        self.nodes.is_empty()
95    }
96
97    /// Get the number of nodes.
98    #[must_use]
99    pub fn node_count(&self) -> usize {
100        self.nodes.len()
101    }
102
103    /// Initialize all nodes for processing.
104    pub fn initialize(&mut self) -> GraphResult<()> {
105        for id in &self.execution_order.clone() {
106            if let Some(runtime) = self.nodes.get_mut(id) {
107                runtime.node_mut().initialize()?;
108            }
109        }
110        Ok(())
111    }
112
113    /// Process one step of the graph.
114    ///
115    /// Processes nodes in topological order, passing frames through connections.
116    pub fn process_step(&mut self) -> GraphResult<bool> {
117        let mut processed_any = false;
118
119        for id in self.execution_order.clone() {
120            let runtime = self
121                .nodes
122                .get_mut(&id)
123                .ok_or(GraphError::NodeNotFound(id))?;
124
125            // Skip nodes that are done
126            if runtime.node().state() == NodeState::Done {
127                continue;
128            }
129
130            // Process the node
131            runtime.node_mut().set_state(NodeState::Processing)?;
132            runtime.process()?;
133            runtime.node_mut().set_state(NodeState::Idle)?;
134            processed_any = true;
135
136            // Transfer outputs to connected inputs
137            for conn in &self.connections.clone() {
138                if conn.from_node == id {
139                    // Get output from source
140                    let frame = {
141                        let source = self
142                            .nodes
143                            .get_mut(&conn.from_node)
144                            .ok_or(GraphError::NodeNotFound(conn.from_node))?;
145                        source.pop_output(conn.from_port)?
146                    };
147
148                    // Push to destination if we have a frame
149                    if let Some(frame) = frame {
150                        let dest = self
151                            .nodes
152                            .get_mut(&conn.to_node)
153                            .ok_or(GraphError::NodeNotFound(conn.to_node))?;
154                        dest.push_input(conn.to_port, frame)?;
155                    }
156                }
157            }
158        }
159
160        Ok(processed_any)
161    }
162
163    /// Push a frame to a source node.
164    pub fn push_frame(
165        &mut self,
166        node_id: NodeId,
167        port: PortId,
168        frame: FilterFrame,
169    ) -> GraphResult<()> {
170        let runtime = self
171            .nodes
172            .get_mut(&node_id)
173            .ok_or(GraphError::NodeNotFound(node_id))?;
174        runtime.push_input(port, frame)
175    }
176
177    /// Pull a frame from a sink node.
178    pub fn pull_frame(
179        &mut self,
180        node_id: NodeId,
181        port: PortId,
182    ) -> GraphResult<Option<FilterFrame>> {
183        let runtime = self
184            .nodes
185            .get_mut(&node_id)
186            .ok_or(GraphError::NodeNotFound(node_id))?;
187        runtime.pop_output(port)
188    }
189
190    /// Reset all nodes to initial state.
191    pub fn reset(&mut self) -> GraphResult<()> {
192        for runtime in self.nodes.values_mut() {
193            runtime.node_mut().reset()?;
194        }
195        Ok(())
196    }
197
198    /// Flush all nodes.
199    pub fn flush(&mut self) -> GraphResult<Vec<FilterFrame>> {
200        let mut frames = Vec::new();
201
202        for id in &self.execution_order.clone() {
203            if let Some(runtime) = self.nodes.get_mut(id) {
204                let flushed = runtime.node_mut().flush()?;
205                frames.extend(flushed);
206            }
207        }
208
209        Ok(frames)
210    }
211
212    /// Add a node to the graph (internal).
213    fn add_node_internal(&mut self, node: Box<dyn Node>) -> NodeId {
214        let id = NodeId(self.next_id);
215        self.next_id += 1;
216
217        // Classify node type
218        match node.node_type() {
219            NodeType::Source => self.source_nodes.push(id),
220            NodeType::Sink => self.sink_nodes.push(id),
221            NodeType::Filter => {}
222        }
223
224        self.nodes.insert(id, NodeRuntime::new(node));
225        id
226    }
227
228    /// Add a connection between nodes (internal).
229    fn add_connection_internal(&mut self, connection: Connection) -> GraphResult<()> {
230        // Verify nodes exist
231        if !self.nodes.contains_key(&connection.from_node) {
232            return Err(GraphError::NodeNotFound(connection.from_node));
233        }
234        if !self.nodes.contains_key(&connection.to_node) {
235            return Err(GraphError::NodeNotFound(connection.to_node));
236        }
237
238        // Check for duplicate connections
239        if self.connections.contains(&connection) {
240            return Err(GraphError::ConnectionExists {
241                from_node: connection.from_node,
242                from_port: connection.from_port,
243                to_node: connection.to_node,
244                to_port: connection.to_port,
245            });
246        }
247
248        // Verify ports exist and formats are compatible
249        {
250            let from_node = self
251                .nodes
252                .get(&connection.from_node)
253                .ok_or(GraphError::NodeNotFound(connection.from_node))?;
254            let to_node = self
255                .nodes
256                .get(&connection.to_node)
257                .ok_or(GraphError::NodeNotFound(connection.to_node))?;
258
259            let from_port = from_node.node().output_port(connection.from_port).ok_or(
260                GraphError::PortNotFound {
261                    node: connection.from_node,
262                    port: connection.from_port,
263                },
264            )?;
265
266            let to_port =
267                to_node
268                    .node()
269                    .input_port(connection.to_port)
270                    .ok_or(GraphError::PortNotFound {
271                        node: connection.to_node,
272                        port: connection.to_port,
273                    })?;
274
275            // Check port type compatibility
276            if from_port.port_type != to_port.port_type {
277                return Err(GraphError::PortTypeMismatch {
278                    expected: format!("{:?}", to_port.port_type),
279                    actual: format!("{:?}", from_port.port_type),
280                });
281            }
282
283            // Check format compatibility
284            if !from_port.format.is_compatible(&to_port.format) {
285                return Err(GraphError::IncompatibleFormats {
286                    source_format: format!("{}", from_port.format),
287                    dest_format: format!("{}", to_port.format),
288                });
289            }
290        }
291
292        self.connections.push(connection);
293        Ok(())
294    }
295
296    /// Compute topological sort for execution order.
297    fn compute_execution_order(&mut self) -> GraphResult<()> {
298        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
299        let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
300
301        // Initialize
302        for &id in self.nodes.keys() {
303            in_degree.insert(id, 0);
304            adjacency.insert(id, Vec::new());
305        }
306
307        // Build adjacency list and in-degrees
308        for conn in &self.connections {
309            adjacency
310                .get_mut(&conn.from_node)
311                .ok_or(GraphError::NodeNotFound(conn.from_node))?
312                .push(conn.to_node);
313            *in_degree
314                .get_mut(&conn.to_node)
315                .ok_or(GraphError::NodeNotFound(conn.to_node))? += 1;
316        }
317
318        // Kahn's algorithm
319        let mut queue: VecDeque<NodeId> = in_degree
320            .iter()
321            .filter(|(_, &deg)| deg == 0)
322            .map(|(&id, _)| id)
323            .collect();
324
325        let mut order = Vec::new();
326
327        while let Some(id) = queue.pop_front() {
328            order.push(id);
329
330            let neighbors: Vec<NodeId> = adjacency
331                .get(&id)
332                .ok_or(GraphError::NodeNotFound(id))?
333                .clone();
334            for neighbor in neighbors {
335                let deg = in_degree
336                    .get_mut(&neighbor)
337                    .ok_or(GraphError::NodeNotFound(neighbor))?;
338                *deg -= 1;
339                if *deg == 0 {
340                    queue.push_back(neighbor);
341                }
342            }
343        }
344
345        // Check for cycle
346        if order.len() != self.nodes.len() {
347            // Find a node that's part of the cycle
348            let cycle_node = in_degree
349                .iter()
350                .find(|(_, &deg)| deg > 0)
351                .map_or(NodeId(0), |(&id, _)| id);
352            return Err(GraphError::CycleDetected(cycle_node));
353        }
354
355        self.execution_order = order;
356        Ok(())
357    }
358
359    /// Validate the graph configuration.
360    fn validate(&self) -> GraphResult<()> {
361        if self.nodes.is_empty() {
362            return Err(GraphError::EmptyGraph);
363        }
364
365        if self.source_nodes.is_empty() {
366            return Err(GraphError::NoSourceNodes);
367        }
368
369        if self.sink_nodes.is_empty() {
370            return Err(GraphError::NoSinkNodes);
371        }
372
373        // Check all required inputs are connected
374        for (id, runtime) in &self.nodes {
375            for input in runtime.node().inputs() {
376                if input.required {
377                    let connected = self
378                        .connections
379                        .iter()
380                        .any(|c| c.to_node == *id && c.to_port == input.id);
381                    if !connected && runtime.node().node_type() != NodeType::Source {
382                        return Err(GraphError::ConfigurationError(format!(
383                            "Required input '{}' on node {:?} is not connected",
384                            input.name, id
385                        )));
386                    }
387                }
388            }
389        }
390
391        Ok(())
392    }
393}
394
395impl Default for FilterGraph {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401// Type-state markers for the builder
402/// Empty graph state.
403pub struct Empty;
404/// Graph has at least one node.
405pub struct HasNodes;
406/// Graph has connections.
407pub struct HasConnections;
408/// Graph is ready to build.
409pub struct Ready;
410
411/// Builder for constructing filter graphs with type-state pattern.
412///
413/// The builder ensures that graphs are constructed correctly:
414/// 1. Add nodes
415/// 2. Connect nodes
416/// 3. Build the graph
417pub struct GraphBuilder<State> {
418    graph: FilterGraph,
419    _state: std::marker::PhantomData<State>,
420}
421
422impl GraphBuilder<Empty> {
423    /// Create a new graph builder.
424    #[must_use]
425    pub fn new() -> Self {
426        Self {
427            graph: FilterGraph::new(),
428            _state: std::marker::PhantomData,
429        }
430    }
431
432    /// Add the first node to the graph.
433    pub fn add_node(mut self, node: Box<dyn Node>) -> (GraphBuilder<HasNodes>, NodeId) {
434        let id = self.graph.add_node_internal(node);
435        (
436            GraphBuilder {
437                graph: self.graph,
438                _state: std::marker::PhantomData,
439            },
440            id,
441        )
442    }
443}
444
445impl Default for GraphBuilder<Empty> {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451impl GraphBuilder<HasNodes> {
452    /// Add another node to the graph.
453    pub fn add_node(mut self, node: Box<dyn Node>) -> (Self, NodeId) {
454        let id = self.graph.add_node_internal(node);
455        (self, id)
456    }
457
458    /// Connect two nodes.
459    pub fn connect(
460        mut self,
461        from_node: NodeId,
462        from_port: PortId,
463        to_node: NodeId,
464        to_port: PortId,
465    ) -> GraphResult<GraphBuilder<HasConnections>> {
466        let connection = Connection::new(from_node, from_port, to_node, to_port);
467        self.graph.add_connection_internal(connection)?;
468        Ok(GraphBuilder {
469            graph: self.graph,
470            _state: std::marker::PhantomData,
471        })
472    }
473
474    /// Build the graph without any connections (single node graph).
475    pub fn build(mut self) -> GraphResult<FilterGraph> {
476        self.graph.validate()?;
477        self.graph.compute_execution_order()?;
478        Ok(self.graph)
479    }
480}
481
482impl GraphBuilder<HasConnections> {
483    /// Add another node to the graph.
484    pub fn add_node(mut self, node: Box<dyn Node>) -> (Self, NodeId) {
485        let id = self.graph.add_node_internal(node);
486        (self, id)
487    }
488
489    /// Add another connection.
490    pub fn connect(
491        mut self,
492        from_node: NodeId,
493        from_port: PortId,
494        to_node: NodeId,
495        to_port: PortId,
496    ) -> GraphResult<Self> {
497        let connection = Connection::new(from_node, from_port, to_node, to_port);
498        self.graph.add_connection_internal(connection)?;
499        Ok(self)
500    }
501
502    /// Build the filter graph.
503    pub fn build(mut self) -> GraphResult<FilterGraph> {
504        self.graph.validate()?;
505        self.graph.compute_execution_order()?;
506        Ok(self.graph)
507    }
508}
509
510/// Find all paths between two nodes in the graph.
511#[allow(dead_code)]
512fn find_paths(graph: &FilterGraph, from: NodeId, to: NodeId) -> Vec<Vec<NodeId>> {
513    let mut paths = Vec::new();
514    let mut current_path = vec![from];
515    let mut visited = HashSet::new();
516
517    find_paths_recursive(graph, from, to, &mut current_path, &mut visited, &mut paths);
518    paths
519}
520
521fn find_paths_recursive(
522    graph: &FilterGraph,
523    current: NodeId,
524    target: NodeId,
525    path: &mut Vec<NodeId>,
526    visited: &mut HashSet<NodeId>,
527    paths: &mut Vec<Vec<NodeId>>,
528) {
529    if current == target {
530        paths.push(path.clone());
531        return;
532    }
533
534    visited.insert(current);
535
536    for conn in graph.connections() {
537        if conn.from_node == current && !visited.contains(&conn.to_node) {
538            path.push(conn.to_node);
539            find_paths_recursive(graph, conn.to_node, target, path, visited, paths);
540            path.pop();
541        }
542    }
543
544    visited.remove(&current);
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use crate::filters::video::{NullSink, PassthroughFilter};
551
552    #[test]
553    fn test_graph_builder() {
554        let source = PassthroughFilter::new_source(NodeId(0), "source");
555        let sink = NullSink::new(NodeId(0), "sink");
556
557        let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
558        let (builder, sink_id) = builder.add_node(Box::new(sink));
559
560        let graph = builder
561            .connect(source_id, PortId(0), sink_id, PortId(0))
562            .expect("operation should succeed")
563            .build()
564            .expect("operation should succeed");
565
566        assert_eq!(graph.node_count(), 2);
567        assert_eq!(graph.source_nodes().len(), 1);
568        assert_eq!(graph.sink_nodes().len(), 1);
569    }
570
571    #[test]
572    fn test_execution_order() {
573        let source = PassthroughFilter::new_source(NodeId(0), "source");
574        let filter = PassthroughFilter::new(NodeId(0), "filter");
575        let sink = NullSink::new(NodeId(0), "sink");
576
577        let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
578        let (builder, filter_id) = builder.add_node(Box::new(filter));
579        let (builder, sink_id) = builder.add_node(Box::new(sink));
580
581        let graph = builder
582            .connect(source_id, PortId(0), filter_id, PortId(0))
583            .expect("operation should succeed")
584            .connect(filter_id, PortId(0), sink_id, PortId(0))
585            .expect("operation should succeed")
586            .build()
587            .expect("operation should succeed");
588
589        let order = graph.execution_order();
590        assert_eq!(order.len(), 3);
591
592        // Source should come before filter, filter before sink
593        let source_pos = order
594            .iter()
595            .position(|&id| id == source_id)
596            .expect("iter should succeed");
597        let filter_pos = order
598            .iter()
599            .position(|&id| id == filter_id)
600            .expect("iter should succeed");
601        let sink_pos = order
602            .iter()
603            .position(|&id| id == sink_id)
604            .expect("iter should succeed");
605
606        assert!(source_pos < filter_pos);
607        assert!(filter_pos < sink_pos);
608    }
609
610    #[test]
611    fn test_empty_graph_error() {
612        let builder = GraphBuilder::<Empty>::new();
613        // Cannot call build on empty builder due to type state
614        // This test verifies the type state prevents invalid usage
615        let _ = builder; // Just verify it compiles
616    }
617
618    #[test]
619    fn test_graph_reset() {
620        let source = PassthroughFilter::new_source(NodeId(0), "source");
621        let sink = NullSink::new(NodeId(0), "sink");
622
623        let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
624        let (builder, sink_id) = builder.add_node(Box::new(sink));
625
626        let mut graph = builder
627            .connect(source_id, PortId(0), sink_id, PortId(0))
628            .expect("operation should succeed")
629            .build()
630            .expect("operation should succeed");
631
632        // Initialize and reset
633        graph.initialize().expect("initialize should succeed");
634        graph.reset().expect("reset should succeed");
635
636        // Nodes should be back to idle
637        for id in graph.node_ids() {
638            let node = graph.node(id).expect("node should succeed");
639            assert_eq!(node.state(), NodeState::Idle);
640        }
641    }
642}