1use std::collections::{HashMap, HashSet, VecDeque};
7
8use crate::error::{GraphError, GraphResult};
9use crate::frame::FilterFrame;
10use crate::node::{Node, NodeId, NodeRuntime, NodeState, NodeType};
11use crate::port::{Connection, PortId};
12
13#[allow(dead_code)]
15pub struct FilterGraph {
16 nodes: HashMap<NodeId, NodeRuntime>,
18 connections: Vec<Connection>,
20 execution_order: Vec<NodeId>,
22 source_nodes: Vec<NodeId>,
24 sink_nodes: Vec<NodeId>,
26 next_id: u64,
28}
29
30impl FilterGraph {
31 #[must_use]
33 pub fn new() -> Self {
34 Self {
35 nodes: HashMap::new(),
36 connections: Vec::new(),
37 execution_order: Vec::new(),
38 source_nodes: Vec::new(),
39 sink_nodes: Vec::new(),
40 next_id: 0,
41 }
42 }
43
44 #[must_use]
46 pub fn builder() -> GraphBuilder<Empty> {
47 GraphBuilder::new()
48 }
49
50 #[must_use]
52 pub fn node(&self, id: NodeId) -> Option<&dyn Node> {
53 self.nodes.get(&id).map(|r| r.node())
54 }
55
56 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut dyn Node> {
58 self.nodes.get_mut(&id).map(|r| r.node_mut())
59 }
60
61 #[must_use]
63 pub fn node_ids(&self) -> Vec<NodeId> {
64 self.nodes.keys().copied().collect()
65 }
66
67 #[must_use]
69 pub fn execution_order(&self) -> &[NodeId] {
70 &self.execution_order
71 }
72
73 #[must_use]
75 pub fn source_nodes(&self) -> &[NodeId] {
76 &self.source_nodes
77 }
78
79 #[must_use]
81 pub fn sink_nodes(&self) -> &[NodeId] {
82 &self.sink_nodes
83 }
84
85 #[must_use]
87 pub fn connections(&self) -> &[Connection] {
88 &self.connections
89 }
90
91 #[must_use]
93 pub fn is_empty(&self) -> bool {
94 self.nodes.is_empty()
95 }
96
97 #[must_use]
99 pub fn node_count(&self) -> usize {
100 self.nodes.len()
101 }
102
103 pub fn initialize(&mut self) -> GraphResult<()> {
105 for id in &self.execution_order.clone() {
106 if let Some(runtime) = self.nodes.get_mut(id) {
107 runtime.node_mut().initialize()?;
108 }
109 }
110 Ok(())
111 }
112
113 pub fn process_step(&mut self) -> GraphResult<bool> {
117 let mut processed_any = false;
118
119 for id in self.execution_order.clone() {
120 let runtime = self
121 .nodes
122 .get_mut(&id)
123 .ok_or(GraphError::NodeNotFound(id))?;
124
125 if runtime.node().state() == NodeState::Done {
127 continue;
128 }
129
130 runtime.node_mut().set_state(NodeState::Processing)?;
132 runtime.process()?;
133 runtime.node_mut().set_state(NodeState::Idle)?;
134 processed_any = true;
135
136 for conn in &self.connections.clone() {
138 if conn.from_node == id {
139 let frame = {
141 let source = self
142 .nodes
143 .get_mut(&conn.from_node)
144 .ok_or(GraphError::NodeNotFound(conn.from_node))?;
145 source.pop_output(conn.from_port)?
146 };
147
148 if let Some(frame) = frame {
150 let dest = self
151 .nodes
152 .get_mut(&conn.to_node)
153 .ok_or(GraphError::NodeNotFound(conn.to_node))?;
154 dest.push_input(conn.to_port, frame)?;
155 }
156 }
157 }
158 }
159
160 Ok(processed_any)
161 }
162
163 pub fn push_frame(
165 &mut self,
166 node_id: NodeId,
167 port: PortId,
168 frame: FilterFrame,
169 ) -> GraphResult<()> {
170 let runtime = self
171 .nodes
172 .get_mut(&node_id)
173 .ok_or(GraphError::NodeNotFound(node_id))?;
174 runtime.push_input(port, frame)
175 }
176
177 pub fn pull_frame(
179 &mut self,
180 node_id: NodeId,
181 port: PortId,
182 ) -> GraphResult<Option<FilterFrame>> {
183 let runtime = self
184 .nodes
185 .get_mut(&node_id)
186 .ok_or(GraphError::NodeNotFound(node_id))?;
187 runtime.pop_output(port)
188 }
189
190 pub fn reset(&mut self) -> GraphResult<()> {
192 for runtime in self.nodes.values_mut() {
193 runtime.node_mut().reset()?;
194 }
195 Ok(())
196 }
197
198 pub fn flush(&mut self) -> GraphResult<Vec<FilterFrame>> {
200 let mut frames = Vec::new();
201
202 for id in &self.execution_order.clone() {
203 if let Some(runtime) = self.nodes.get_mut(id) {
204 let flushed = runtime.node_mut().flush()?;
205 frames.extend(flushed);
206 }
207 }
208
209 Ok(frames)
210 }
211
212 fn add_node_internal(&mut self, node: Box<dyn Node>) -> NodeId {
214 let id = NodeId(self.next_id);
215 self.next_id += 1;
216
217 match node.node_type() {
219 NodeType::Source => self.source_nodes.push(id),
220 NodeType::Sink => self.sink_nodes.push(id),
221 NodeType::Filter => {}
222 }
223
224 self.nodes.insert(id, NodeRuntime::new(node));
225 id
226 }
227
228 fn add_connection_internal(&mut self, connection: Connection) -> GraphResult<()> {
230 if !self.nodes.contains_key(&connection.from_node) {
232 return Err(GraphError::NodeNotFound(connection.from_node));
233 }
234 if !self.nodes.contains_key(&connection.to_node) {
235 return Err(GraphError::NodeNotFound(connection.to_node));
236 }
237
238 if self.connections.contains(&connection) {
240 return Err(GraphError::ConnectionExists {
241 from_node: connection.from_node,
242 from_port: connection.from_port,
243 to_node: connection.to_node,
244 to_port: connection.to_port,
245 });
246 }
247
248 {
250 let from_node = self
251 .nodes
252 .get(&connection.from_node)
253 .ok_or(GraphError::NodeNotFound(connection.from_node))?;
254 let to_node = self
255 .nodes
256 .get(&connection.to_node)
257 .ok_or(GraphError::NodeNotFound(connection.to_node))?;
258
259 let from_port = from_node.node().output_port(connection.from_port).ok_or(
260 GraphError::PortNotFound {
261 node: connection.from_node,
262 port: connection.from_port,
263 },
264 )?;
265
266 let to_port =
267 to_node
268 .node()
269 .input_port(connection.to_port)
270 .ok_or(GraphError::PortNotFound {
271 node: connection.to_node,
272 port: connection.to_port,
273 })?;
274
275 if from_port.port_type != to_port.port_type {
277 return Err(GraphError::PortTypeMismatch {
278 expected: format!("{:?}", to_port.port_type),
279 actual: format!("{:?}", from_port.port_type),
280 });
281 }
282
283 if !from_port.format.is_compatible(&to_port.format) {
285 return Err(GraphError::IncompatibleFormats {
286 source_format: format!("{}", from_port.format),
287 dest_format: format!("{}", to_port.format),
288 });
289 }
290 }
291
292 self.connections.push(connection);
293 Ok(())
294 }
295
296 fn compute_execution_order(&mut self) -> GraphResult<()> {
298 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
299 let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
300
301 for &id in self.nodes.keys() {
303 in_degree.insert(id, 0);
304 adjacency.insert(id, Vec::new());
305 }
306
307 for conn in &self.connections {
309 adjacency
310 .get_mut(&conn.from_node)
311 .ok_or(GraphError::NodeNotFound(conn.from_node))?
312 .push(conn.to_node);
313 *in_degree
314 .get_mut(&conn.to_node)
315 .ok_or(GraphError::NodeNotFound(conn.to_node))? += 1;
316 }
317
318 let mut queue: VecDeque<NodeId> = in_degree
320 .iter()
321 .filter(|(_, °)| deg == 0)
322 .map(|(&id, _)| id)
323 .collect();
324
325 let mut order = Vec::new();
326
327 while let Some(id) = queue.pop_front() {
328 order.push(id);
329
330 let neighbors: Vec<NodeId> = adjacency
331 .get(&id)
332 .ok_or(GraphError::NodeNotFound(id))?
333 .clone();
334 for neighbor in neighbors {
335 let deg = in_degree
336 .get_mut(&neighbor)
337 .ok_or(GraphError::NodeNotFound(neighbor))?;
338 *deg -= 1;
339 if *deg == 0 {
340 queue.push_back(neighbor);
341 }
342 }
343 }
344
345 if order.len() != self.nodes.len() {
347 let cycle_node = in_degree
349 .iter()
350 .find(|(_, °)| deg > 0)
351 .map_or(NodeId(0), |(&id, _)| id);
352 return Err(GraphError::CycleDetected(cycle_node));
353 }
354
355 self.execution_order = order;
356 Ok(())
357 }
358
359 fn validate(&self) -> GraphResult<()> {
361 if self.nodes.is_empty() {
362 return Err(GraphError::EmptyGraph);
363 }
364
365 if self.source_nodes.is_empty() {
366 return Err(GraphError::NoSourceNodes);
367 }
368
369 if self.sink_nodes.is_empty() {
370 return Err(GraphError::NoSinkNodes);
371 }
372
373 for (id, runtime) in &self.nodes {
375 for input in runtime.node().inputs() {
376 if input.required {
377 let connected = self
378 .connections
379 .iter()
380 .any(|c| c.to_node == *id && c.to_port == input.id);
381 if !connected && runtime.node().node_type() != NodeType::Source {
382 return Err(GraphError::ConfigurationError(format!(
383 "Required input '{}' on node {:?} is not connected",
384 input.name, id
385 )));
386 }
387 }
388 }
389 }
390
391 Ok(())
392 }
393}
394
395impl Default for FilterGraph {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401pub struct Empty;
404pub struct HasNodes;
406pub struct HasConnections;
408pub struct Ready;
410
411pub struct GraphBuilder<State> {
418 graph: FilterGraph,
419 _state: std::marker::PhantomData<State>,
420}
421
422impl GraphBuilder<Empty> {
423 #[must_use]
425 pub fn new() -> Self {
426 Self {
427 graph: FilterGraph::new(),
428 _state: std::marker::PhantomData,
429 }
430 }
431
432 pub fn add_node(mut self, node: Box<dyn Node>) -> (GraphBuilder<HasNodes>, NodeId) {
434 let id = self.graph.add_node_internal(node);
435 (
436 GraphBuilder {
437 graph: self.graph,
438 _state: std::marker::PhantomData,
439 },
440 id,
441 )
442 }
443}
444
445impl Default for GraphBuilder<Empty> {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451impl GraphBuilder<HasNodes> {
452 pub fn add_node(mut self, node: Box<dyn Node>) -> (Self, NodeId) {
454 let id = self.graph.add_node_internal(node);
455 (self, id)
456 }
457
458 pub fn connect(
460 mut self,
461 from_node: NodeId,
462 from_port: PortId,
463 to_node: NodeId,
464 to_port: PortId,
465 ) -> GraphResult<GraphBuilder<HasConnections>> {
466 let connection = Connection::new(from_node, from_port, to_node, to_port);
467 self.graph.add_connection_internal(connection)?;
468 Ok(GraphBuilder {
469 graph: self.graph,
470 _state: std::marker::PhantomData,
471 })
472 }
473
474 pub fn build(mut self) -> GraphResult<FilterGraph> {
476 self.graph.validate()?;
477 self.graph.compute_execution_order()?;
478 Ok(self.graph)
479 }
480}
481
482impl GraphBuilder<HasConnections> {
483 pub fn add_node(mut self, node: Box<dyn Node>) -> (Self, NodeId) {
485 let id = self.graph.add_node_internal(node);
486 (self, id)
487 }
488
489 pub fn connect(
491 mut self,
492 from_node: NodeId,
493 from_port: PortId,
494 to_node: NodeId,
495 to_port: PortId,
496 ) -> GraphResult<Self> {
497 let connection = Connection::new(from_node, from_port, to_node, to_port);
498 self.graph.add_connection_internal(connection)?;
499 Ok(self)
500 }
501
502 pub fn build(mut self) -> GraphResult<FilterGraph> {
504 self.graph.validate()?;
505 self.graph.compute_execution_order()?;
506 Ok(self.graph)
507 }
508}
509
510#[allow(dead_code)]
512fn find_paths(graph: &FilterGraph, from: NodeId, to: NodeId) -> Vec<Vec<NodeId>> {
513 let mut paths = Vec::new();
514 let mut current_path = vec![from];
515 let mut visited = HashSet::new();
516
517 find_paths_recursive(graph, from, to, &mut current_path, &mut visited, &mut paths);
518 paths
519}
520
521fn find_paths_recursive(
522 graph: &FilterGraph,
523 current: NodeId,
524 target: NodeId,
525 path: &mut Vec<NodeId>,
526 visited: &mut HashSet<NodeId>,
527 paths: &mut Vec<Vec<NodeId>>,
528) {
529 if current == target {
530 paths.push(path.clone());
531 return;
532 }
533
534 visited.insert(current);
535
536 for conn in graph.connections() {
537 if conn.from_node == current && !visited.contains(&conn.to_node) {
538 path.push(conn.to_node);
539 find_paths_recursive(graph, conn.to_node, target, path, visited, paths);
540 path.pop();
541 }
542 }
543
544 visited.remove(¤t);
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::filters::video::{NullSink, PassthroughFilter};
551
552 #[test]
553 fn test_graph_builder() {
554 let source = PassthroughFilter::new_source(NodeId(0), "source");
555 let sink = NullSink::new(NodeId(0), "sink");
556
557 let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
558 let (builder, sink_id) = builder.add_node(Box::new(sink));
559
560 let graph = builder
561 .connect(source_id, PortId(0), sink_id, PortId(0))
562 .expect("operation should succeed")
563 .build()
564 .expect("operation should succeed");
565
566 assert_eq!(graph.node_count(), 2);
567 assert_eq!(graph.source_nodes().len(), 1);
568 assert_eq!(graph.sink_nodes().len(), 1);
569 }
570
571 #[test]
572 fn test_execution_order() {
573 let source = PassthroughFilter::new_source(NodeId(0), "source");
574 let filter = PassthroughFilter::new(NodeId(0), "filter");
575 let sink = NullSink::new(NodeId(0), "sink");
576
577 let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
578 let (builder, filter_id) = builder.add_node(Box::new(filter));
579 let (builder, sink_id) = builder.add_node(Box::new(sink));
580
581 let graph = builder
582 .connect(source_id, PortId(0), filter_id, PortId(0))
583 .expect("operation should succeed")
584 .connect(filter_id, PortId(0), sink_id, PortId(0))
585 .expect("operation should succeed")
586 .build()
587 .expect("operation should succeed");
588
589 let order = graph.execution_order();
590 assert_eq!(order.len(), 3);
591
592 let source_pos = order
594 .iter()
595 .position(|&id| id == source_id)
596 .expect("iter should succeed");
597 let filter_pos = order
598 .iter()
599 .position(|&id| id == filter_id)
600 .expect("iter should succeed");
601 let sink_pos = order
602 .iter()
603 .position(|&id| id == sink_id)
604 .expect("iter should succeed");
605
606 assert!(source_pos < filter_pos);
607 assert!(filter_pos < sink_pos);
608 }
609
610 #[test]
611 fn test_empty_graph_error() {
612 let builder = GraphBuilder::<Empty>::new();
613 let _ = builder; }
617
618 #[test]
619 fn test_graph_reset() {
620 let source = PassthroughFilter::new_source(NodeId(0), "source");
621 let sink = NullSink::new(NodeId(0), "sink");
622
623 let (builder, source_id) = GraphBuilder::new().add_node(Box::new(source));
624 let (builder, sink_id) = builder.add_node(Box::new(sink));
625
626 let mut graph = builder
627 .connect(source_id, PortId(0), sink_id, PortId(0))
628 .expect("operation should succeed")
629 .build()
630 .expect("operation should succeed");
631
632 graph.initialize().expect("initialize should succeed");
634 graph.reset().expect("reset should succeed");
635
636 for id in graph.node_ids() {
638 let node = graph.node(id).expect("node should succeed");
639 assert_eq!(node.state(), NodeState::Idle);
640 }
641 }
642}