Skip to main content

oximedia_graph/
dependency_graph.rs

1#![allow(dead_code)]
2//! Dependency analysis for graph nodes.
3//!
4//! This module provides tools for analyzing dependencies between nodes
5//! in a filter graph, computing execution order, detecting critical paths,
6//! and identifying parallelizable groups.
7
8use std::collections::{HashMap, HashSet, VecDeque};
9
10/// Unique identifier for a dependency node.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct DepNodeId(pub u64);
13
14impl std::fmt::Display for DepNodeId {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "node_{}", self.0)
17    }
18}
19
20/// A node in the dependency graph.
21#[derive(Debug, Clone)]
22pub struct DepNode {
23    /// Unique identifier.
24    pub id: DepNodeId,
25    /// Human-readable label.
26    pub label: String,
27    /// Estimated execution cost (arbitrary units).
28    pub cost: f64,
29}
30
31impl DepNode {
32    /// Create a new dependency node.
33    pub fn new(id: u64, label: &str, cost: f64) -> Self {
34        Self {
35            id: DepNodeId(id),
36            label: label.to_string(),
37            cost,
38        }
39    }
40}
41
42/// A directed dependency graph for analyzing execution order.
43#[derive(Debug, Default)]
44pub struct DependencyGraph {
45    /// All nodes, keyed by their ID.
46    nodes: HashMap<DepNodeId, DepNode>,
47    /// Adjacency list: node -> set of dependents (successors).
48    forward_edges: HashMap<DepNodeId, HashSet<DepNodeId>>,
49    /// Reverse adjacency: node -> set of dependencies (predecessors).
50    reverse_edges: HashMap<DepNodeId, HashSet<DepNodeId>>,
51}
52
53impl DependencyGraph {
54    /// Create an empty dependency graph.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Add a node to the graph.
60    pub fn add_node(&mut self, node: DepNode) {
61        let id = node.id;
62        self.nodes.insert(id, node);
63        self.forward_edges.entry(id).or_default();
64        self.reverse_edges.entry(id).or_default();
65    }
66
67    /// Add a dependency edge: `from` must complete before `to`.
68    ///
69    /// Returns `true` if the edge was newly added.
70    pub fn add_edge(&mut self, from: DepNodeId, to: DepNodeId) -> bool {
71        self.forward_edges.entry(from).or_default().insert(to);
72        self.reverse_edges.entry(to).or_default().insert(from)
73    }
74
75    /// Number of nodes in the graph.
76    pub fn node_count(&self) -> usize {
77        self.nodes.len()
78    }
79
80    /// Number of edges in the graph.
81    pub fn edge_count(&self) -> usize {
82        self.forward_edges.values().map(|s| s.len()).sum()
83    }
84
85    /// Get the direct dependencies (predecessors) of a node.
86    pub fn dependencies_of(&self, id: DepNodeId) -> Vec<DepNodeId> {
87        self.reverse_edges
88            .get(&id)
89            .map(|s| s.iter().copied().collect())
90            .unwrap_or_default()
91    }
92
93    /// Get the direct dependents (successors) of a node.
94    pub fn dependents_of(&self, id: DepNodeId) -> Vec<DepNodeId> {
95        self.forward_edges
96            .get(&id)
97            .map(|s| s.iter().copied().collect())
98            .unwrap_or_default()
99    }
100
101    /// Find root nodes (no dependencies).
102    pub fn roots(&self) -> Vec<DepNodeId> {
103        self.nodes
104            .keys()
105            .filter(|id| self.reverse_edges.get(id).map_or(true, HashSet::is_empty))
106            .copied()
107            .collect()
108    }
109
110    /// Find leaf nodes (no dependents).
111    pub fn leaves(&self) -> Vec<DepNodeId> {
112        self.nodes
113            .keys()
114            .filter(|id| self.forward_edges.get(id).map_or(true, HashSet::is_empty))
115            .copied()
116            .collect()
117    }
118
119    /// Compute topological ordering using Kahn's algorithm.
120    ///
121    /// Returns `None` if the graph has a cycle.
122    pub fn topological_order(&self) -> Option<Vec<DepNodeId>> {
123        let mut in_degree: HashMap<DepNodeId, usize> = HashMap::new();
124        for id in self.nodes.keys() {
125            in_degree.insert(*id, self.reverse_edges.get(id).map_or(0, HashSet::len));
126        }
127
128        let mut queue: VecDeque<DepNodeId> = in_degree
129            .iter()
130            .filter(|(_, &deg)| deg == 0)
131            .map(|(id, _)| *id)
132            .collect();
133
134        let mut order = Vec::with_capacity(self.nodes.len());
135
136        while let Some(node) = queue.pop_front() {
137            order.push(node);
138            if let Some(successors) = self.forward_edges.get(&node) {
139                for &succ in successors {
140                    if let Some(deg) = in_degree.get_mut(&succ) {
141                        *deg -= 1;
142                        if *deg == 0 {
143                            queue.push_back(succ);
144                        }
145                    }
146                }
147            }
148        }
149
150        if order.len() == self.nodes.len() {
151            Some(order)
152        } else {
153            None
154        }
155    }
156
157    /// Compute all transitive dependencies of a node.
158    pub fn transitive_dependencies(&self, id: DepNodeId) -> HashSet<DepNodeId> {
159        let mut visited = HashSet::new();
160        let mut stack = vec![id];
161        while let Some(current) = stack.pop() {
162            if let Some(deps) = self.reverse_edges.get(&current) {
163                for &dep in deps {
164                    if visited.insert(dep) {
165                        stack.push(dep);
166                    }
167                }
168            }
169        }
170        visited
171    }
172
173    /// Compute the depth of each node (longest path from a root).
174    pub fn compute_depths(&self) -> HashMap<DepNodeId, u32> {
175        let mut depths: HashMap<DepNodeId, u32> = HashMap::new();
176        if let Some(order) = self.topological_order() {
177            for &node in &order {
178                let max_pred_depth = self
179                    .reverse_edges
180                    .get(&node)
181                    .map(|preds| {
182                        preds
183                            .iter()
184                            .filter_map(|p| depths.get(p))
185                            .max()
186                            .copied()
187                            .unwrap_or(0)
188                    })
189                    .unwrap_or(0);
190                let depth = if self
191                    .reverse_edges
192                    .get(&node)
193                    .map_or(true, HashSet::is_empty)
194                {
195                    0
196                } else {
197                    max_pred_depth + 1
198                };
199                depths.insert(node, depth);
200            }
201        }
202        depths
203    }
204
205    /// Group nodes into parallelizable levels.
206    ///
207    /// Nodes at the same level have no inter-dependencies and can run concurrently.
208    pub fn parallel_levels(&self) -> Vec<Vec<DepNodeId>> {
209        let depths = self.compute_depths();
210        if depths.is_empty() {
211            return Vec::new();
212        }
213        let max_depth = depths.values().copied().max().unwrap_or(0);
214        let mut levels = vec![Vec::new(); (max_depth + 1) as usize];
215        for (id, depth) in &depths {
216            levels[*depth as usize].push(*id);
217        }
218        levels
219    }
220
221    /// Compute the critical path (longest weighted path through the graph).
222    #[allow(clippy::cast_precision_loss)]
223    pub fn critical_path(&self) -> (Vec<DepNodeId>, f64) {
224        let order = match self.topological_order() {
225            Some(o) => o,
226            None => return (Vec::new(), 0.0),
227        };
228
229        let mut dist: HashMap<DepNodeId, f64> = HashMap::new();
230        let mut prev: HashMap<DepNodeId, DepNodeId> = HashMap::new();
231
232        for &node in &order {
233            let node_cost = self.nodes.get(&node).map_or(0.0, |n| n.cost);
234            let max_pred = self.reverse_edges.get(&node).and_then(|preds| {
235                preds
236                    .iter()
237                    .filter_map(|p| dist.get(p).map(|d| (*p, *d)))
238                    .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
239            });
240
241            let total = if let Some((pred_id, pred_dist)) = max_pred {
242                prev.insert(node, pred_id);
243                pred_dist + node_cost
244            } else {
245                node_cost
246            };
247            dist.insert(node, total);
248        }
249
250        // Find the node with maximum distance
251        let end_node = dist
252            .iter()
253            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
254            .map(|(id, _)| *id);
255
256        let end_node = match end_node {
257            Some(n) => n,
258            None => return (Vec::new(), 0.0),
259        };
260
261        let total_cost = dist[&end_node];
262
263        // Trace back the path
264        let mut path = vec![end_node];
265        let mut current = end_node;
266        while let Some(&pred) = prev.get(&current) {
267            path.push(pred);
268            current = pred;
269        }
270        path.reverse();
271
272        (path, total_cost)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    fn make_linear_graph() -> DependencyGraph {
281        let mut g = DependencyGraph::new();
282        g.add_node(DepNode::new(0, "A", 1.0));
283        g.add_node(DepNode::new(1, "B", 2.0));
284        g.add_node(DepNode::new(2, "C", 3.0));
285        g.add_edge(DepNodeId(0), DepNodeId(1));
286        g.add_edge(DepNodeId(1), DepNodeId(2));
287        g
288    }
289
290    fn make_diamond_graph() -> DependencyGraph {
291        // A -> B, A -> C, B -> D, C -> D
292        let mut g = DependencyGraph::new();
293        g.add_node(DepNode::new(0, "A", 1.0));
294        g.add_node(DepNode::new(1, "B", 2.0));
295        g.add_node(DepNode::new(2, "C", 4.0));
296        g.add_node(DepNode::new(3, "D", 1.0));
297        g.add_edge(DepNodeId(0), DepNodeId(1));
298        g.add_edge(DepNodeId(0), DepNodeId(2));
299        g.add_edge(DepNodeId(1), DepNodeId(3));
300        g.add_edge(DepNodeId(2), DepNodeId(3));
301        g
302    }
303
304    #[test]
305    fn test_add_node_and_edge() {
306        let g = make_linear_graph();
307        assert_eq!(g.node_count(), 3);
308        assert_eq!(g.edge_count(), 2);
309    }
310
311    #[test]
312    fn test_roots_and_leaves() {
313        let g = make_linear_graph();
314        let roots = g.roots();
315        assert_eq!(roots.len(), 1);
316        assert_eq!(roots[0], DepNodeId(0));
317        let leaves = g.leaves();
318        assert_eq!(leaves.len(), 1);
319        assert_eq!(leaves[0], DepNodeId(2));
320    }
321
322    #[test]
323    fn test_dependencies_of() {
324        let g = make_linear_graph();
325        let deps = g.dependencies_of(DepNodeId(2));
326        assert_eq!(deps.len(), 1);
327        assert_eq!(deps[0], DepNodeId(1));
328    }
329
330    #[test]
331    fn test_dependents_of() {
332        let g = make_linear_graph();
333        let deps = g.dependents_of(DepNodeId(0));
334        assert_eq!(deps.len(), 1);
335        assert_eq!(deps[0], DepNodeId(1));
336    }
337
338    #[test]
339    fn test_topological_order() {
340        let g = make_linear_graph();
341        let order = g
342            .topological_order()
343            .expect("topological_order should succeed");
344        assert_eq!(order.len(), 3);
345        // A before B, B before C
346        let pos_a = order
347            .iter()
348            .position(|&x| x == DepNodeId(0))
349            .expect("iter should succeed");
350        let pos_b = order
351            .iter()
352            .position(|&x| x == DepNodeId(1))
353            .expect("iter should succeed");
354        let pos_c = order
355            .iter()
356            .position(|&x| x == DepNodeId(2))
357            .expect("iter should succeed");
358        assert!(pos_a < pos_b);
359        assert!(pos_b < pos_c);
360    }
361
362    #[test]
363    fn test_topological_order_diamond() {
364        let g = make_diamond_graph();
365        let order = g
366            .topological_order()
367            .expect("topological_order should succeed");
368        assert_eq!(order.len(), 4);
369        let pos_a = order
370            .iter()
371            .position(|&x| x == DepNodeId(0))
372            .expect("iter should succeed");
373        let pos_d = order
374            .iter()
375            .position(|&x| x == DepNodeId(3))
376            .expect("iter should succeed");
377        assert!(pos_a < pos_d);
378    }
379
380    #[test]
381    fn test_transitive_dependencies() {
382        let g = make_linear_graph();
383        let trans = g.transitive_dependencies(DepNodeId(2));
384        assert!(trans.contains(&DepNodeId(0)));
385        assert!(trans.contains(&DepNodeId(1)));
386        assert_eq!(trans.len(), 2);
387    }
388
389    #[test]
390    fn test_compute_depths() {
391        let g = make_linear_graph();
392        let depths = g.compute_depths();
393        assert_eq!(depths[&DepNodeId(0)], 0);
394        assert_eq!(depths[&DepNodeId(1)], 1);
395        assert_eq!(depths[&DepNodeId(2)], 2);
396    }
397
398    #[test]
399    fn test_parallel_levels_diamond() {
400        let g = make_diamond_graph();
401        let levels = g.parallel_levels();
402        assert_eq!(levels.len(), 3);
403        // Level 0: A, Level 1: B and C, Level 2: D
404        assert_eq!(levels[0].len(), 1);
405        assert_eq!(levels[1].len(), 2);
406        assert_eq!(levels[2].len(), 1);
407    }
408
409    #[test]
410    fn test_critical_path_linear() {
411        let g = make_linear_graph();
412        let (path, cost) = g.critical_path();
413        assert_eq!(path.len(), 3);
414        assert!((cost - 6.0).abs() < f64::EPSILON);
415    }
416
417    #[test]
418    fn test_critical_path_diamond() {
419        let g = make_diamond_graph();
420        let (path, cost) = g.critical_path();
421        // Critical path: A(1) -> C(4) -> D(1) = 6.0
422        assert!((cost - 6.0).abs() < f64::EPSILON);
423        assert!(path.contains(&DepNodeId(0)));
424        assert!(path.contains(&DepNodeId(3)));
425    }
426
427    #[test]
428    fn test_empty_graph() {
429        let g = DependencyGraph::new();
430        assert_eq!(g.node_count(), 0);
431        assert_eq!(g.edge_count(), 0);
432        assert!(g.roots().is_empty());
433        assert!(g.leaves().is_empty());
434        let (path, cost) = g.critical_path();
435        assert!(path.is_empty());
436        assert!((cost - 0.0).abs() < f64::EPSILON);
437    }
438
439    #[test]
440    fn test_dep_node_id_display() {
441        let id = DepNodeId(42);
442        assert_eq!(format!("{id}"), "node_42");
443    }
444}