Skip to main content

zrx_graph/graph/
traversal.rs

1// Copyright (c) 2025-2026 Zensical and contributors
2
3// SPDX-License-Identifier: MIT
4// All contributions are certified under the DCO
5
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to
8// deal in the Software without restriction, including without limitation the
9// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10// sell copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22// IN THE SOFTWARE.
23
24// ----------------------------------------------------------------------------
25
26//! Topological traversal.
27
28use ahash::HashSet;
29use std::collections::VecDeque;
30use std::mem;
31
32use super::topology::Topology;
33use super::Graph;
34
35mod error;
36mod into_iter;
37
38pub use error::{Error, Result};
39pub use into_iter::IntoIter;
40
41// ----------------------------------------------------------------------------
42// Structs
43// ----------------------------------------------------------------------------
44
45/// Topological traversal.
46///
47/// This data type manages a topological traversal of a directed acyclic graph
48/// (DAG). It allows visiting nodes in a way that respects their dependencies,
49/// meaning that a node can only be visited after all of its dependencies have
50/// been visited. Visitable nodes can be obtained with [`Traversal::take`].
51///
52/// Note that the traversal itself doesn't know whether it's complete or not,
53/// as it only tracks visitable nodes depending on what has been reported back
54/// to [`Traversal::complete`]. This is because we also need to support partial
55/// traversals that can be resumed, which must be managed by the caller. In case
56/// a traversal starts at an intermediate node, only the nodes and dependencies
57/// reachable from this node are considered, which is necessary for implementing
58/// subgraph traversals that are self-contained, allowing for the creation of
59/// frontiers at any point in the graph.
60#[derive(Clone, Debug)]
61pub struct Traversal {
62    /// Graph topology.
63    topology: Topology,
64    /// Dependency counts.
65    dependencies: Vec<u8>,
66    /// Initial nodes.
67    initial: Vec<usize>,
68    /// Visitable nodes.
69    visitable: VecDeque<usize>,
70}
71
72// ----------------------------------------------------------------------------
73// Implementations
74// ----------------------------------------------------------------------------
75
76impl Traversal {
77    /// Creates a topological traversal.
78    ///
79    /// The given initial nodes are immediately marked as visitable, and thus
80    /// returned by [`Traversal::take`], so the caller must make sure they can
81    /// be processed. Note that the canonical way to create a [`Traversal`] is
82    /// to invoke the [`Graph::traverse`] method.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// # use std::error::Error;
88    /// # fn main() -> Result<(), Box<dyn Error>> {
89    /// use zrx_graph::{Graph, Traversal};
90    ///
91    /// // Create graph builder and add nodes
92    /// let mut builder = Graph::builder();
93    /// let a = builder.add_node("a");
94    /// let b = builder.add_node("b");
95    /// let c = builder.add_node("c");
96    ///
97    /// // Create edges between nodes
98    /// builder.add_edge(a, b)?;
99    /// builder.add_edge(b, c)?;
100    ///
101    /// // Create graph from builder
102    /// let graph = builder.build();
103    ///
104    /// // Create topological traversal
105    /// let traversal = Traversal::new(graph.topology(), [a]);
106    /// # Ok(())
107    /// # }
108    /// ```
109    #[must_use]
110    pub fn new<I>(topology: &Topology, initial: I) -> Self
111    where
112        I: AsRef<[usize]>,
113    {
114        let mut visitable: VecDeque<_> =
115            unique(initial.as_ref().iter()).collect();
116
117        // Obtain incoming edges and distance matrix
118        let incoming = topology.incoming();
119        let distance = topology.distance();
120
121        // When doing a topological traversal, we only visit a node once all of
122        // its dependencies have been visited. This means that we need to track
123        // the number of dependencies for each node, which is the number of
124        // incoming edges for that node.
125        let mut dependencies = incoming.degrees().to_vec();
126        for node in incoming {
127            // We must adjust the dependency count for each node for all of its
128            // dependencies that are not reachable from the initial nodes
129            for &dependency in &incoming[node] {
130                let mut iter = initial.as_ref().iter();
131                if !iter.any(|&n| distance[n][dependency] != u8::MAX) {
132                    dependencies[node] -= 1;
133                }
134            }
135        }
136
137        // Retain only the visitable nodes whose dependencies are satisfied,
138        // as we will discover the other initial nodes during traversal
139        visitable.retain(|&n| dependencies[n] == 0);
140        Self {
141            topology: topology.clone(),
142            dependencies,
143            initial: visitable.iter().copied().collect(),
144            visitable,
145        }
146    }
147
148    /// Returns the next visitable node.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// # use std::error::Error;
154    /// # fn main() -> Result<(), Box<dyn Error>> {
155    /// use zrx_graph::Graph;
156    ///
157    /// // Create graph builder and add nodes
158    /// let mut builder = Graph::builder();
159    /// let a = builder.add_node("a");
160    /// let b = builder.add_node("b");
161    /// let c = builder.add_node("c");
162    ///
163    /// // Create edges between nodes
164    /// builder.add_edge(a, b)?;
165    /// builder.add_edge(b, c)?;
166    ///
167    /// // Create graph from builder
168    /// let graph = builder.build();
169    ///
170    /// // Create topological traversal
171    /// let mut traversal = graph.traverse([a]);
172    /// while let Some(node) = traversal.take() {
173    ///     println!("{node:?}");
174    ///     traversal.complete(node)?;
175    /// }
176    /// # Ok(())
177    /// # }
178    /// ```
179    #[inline]
180    #[must_use]
181    pub fn take(&mut self) -> Option<usize> {
182        self.visitable.pop_front()
183    }
184
185    /// Marks the given node as visited.
186    ///
187    /// This method marks a node as visited as part of a traversal, which might
188    /// allow visiting dependent nodes when all of their dependencies have been
189    /// satisfied. After marking a node as visited, the next nodes that can be
190    /// visited can be obtained using the [`Traversal::take`] method.
191    ///
192    /// # Errors
193    ///
194    /// If the node has already been marked as visited, [`Error::Completed`] is
195    /// returned. This is likely an error in the caller's business logic.
196    ///
197    /// # Panics
198    ///
199    /// Panics if a node does not exist, as this indicates that there's a bug
200    /// in the code that creates or uses the traversal. While the [`Builder`][]
201    /// is designed to be fallible to ensure the structure is valid, methods
202    /// that operate on [`Graph`] panic on violated invariants.
203    ///
204    /// [`Builder`]: crate::graph::Builder
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// # use std::error::Error;
210    /// # fn main() -> Result<(), Box<dyn Error>> {
211    /// use zrx_graph::Graph;
212    ///
213    /// // Create graph builder and add nodes
214    /// let mut builder = Graph::builder();
215    /// let a = builder.add_node("a");
216    /// let b = builder.add_node("b");
217    /// let c = builder.add_node("c");
218    ///
219    /// // Create edges between nodes
220    /// builder.add_edge(a, b)?;
221    /// builder.add_edge(b, c)?;
222    ///
223    /// // Create graph from builder
224    /// let graph = builder.build();
225    ///
226    /// // Create topological traversal
227    /// let mut traversal = graph.traverse([a]);
228    /// while let Some(node) = traversal.take() {
229    ///     println!("{node:?}");
230    ///     traversal.complete(node)?;
231    /// }
232    /// # Ok(())
233    /// # }
234    /// ```
235    pub fn complete(&mut self, node: usize) -> Result {
236        if self.dependencies[node] == u8::MAX {
237            return Err(Error::Completed(node));
238        }
239
240        // When the dependency count isn't zero, the traversal converged with
241        // another traversal and restarted at a node occurring before the given
242        // node. In this case, we return an error, and indicate to the caller
243        // that parts of the traversal have to be completed again.
244        if self.dependencies[node] != 0 {
245            return Err(Error::Converged);
246        }
247
248        // Mark node as visited - we can just use the maximum value of `u8` as
249        // a marker, as we don't expect more than 255 dependencies for any node
250        self.dependencies[node] = u8::MAX;
251
252        // Obtain adjacency list of outgoing edges, and decrement the number
253        // of unresolved dependencies for each dependent by one. When the number
254        // of dependencies for a dependent reaches zero, it can be visited, so
255        // we add it to the queue of visitable nodes.
256        let outgoing = self.topology.outgoing();
257        for &dependent in &outgoing[node] {
258            self.dependencies[dependent] -= 1;
259
260            // We satisfied all dependencies, so the dependent can be visited
261            if self.dependencies[dependent] == 0 {
262                self.visitable.push_back(dependent);
263            }
264        }
265
266        // No errors occurred
267        Ok(())
268    }
269
270    /// Attempts to converge with the given traversal.
271    ///
272    /// This method attempts to merge both traversals into a single traversal,
273    /// which is possible when they have common descendants, a condition that
274    /// is always true for directed acyclic graphs with a single component.
275    /// There are several cases to consider when converging two traversals:
276    ///
277    /// - If traversals start from the same set of source nodes, they already
278    ///   converged, so we just restart the traversal at these source nodes.
279    ///
280    /// - If traversals start from different source nodes, yet both have common
281    ///   descendants, we converge at the first layer of common descendants, as
282    ///   all descendants of them must be revisited in the combined traversal.
283    ///   Ancestors of the common descendants that have already been visited in
284    ///   either traversal don't need to be revisited, and thus are carried over
285    ///   from both traversals in their current state.
286    ///
287    /// - If traversals are disjoint, they can't be converged, so we return the
288    ///   given traversal back to the caller wrapped in [`Error::Disjoint`].
289    ///
290    /// # Errors
291    ///
292    /// If the given traversal is from a different graph, [`Error::Mismatch`]
293    /// is returned. Otherwise, if the traversals are disjoint, the traversal
294    /// is returned back to the caller wrapped in [`Error::Disjoint`], so the
295    /// caller can decide how to proceed.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// # use std::error::Error;
301    /// # fn main() -> Result<(), Box<dyn Error>> {
302    /// use zrx_graph::Graph;
303    ///
304    /// // Create graph builder and add nodes
305    /// let mut builder = Graph::builder();
306    /// let a = builder.add_node("a");
307    /// let b = builder.add_node("b");
308    /// let c = builder.add_node("c");
309    ///
310    /// // Create edges between nodes
311    /// builder.add_edge(a, c)?;
312    /// builder.add_edge(b, c)?;
313    ///
314    /// // Create graph from builder
315    /// let graph = builder.build();
316    ///
317    /// // Create topological traversal
318    /// let mut traversal = graph.traverse([a]);
319    ///
320    /// // Converge with another topological traversal
321    /// let mut other = graph.traverse([b]);
322    /// assert!(traversal.converge(other).is_ok());
323    /// # Ok(())
324    /// # }
325    /// ```
326    pub fn converge(&mut self, other: Self) -> Result {
327        if self.topology != other.topology {
328            return Err(Error::Mismatch);
329        }
330
331        // Compute the initial nodes for the combined traversal, which is the
332        // union of both traversals with all redundant nodes removed
333        let iter = self.initial.iter().chain(&other.initial);
334        let initial: Vec<_> = unique(iter).collect();
335
336        // Create a temporary graph, so we can compute the first layer of nodes
337        // that are common descendants contained in both traversals
338        let graph = Graph {
339            data: (0..self.topology.incoming().len()).collect(),
340            topology: self.topology.clone(),
341        };
342
343        // If there are no common descendants, the traversals are disjoint and
344        // can't converge, so we return the given traversal back to the caller
345        let mut iter = graph.common_descendants(&initial);
346        let Some(common) = iter.next() else {
347            return Err(Error::Disjoint(other));
348        };
349
350        // Create the combined traversal, and mark all already visited nodes
351        // that are ancestors of the common descendants as visited
352        let prior = mem::replace(self, Self::new(&self.topology, initial));
353
354        // Compute the visitable nodes for the combined traversal, which is the
355        // union of both traversals with all redundant nodes removed. Note that
356        // we must collect them in a temporary vector first, or this loop would
357        // run indefinitely, as we'd be adding visitable nodes again and again.
358        let mut visitable = VecDeque::new();
359        while let Some(node) = self.take() {
360            let p = prior.dependencies[node];
361            let o = other.dependencies[node];
362
363            // If the node has been visited in either traversal, and is not part
364            // of the first layer of common descendants, mark it as visited in
365            // the combined traversal, since we don't need to revisit ancestors
366            // that have already been visited in either traversal
367            if (p == u8::MAX || o == u8::MAX) && !common.contains(&node) {
368                self.complete(node)?;
369            } else {
370                visitable.push_back(node);
371            }
372        }
373
374        // Update visitable nodes
375        self.visitable = visitable;
376
377        // No errors occurred
378        Ok(())
379    }
380}
381
382#[allow(clippy::must_use_candidate)]
383impl Traversal {
384    /// Returns a reference to the graph topology.
385    #[inline]
386    pub fn topology(&self) -> &Topology {
387        &self.topology
388    }
389
390    /// Returns a reference to the initial nodes.
391    #[inline]
392    pub fn initial(&self) -> &[usize] {
393        &self.initial
394    }
395
396    /// Returns the number of visitable nodes.
397    #[inline]
398    pub fn len(&self) -> usize {
399        self.visitable.len()
400    }
401
402    /// Returns whether there are any visitable nodes.
403    #[inline]
404    pub fn is_empty(&self) -> bool {
405        self.visitable.is_empty()
406    }
407}
408
409// ----------------------------------------------------------------------------
410// Functions
411// ----------------------------------------------------------------------------
412
413/// Deduplicates the given nodes while preserving their order.
414#[inline]
415fn unique<'a, I>(iter: I) -> impl Iterator<Item = usize>
416where
417    I: IntoIterator<Item = &'a usize>,
418{
419    let mut nodes = HashSet::default();
420    iter.into_iter() // fmt
421        .copied()
422        .filter(move |&node| nodes.insert(node))
423}
424
425// ----------------------------------------------------------------------------
426// Tests
427// ----------------------------------------------------------------------------
428
429#[cfg(test)]
430mod tests {
431
432    mod complete {
433        use crate::graph;
434
435        #[test]
436        fn handles_graph() {
437            let graph = graph! {
438                "a" => "b", "a" => "c",
439                "b" => "d", "b" => "e",
440                "c" => "f",
441                "d" => "g",
442                "e" => "g", "e" => "h",
443                "f" => "h",
444                "g" => "i",
445                "h" => "i",
446            };
447            for (node, mut descendants) in [
448                (0, vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
449                (1, vec![1, 3, 4, 6, 7, 8]),
450                (2, vec![2, 5, 7, 8]),
451                (3, vec![3, 6, 8]),
452                (4, vec![4, 6, 7, 8]),
453                (5, vec![5, 7, 8]),
454                (6, vec![6, 8]),
455                (7, vec![7, 8]),
456                (8, vec![8]),
457            ] {
458                let mut traversal = graph.traverse([node]);
459                while let Some(node) = traversal.take() {
460                    assert_eq!(node, descendants.remove(0));
461                    assert!(traversal.complete(node).is_ok());
462                }
463            }
464        }
465
466        #[test]
467        fn handles_multi_graph() {
468            let graph = graph! {
469                "a" => "b", "a" => "c", "a" => "c",
470                "b" => "d", "b" => "e",
471                "c" => "f",
472                "d" => "g",
473                "e" => "g", "e" => "h",
474                "f" => "h",
475                "g" => "i",
476                "h" => "i",
477            };
478            for (node, mut descendants) in [
479                (0, vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
480                (1, vec![1, 3, 4, 6, 7, 8]),
481                (2, vec![2, 5, 7, 8]),
482                (3, vec![3, 6, 8]),
483                (4, vec![4, 6, 7, 8]),
484                (5, vec![5, 7, 8]),
485                (6, vec![6, 8]),
486                (7, vec![7, 8]),
487                (8, vec![8]),
488            ] {
489                let mut traversal = graph.traverse([node]);
490                while let Some(node) = traversal.take() {
491                    assert_eq!(node, descendants.remove(0));
492                    assert!(traversal.complete(node).is_ok());
493                }
494            }
495        }
496    }
497
498    mod converge {
499        use crate::graph;
500
501        #[test]
502        fn handles_graph() {
503            let graph = graph! {
504                "a" => "b", "a" => "c",
505                "b" => "d", "b" => "e",
506                "c" => "f",
507                "d" => "g",
508                "e" => "g", "e" => "h",
509                "f" => "h",
510                "g" => "i",
511                "h" => "i",
512            };
513            for (i, j, descendants) in [
514                (vec![0], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
515                (vec![1], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
516                (vec![8], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
517                (vec![1], vec![1], vec![1, 3, 4, 6, 7, 8]),
518                (vec![1], vec![2], vec![1, 2, 3, 4, 5, 6, 7, 8]),
519                (vec![2], vec![4], vec![2, 4, 5, 6, 7, 8]),
520                (vec![4], vec![2], vec![4, 2, 6, 5, 7, 8]),
521                (vec![3], vec![5], vec![3, 5, 6, 7, 8]),
522                (vec![6], vec![7], vec![6, 7, 8]),
523                (vec![8], vec![8], vec![8]),
524            ] {
525                let mut traversal = graph.traverse(i);
526                assert!(traversal.converge(graph.traverse(j)).is_ok());
527                assert_eq!(
528                    traversal.into_iter().collect::<Vec<_>>(), // fmt
529                    descendants
530                );
531            }
532        }
533
534        #[test]
535        fn handles_multi_graph() {
536            let graph = graph! {
537                "a" => "b", "a" => "c", "a" => "c",
538                "b" => "d", "b" => "e",
539                "c" => "f",
540                "d" => "g",
541                "e" => "g", "e" => "h",
542                "f" => "h",
543                "g" => "i",
544                "h" => "i",
545            };
546            for (i, j, descendants) in [
547                (vec![0], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
548                (vec![1], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
549                (vec![8], vec![0], vec![0, 1, 2, 3, 4, 5, 6, 7, 8]),
550                (vec![1], vec![1], vec![1, 3, 4, 6, 7, 8]),
551                (vec![1], vec![2], vec![1, 2, 3, 4, 5, 6, 7, 8]),
552                (vec![2], vec![4], vec![2, 4, 5, 6, 7, 8]),
553                (vec![4], vec![2], vec![4, 2, 6, 5, 7, 8]),
554                (vec![3], vec![5], vec![3, 5, 6, 7, 8]),
555                (vec![6], vec![7], vec![6, 7, 8]),
556                (vec![8], vec![8], vec![8]),
557            ] {
558                let mut traversal = graph.traverse(i);
559                assert!(traversal.converge(graph.traverse(j)).is_ok());
560                assert_eq!(
561                    traversal.into_iter().collect::<Vec<_>>(), // fmt
562                    descendants
563                );
564            }
565        }
566    }
567}