Skip to main content

oximedia_graph/
topological.rs

1//! Topological sorting for directed acyclic graphs.
2//!
3//! This module provides Kahn's algorithm and DFS-based topological sort
4//! implementations for ordering graph nodes such that every directed edge
5//! goes from an earlier node to a later node in the ordering.
6//!
7//! For large graphs (> 1 000 nodes) prefer [`FastTopoSorter`], which uses
8//! integer-indexed adjacency lists and an in-degree array instead of hash maps,
9//! cutting constant-factor overhead by roughly 3–5×.
10
11pub use fast_topo::CycleError;
12pub use fast_topo::FastTopoSorter;
13
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// A node identifier in the topological graph.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct TopoNodeId(
19    /// Inner identifier value.
20    pub usize,
21);
22
23impl std::fmt::Display for TopoNodeId {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "Node({})", self.0)
26    }
27}
28
29/// Error types for topological sort operations.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum TopoError {
32    /// The graph contains a cycle, making topological sort impossible.
33    CycleDetected(
34        /// Nodes involved in the cycle.
35        Vec<TopoNodeId>,
36    ),
37    /// A referenced node does not exist in the graph.
38    NodeNotFound(
39        /// The missing node.
40        TopoNodeId,
41    ),
42    /// The graph is empty.
43    EmptyGraph,
44}
45
46impl std::fmt::Display for TopoError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::CycleDetected(nodes) => {
50                write!(f, "Cycle detected involving {} nodes", nodes.len())
51            }
52            Self::NodeNotFound(id) => write!(f, "Node {id} not found"),
53            Self::EmptyGraph => write!(f, "Graph is empty"),
54        }
55    }
56}
57
58/// Directed graph structure for topological sorting.
59pub struct TopoGraph {
60    /// Adjacency list: node -> set of successor nodes.
61    adjacency: HashMap<TopoNodeId, HashSet<TopoNodeId>>,
62    /// Reverse adjacency: node -> set of predecessor nodes.
63    reverse: HashMap<TopoNodeId, HashSet<TopoNodeId>>,
64}
65
66impl TopoGraph {
67    /// Create a new empty topological graph.
68    pub fn new() -> Self {
69        Self {
70            adjacency: HashMap::new(),
71            reverse: HashMap::new(),
72        }
73    }
74
75    /// Add a node to the graph.
76    pub fn add_node(&mut self, id: TopoNodeId) {
77        self.adjacency.entry(id).or_default();
78        self.reverse.entry(id).or_default();
79    }
80
81    /// Add a directed edge from `from` to `to`.
82    pub fn add_edge(&mut self, from: TopoNodeId, to: TopoNodeId) {
83        self.add_node(from);
84        self.add_node(to);
85        self.adjacency.entry(from).or_default().insert(to);
86        self.reverse.entry(to).or_default().insert(from);
87    }
88
89    /// Return the number of nodes.
90    pub fn node_count(&self) -> usize {
91        self.adjacency.len()
92    }
93
94    /// Return the number of edges.
95    pub fn edge_count(&self) -> usize {
96        self.adjacency.values().map(|s| s.len()).sum()
97    }
98
99    /// Return the in-degree of a node.
100    pub fn in_degree(&self, id: TopoNodeId) -> usize {
101        self.reverse.get(&id).map_or(0, |s| s.len())
102    }
103
104    /// Return the out-degree of a node.
105    pub fn out_degree(&self, id: TopoNodeId) -> usize {
106        self.adjacency.get(&id).map_or(0, |s| s.len())
107    }
108
109    /// Return all nodes with in-degree zero (source nodes).
110    pub fn sources(&self) -> Vec<TopoNodeId> {
111        let mut sources: Vec<TopoNodeId> = self
112            .adjacency
113            .keys()
114            .filter(|id| self.in_degree(**id) == 0)
115            .copied()
116            .collect();
117        sources.sort();
118        sources
119    }
120
121    /// Return all nodes with out-degree zero (sink nodes).
122    pub fn sinks(&self) -> Vec<TopoNodeId> {
123        let mut sinks: Vec<TopoNodeId> = self
124            .adjacency
125            .keys()
126            .filter(|id| self.out_degree(**id) == 0)
127            .copied()
128            .collect();
129        sinks.sort();
130        sinks
131    }
132
133    /// Perform topological sort using Kahn's algorithm (BFS-based).
134    ///
135    /// Returns nodes in topological order or an error if a cycle exists.
136    pub fn sort_kahn(&self) -> Result<Vec<TopoNodeId>, TopoError> {
137        if self.adjacency.is_empty() {
138            return Err(TopoError::EmptyGraph);
139        }
140
141        let mut in_degrees: HashMap<TopoNodeId, usize> = HashMap::new();
142        for &node in self.adjacency.keys() {
143            in_degrees.insert(node, self.in_degree(node));
144        }
145
146        let mut queue: VecDeque<TopoNodeId> = in_degrees
147            .iter()
148            .filter(|(_, &deg)| deg == 0)
149            .map(|(&id, _)| id)
150            .collect();
151
152        // Sort queue for deterministic output
153        let mut sorted_start: Vec<TopoNodeId> = queue.drain(..).collect();
154        sorted_start.sort();
155        queue.extend(sorted_start);
156
157        let mut result = Vec::with_capacity(self.adjacency.len());
158
159        while let Some(node) = queue.pop_front() {
160            result.push(node);
161            if let Some(successors) = self.adjacency.get(&node) {
162                let mut sorted_succ: Vec<TopoNodeId> = successors.iter().copied().collect();
163                sorted_succ.sort();
164                for succ in sorted_succ {
165                    if let Some(deg) = in_degrees.get_mut(&succ) {
166                        *deg -= 1;
167                        if *deg == 0 {
168                            queue.push_back(succ);
169                        }
170                    }
171                }
172            }
173        }
174
175        if result.len() != self.adjacency.len() {
176            let remaining: Vec<TopoNodeId> = self
177                .adjacency
178                .keys()
179                .filter(|id| !result.contains(id))
180                .copied()
181                .collect();
182            return Err(TopoError::CycleDetected(remaining));
183        }
184
185        Ok(result)
186    }
187
188    /// Perform topological sort using DFS-based algorithm.
189    ///
190    /// Returns nodes in topological order or an error if a cycle exists.
191    pub fn sort_dfs(&self) -> Result<Vec<TopoNodeId>, TopoError> {
192        if self.adjacency.is_empty() {
193            return Err(TopoError::EmptyGraph);
194        }
195
196        let mut visited: HashSet<TopoNodeId> = HashSet::new();
197        let mut in_stack: HashSet<TopoNodeId> = HashSet::new();
198        let mut result: Vec<TopoNodeId> = Vec::new();
199
200        let mut nodes: Vec<TopoNodeId> = self.adjacency.keys().copied().collect();
201        nodes.sort();
202
203        for node in &nodes {
204            if !visited.contains(node)
205                && !Self::dfs_visit(
206                    *node,
207                    &self.adjacency,
208                    &mut visited,
209                    &mut in_stack,
210                    &mut result,
211                )
212            {
213                let cycle_nodes: Vec<TopoNodeId> = in_stack.into_iter().collect();
214                return Err(TopoError::CycleDetected(cycle_nodes));
215            }
216        }
217
218        result.reverse();
219        Ok(result)
220    }
221
222    /// DFS visit helper. Returns false if a cycle is detected.
223    fn dfs_visit(
224        node: TopoNodeId,
225        adjacency: &HashMap<TopoNodeId, HashSet<TopoNodeId>>,
226        visited: &mut HashSet<TopoNodeId>,
227        in_stack: &mut HashSet<TopoNodeId>,
228        result: &mut Vec<TopoNodeId>,
229    ) -> bool {
230        visited.insert(node);
231        in_stack.insert(node);
232
233        if let Some(successors) = adjacency.get(&node) {
234            let mut sorted_succ: Vec<TopoNodeId> = successors.iter().copied().collect();
235            sorted_succ.sort();
236            for succ in sorted_succ {
237                if in_stack.contains(&succ) {
238                    return false;
239                }
240                if !visited.contains(&succ)
241                    && !Self::dfs_visit(succ, adjacency, visited, in_stack, result)
242                {
243                    return false;
244                }
245            }
246        }
247
248        in_stack.remove(&node);
249        result.push(node);
250        true
251    }
252
253    /// Check if the graph is a DAG (has no cycles).
254    pub fn is_dag(&self) -> bool {
255        self.sort_kahn().is_ok()
256    }
257
258    /// Return the longest path length in the DAG.
259    pub fn longest_path(&self) -> Result<usize, TopoError> {
260        let order = self.sort_kahn()?;
261        let mut dist: HashMap<TopoNodeId, usize> = HashMap::new();
262        for &node in &order {
263            dist.insert(node, 0);
264        }
265
266        for &node in &order {
267            let node_dist = dist[&node];
268            if let Some(successors) = self.adjacency.get(&node) {
269                for &succ in successors {
270                    let entry = dist.entry(succ).or_insert(0);
271                    if node_dist + 1 > *entry {
272                        *entry = node_dist + 1;
273                    }
274                }
275            }
276        }
277
278        Ok(dist.values().copied().max().unwrap_or(0))
279    }
280
281    /// Return the depth (longest path from any source) for each node.
282    pub fn node_depths(&self) -> Result<HashMap<TopoNodeId, usize>, TopoError> {
283        let order = self.sort_kahn()?;
284        let mut depths: HashMap<TopoNodeId, usize> = HashMap::new();
285        for &node in &order {
286            depths.insert(node, 0);
287        }
288
289        for &node in &order {
290            let node_depth = depths[&node];
291            if let Some(successors) = self.adjacency.get(&node) {
292                for &succ in successors {
293                    let entry = depths.entry(succ).or_insert(0);
294                    if node_depth + 1 > *entry {
295                        *entry = node_depth + 1;
296                    }
297                }
298            }
299        }
300
301        Ok(depths)
302    }
303
304    /// Check if node `a` can reach node `b` (transitively).
305    pub fn can_reach(&self, a: TopoNodeId, b: TopoNodeId) -> bool {
306        let mut visited: HashSet<TopoNodeId> = HashSet::new();
307        let mut queue = VecDeque::new();
308        queue.push_back(a);
309
310        while let Some(current) = queue.pop_front() {
311            if current == b {
312                return true;
313            }
314            if visited.insert(current) {
315                if let Some(successors) = self.adjacency.get(&current) {
316                    for &succ in successors {
317                        queue.push_back(succ);
318                    }
319                }
320            }
321        }
322
323        false
324    }
325}
326
327impl Default for TopoGraph {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    fn n(id: usize) -> TopoNodeId {
338        TopoNodeId(id)
339    }
340
341    #[test]
342    fn test_empty_graph() {
343        let graph = TopoGraph::new();
344        assert_eq!(graph.node_count(), 0);
345        assert!(matches!(graph.sort_kahn(), Err(TopoError::EmptyGraph)));
346    }
347
348    #[test]
349    fn test_single_node() {
350        let mut graph = TopoGraph::new();
351        graph.add_node(n(0));
352        let order = graph.sort_kahn().expect("sort_kahn should succeed");
353        assert_eq!(order, vec![n(0)]);
354    }
355
356    #[test]
357    fn test_linear_chain() {
358        let mut graph = TopoGraph::new();
359        graph.add_edge(n(0), n(1));
360        graph.add_edge(n(1), n(2));
361        graph.add_edge(n(2), n(3));
362        let order = graph.sort_kahn().expect("sort_kahn should succeed");
363        assert_eq!(order, vec![n(0), n(1), n(2), n(3)]);
364    }
365
366    #[test]
367    fn test_diamond_graph() {
368        let mut graph = TopoGraph::new();
369        graph.add_edge(n(0), n(1));
370        graph.add_edge(n(0), n(2));
371        graph.add_edge(n(1), n(3));
372        graph.add_edge(n(2), n(3));
373        let order = graph.sort_kahn().expect("sort_kahn should succeed");
374        assert_eq!(order[0], n(0));
375        assert_eq!(order[3], n(3));
376    }
377
378    #[test]
379    fn test_cycle_detection_kahn() {
380        let mut graph = TopoGraph::new();
381        graph.add_edge(n(0), n(1));
382        graph.add_edge(n(1), n(2));
383        graph.add_edge(n(2), n(0));
384        let result = graph.sort_kahn();
385        assert!(matches!(result, Err(TopoError::CycleDetected(_))));
386    }
387
388    #[test]
389    fn test_cycle_detection_dfs() {
390        let mut graph = TopoGraph::new();
391        graph.add_edge(n(0), n(1));
392        graph.add_edge(n(1), n(2));
393        graph.add_edge(n(2), n(0));
394        let result = graph.sort_dfs();
395        assert!(matches!(result, Err(TopoError::CycleDetected(_))));
396    }
397
398    #[test]
399    fn test_dfs_sort_matches_kahn() {
400        let mut graph = TopoGraph::new();
401        graph.add_edge(n(0), n(1));
402        graph.add_edge(n(0), n(2));
403        graph.add_edge(n(1), n(3));
404        graph.add_edge(n(2), n(3));
405        let kahn = graph.sort_kahn().expect("sort_kahn should succeed");
406        let dfs = graph.sort_dfs().expect("sort_dfs should succeed");
407        // Both should have 0 first and 3 last
408        assert_eq!(kahn[0], n(0));
409        assert_eq!(dfs[0], n(0));
410        assert_eq!(*kahn.last().expect("last should succeed"), n(3));
411        assert_eq!(*dfs.last().expect("last should succeed"), n(3));
412    }
413
414    #[test]
415    fn test_sources_and_sinks() {
416        let mut graph = TopoGraph::new();
417        graph.add_edge(n(0), n(2));
418        graph.add_edge(n(1), n(2));
419        graph.add_edge(n(2), n(3));
420        graph.add_edge(n(2), n(4));
421        assert_eq!(graph.sources(), vec![n(0), n(1)]);
422        assert_eq!(graph.sinks(), vec![n(3), n(4)]);
423    }
424
425    #[test]
426    fn test_in_out_degree() {
427        let mut graph = TopoGraph::new();
428        graph.add_edge(n(0), n(1));
429        graph.add_edge(n(0), n(2));
430        graph.add_edge(n(1), n(2));
431        assert_eq!(graph.out_degree(n(0)), 2);
432        assert_eq!(graph.in_degree(n(2)), 2);
433        assert_eq!(graph.in_degree(n(0)), 0);
434    }
435
436    #[test]
437    fn test_is_dag() {
438        let mut graph = TopoGraph::new();
439        graph.add_edge(n(0), n(1));
440        graph.add_edge(n(1), n(2));
441        assert!(graph.is_dag());
442
443        graph.add_edge(n(2), n(0));
444        assert!(!graph.is_dag());
445    }
446
447    #[test]
448    fn test_longest_path() {
449        let mut graph = TopoGraph::new();
450        graph.add_edge(n(0), n(1));
451        graph.add_edge(n(1), n(2));
452        graph.add_edge(n(0), n(2));
453        assert_eq!(
454            graph.longest_path().expect("longest_path should succeed"),
455            2
456        );
457    }
458
459    #[test]
460    fn test_node_depths() {
461        let mut graph = TopoGraph::new();
462        graph.add_edge(n(0), n(1));
463        graph.add_edge(n(0), n(2));
464        graph.add_edge(n(1), n(3));
465        graph.add_edge(n(2), n(3));
466        let depths = graph.node_depths().expect("node_depths should succeed");
467        assert_eq!(depths[&n(0)], 0);
468        assert_eq!(depths[&n(3)], 2);
469    }
470
471    #[test]
472    fn test_can_reach() {
473        let mut graph = TopoGraph::new();
474        graph.add_edge(n(0), n(1));
475        graph.add_edge(n(1), n(2));
476        assert!(graph.can_reach(n(0), n(2)));
477        assert!(!graph.can_reach(n(2), n(0)));
478    }
479
480    #[test]
481    fn test_topo_error_display() {
482        let err = TopoError::EmptyGraph;
483        assert_eq!(format!("{err}"), "Graph is empty");
484        let err2 = TopoError::NodeNotFound(n(5));
485        assert!(format!("{err2}").contains("5"));
486    }
487
488    #[test]
489    fn test_edge_count() {
490        let mut graph = TopoGraph::new();
491        graph.add_edge(n(0), n(1));
492        graph.add_edge(n(1), n(2));
493        graph.add_edge(n(0), n(2));
494        assert_eq!(graph.edge_count(), 3);
495    }
496
497    #[test]
498    fn test_node_id_display() {
499        let id = TopoNodeId(42);
500        assert_eq!(format!("{id}"), "Node(42)");
501    }
502}
503
504// ─────────────────────────────────────────────────────────────────────────────
505// High-performance integer-indexed topological sorter
506// ─────────────────────────────────────────────────────────────────────────────
507
508/// Sub-module containing [`FastTopoSorter`] and [`CycleError`].
509pub mod fast_topo {
510    use std::collections::VecDeque;
511
512    /// Error returned by [`FastTopoSorter::sort`] when a cycle is detected.
513    #[derive(Debug, Clone, PartialEq, Eq)]
514    pub struct CycleError {
515        /// Nodes that were not reachable in topological order (i.e., they form
516        /// or are downstream of the cycle).
517        pub remaining: Vec<usize>,
518    }
519
520    impl std::fmt::Display for CycleError {
521        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522            write!(
523                f,
524                "cycle detected; {} nodes could not be ordered",
525                self.remaining.len()
526            )
527        }
528    }
529
530    impl std::error::Error for CycleError {}
531
532    /// Cache-friendly topological sorter for large graphs.
533    ///
534    /// Uses a `Vec<Vec<usize>>` adjacency list indexed directly by node integer
535    /// ID (O(1) random access, no hash overhead) and a `Vec<u32>` in-degree
536    /// array.  Kahn's BFS algorithm is used for cycle detection and ordering.
537    ///
538    /// # Example
539    ///
540    /// ```
541    /// use oximedia_graph::topological::FastTopoSorter;
542    ///
543    /// let mut sorter = FastTopoSorter::new(4);
544    /// sorter.add_edge(0, 1);
545    /// sorter.add_edge(1, 2);
546    /// sorter.add_edge(0, 3);
547    /// let order = sorter.sort().expect("DAG must sort cleanly");
548    /// assert_eq!(order[0], 0);
549    /// ```
550    pub struct FastTopoSorter {
551        /// Number of nodes.
552        n: usize,
553        /// `adjacency[i]` — integer indices of nodes reachable from node `i`.
554        adjacency: Vec<Vec<usize>>,
555        /// `in_degree[i]` — number of incoming edges for node `i`.
556        in_degree: Vec<u32>,
557    }
558
559    impl FastTopoSorter {
560        /// Create a sorter for a graph with exactly `n` nodes (IDs `0..n`).
561        #[must_use]
562        pub fn new(n: usize) -> Self {
563            Self {
564                n,
565                adjacency: vec![Vec::new(); n],
566                in_degree: vec![0u32; n],
567            }
568        }
569
570        /// Add a directed edge from node `from` to node `to`.
571        ///
572        /// # Panics
573        ///
574        /// Panics if either `from` or `to` is `>= n`.
575        pub fn add_edge(&mut self, from: usize, to: usize) {
576            assert!(from < self.n, "node id {from} out of range (n={})", self.n);
577            assert!(to < self.n, "node id {to} out of range (n={})", self.n);
578            self.adjacency[from].push(to);
579            self.in_degree[to] = self.in_degree[to].saturating_add(1);
580        }
581
582        /// Run Kahn's algorithm and return nodes in topological order.
583        ///
584        /// Returns [`Err(CycleError)`] if the graph contains a cycle; the
585        /// `remaining` field lists node IDs whose in-degree never reached zero.
586        ///
587        /// # Complexity
588        ///
589        /// O(V + E) time, O(V) extra space.
590        pub fn sort(&self) -> Result<Vec<usize>, CycleError> {
591            // Work on a mutable copy of in-degrees so `self` stays immutable
592            // (callers may want to sort multiple times or inspect the structure).
593            let mut deg = self.in_degree.clone();
594
595            // Seed the queue with all zero-in-degree nodes (sorted for determinism).
596            let mut queue: VecDeque<usize> = (0..self.n).filter(|&i| deg[i] == 0).collect();
597
598            let mut result = Vec::with_capacity(self.n);
599
600            while let Some(node) = queue.pop_front() {
601                result.push(node);
602                for &succ in &self.adjacency[node] {
603                    // Saturating sub: in_degree should never underflow on a
604                    // well-formed graph, but we avoid panics defensively.
605                    deg[succ] = deg[succ].saturating_sub(1);
606                    if deg[succ] == 0 {
607                        queue.push_back(succ);
608                    }
609                }
610            }
611
612            if result.len() != self.n {
613                let remaining: Vec<usize> = (0..self.n).filter(|&i| deg[i] > 0).collect();
614                return Err(CycleError { remaining });
615            }
616
617            Ok(result)
618        }
619
620        /// Return the number of nodes in the graph.
621        #[must_use]
622        pub fn node_count(&self) -> usize {
623            self.n
624        }
625
626        /// Return the total number of edges.
627        #[must_use]
628        pub fn edge_count(&self) -> usize {
629            self.adjacency.iter().map(|v| v.len()).sum()
630        }
631    }
632
633    #[cfg(test)]
634    mod tests {
635        use super::*;
636
637        #[test]
638        fn test_fast_topo_simple() {
639            // 5-node DAG:  0 → 1 → 3
640            //              0 → 2 → 3
641            //              3 → 4
642            let mut s = FastTopoSorter::new(5);
643            s.add_edge(0, 1);
644            s.add_edge(0, 2);
645            s.add_edge(1, 3);
646            s.add_edge(2, 3);
647            s.add_edge(3, 4);
648
649            let order = s.sort().expect("DAG must succeed");
650            assert_eq!(order.len(), 5, "all 5 nodes must appear");
651            assert_eq!(order[0], 0, "node 0 has no predecessors");
652            assert_eq!(
653                *order.last().expect("non-empty"),
654                4,
655                "node 4 is the only sink"
656            );
657
658            // Verify topological constraint: for every edge u→v, pos(u) < pos(v).
659            let mut pos = vec![0usize; 5];
660            for (rank, &node) in order.iter().enumerate() {
661                pos[node] = rank;
662            }
663            assert!(pos[0] < pos[1]);
664            assert!(pos[0] < pos[2]);
665            assert!(pos[1] < pos[3]);
666            assert!(pos[2] < pos[3]);
667            assert!(pos[3] < pos[4]);
668        }
669
670        #[test]
671        fn test_fast_topo_cycle_detected() {
672            // 0 → 1 → 2 → 0  (simple 3-cycle)
673            let mut s = FastTopoSorter::new(3);
674            s.add_edge(0, 1);
675            s.add_edge(1, 2);
676            s.add_edge(2, 0);
677
678            let result = s.sort();
679            assert!(result.is_err(), "cycle must produce an error");
680            let err = result.expect_err("expected CycleError");
681            assert_eq!(
682                err.remaining.len(),
683                3,
684                "all three nodes are stuck in the cycle"
685            );
686        }
687
688        #[test]
689        fn test_fast_topo_large() {
690            // 10 000-node linear chain: 0 → 1 → 2 → … → 9999
691            let n = 10_000usize;
692            let mut s = FastTopoSorter::new(n);
693            for i in 0..n - 1 {
694                s.add_edge(i, i + 1);
695            }
696
697            let start = std::time::Instant::now();
698            let order = s.sort().expect("linear chain must sort cleanly");
699            let elapsed = start.elapsed();
700
701            assert_eq!(order.len(), n, "all {n} nodes must appear");
702            assert!(
703                elapsed.as_millis() < 50,
704                "sort must complete in < 50 ms, took {} ms",
705                elapsed.as_millis()
706            );
707
708            // Verify the chain is sorted in order.
709            for (rank, &node) in order.iter().enumerate() {
710                assert_eq!(node, rank, "linear chain must be in strict ascending order");
711            }
712        }
713    }
714}