id_graph_sccs/
lib.rs

1//! A small crate for finding the [strongly-connected components](https://en.wikipedia.org/wiki/Strongly_connected_component)
2//! of a directed graph.
3//!
4//! This crate is built on the [`id_collections`](https://crates.io/crates/id_collections) crate,
5//! and is designed to work with graphs in which nodes are labeled by integer indices belonging to a
6//! contiguous range from zero to some upper bound. The edges of the input graph do not need to be
7//! stored in any particular format; the caller provides the outgoing edges for each node via a
8//! callback function which is invoked lazily as the algorithm traverses the graph.
9//!
10//! The implementation of the algorithm does not rely on recursion, so it is safe to run it on
11//! arbitrarily large graphs without risking a stack overflow.
12//!
13//! # Examples
14//!
15//! ```
16//! use id_collections::{id_type, IdVec};
17//! use id_graph_sccs::{find_components, Sccs, Scc, SccKind};
18//!
19//! #[id_type]
20//! struct NodeId(u32);
21//!
22//! #[id_type]
23//! struct SccId(u32);
24//!
25//! // Note: you are not required to store the edges of the input graph in an 'IdVec'; all that
26//! // matters is that you are able to pass a closure to the 'find_components' function which
27//! // returns the edges for a given node.
28//! let mut graph: IdVec<NodeId, Vec<NodeId>> = IdVec::new();
29//!
30//! let node_a = graph.push(Vec::new());
31//! let node_b = graph.push(Vec::new());
32//! let node_c = graph.push(Vec::new());
33//! let node_d = graph.push(Vec::new());
34//!
35//! graph[node_a].extend([node_a, node_b]);
36//! graph[node_b].extend([node_c]);
37//! graph[node_c].extend([node_b, node_d]);
38//!
39//! let sccs: Sccs<SccId, NodeId> = find_components(graph.count(), |node| &graph[node]);
40//!
41//! // We can iterate over 'sccs' to obtain the components of the graph:
42//! let mut components: Vec<Scc<NodeId>> = Vec::new();
43//! for (_scc_id, component) in &sccs {
44//!     components.push(component);
45//! }
46//!
47//! assert_eq!(components.len(), 3);
48//!
49//! assert_eq!(components[0].kind, SccKind::Acyclic);
50//! assert_eq!(components[0].nodes, &[node_d]);
51//!
52//! assert_eq!(components[1].kind, SccKind::Cyclic);
53//! assert!(components[1].nodes.contains(&node_b));
54//! assert!(components[1].nodes.contains(&node_c));
55//!
56//! assert_eq!(components[2].kind, SccKind::Cyclic);
57//! assert_eq!(components[2].nodes, &[node_a]);
58//! ```
59use id_collections::{count::IdRangeIter, id::ToPrimIntUnchecked, Count, Id, IdMap, IdVec};
60use num_traits::{CheckedSub, One, ToPrimitive};
61use std::{borrow::Borrow, fmt::Debug, iter::FusedIterator};
62
63/// Indicates if a component contains a cycle.
64#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
65pub enum SccKind {
66    /// Indicates that the component does not contain a cycle. This implies that the component has
67    /// exactly one node.
68    Acyclic,
69    /// Indicates that the component contains a cycle. A cyclic component may have one node with a
70    /// self-loop, or it may have multiple nodes.
71    Cyclic,
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
75struct SccInfo {
76    slice_end: usize,
77    kind: SccKind,
78}
79
80/// A sequence of strongly-connected components in a graph.
81#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
82pub struct Sccs<SccId: Id, NodeId: Id> {
83    raw_nodes: Vec<NodeId>,
84    scc_info: IdVec<SccId, SccInfo>,
85}
86
87impl<SccId: Id + Debug, NodeId: Id + Debug> Debug for Sccs<SccId, NodeId> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_map().entries(self).finish()
90    }
91}
92
93impl<SccId: Id, NodeId: Id> Default for Sccs<SccId, NodeId> {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99/// A single strongly-connected component in a graph.
100#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
101pub struct Scc<'a, NodeId: Id> {
102    /// Indicates if the component contains a cycle.
103    pub kind: SccKind,
104    /// The nodes in the component.
105    pub nodes: &'a [NodeId],
106}
107
108impl<SccId: Id, NodeId: Id> Sccs<SccId, NodeId> {
109    /// Create a new, empty sequence of components.
110    pub fn new() -> Self {
111        Sccs {
112            raw_nodes: Vec::new(),
113            scc_info: IdVec::new(),
114        }
115    }
116
117    /// Creates a new, empty sequence of components, with space preallocated to hold `capacity`
118    /// nodes.
119    pub fn with_node_capacity(capacity: usize) -> Self {
120        Sccs {
121            raw_nodes: Vec::with_capacity(capacity),
122            scc_info: IdVec::new(),
123        }
124    }
125
126    /// Pushes a new component of kind [`SccKind::Acyclic`] onto the end of the sequence.
127    ///
128    /// Returns the id of the newly-inserted component.
129    ///
130    /// # Panics
131    ///
132    /// Panics if the length of the sequence overflows `SccId`.
133    pub fn push_acyclic_component(&mut self, node: NodeId) -> SccId {
134        self.raw_nodes.push(node);
135        self.scc_info.push(SccInfo {
136            slice_end: self.raw_nodes.len(),
137            kind: SccKind::Acyclic,
138        })
139    }
140
141    /// Pushes a new component of kind [`SccKind::Cyclic`] onto the end of the sequence.
142    ///
143    /// Returns the id of the newly-inserted component.
144    ///
145    /// # Panics
146    ///
147    /// Panics if `nodes` is empty, or if the length of the sequence overflows `SccId`.
148    pub fn push_cyclic_component(&mut self, nodes: &[NodeId]) -> SccId {
149        if nodes.is_empty() {
150            panic!("SCC must contain at least one node");
151        }
152        self.raw_nodes.extend_from_slice(nodes);
153        self.scc_info.push(SccInfo {
154            slice_end: self.raw_nodes.len(),
155            kind: SccKind::Cyclic,
156        })
157    }
158
159    /// Returns the number of components in the sequence.
160    pub fn count(&self) -> Count<SccId> {
161        self.scc_info.count()
162    }
163
164    /// Returns the component with the given `id`.
165    ///
166    /// # Panics
167    ///
168    /// Panics if `id` is not contained in [`self.count()`](Sccs::count).
169    pub fn component<S: Borrow<SccId>>(&self, id: S) -> Scc<'_, NodeId> {
170        let id = *id.borrow();
171        let prev_id = id
172            .to_index()
173            .checked_sub(&SccId::Index::one())
174            .map(SccId::from_index);
175        let slice_start = match prev_id {
176            Some(prev_id) => self.scc_info[prev_id].slice_end,
177            None => 0,
178        };
179        let info = &self.scc_info[id];
180        Scc {
181            kind: info.kind,
182            nodes: &self.raw_nodes[slice_start..info.slice_end],
183        }
184    }
185}
186
187impl<'a, SccId: Id, NodeId: Id> IntoIterator for &'a Sccs<SccId, NodeId> {
188    type IntoIter = SccsIter<'a, SccId, NodeId>;
189    type Item = (SccId, Scc<'a, NodeId>);
190
191    fn into_iter(self) -> Self::IntoIter {
192        SccsIter {
193            sccs: self,
194            inner: self.count().into_iter(),
195        }
196    }
197}
198
199/// An iterator over the strongly-connected components in a graph.
200///
201/// This type is returned by [`Sccs::into_iter`](struct.Sccs.html#method.into_iter).
202#[derive(Debug)]
203pub struct SccsIter<'a, SccId: Id, NodeId: Id> {
204    sccs: &'a Sccs<SccId, NodeId>,
205    inner: IdRangeIter<SccId>,
206}
207
208impl<'a, SccId: Id, NodeId: Id> Iterator for SccsIter<'a, SccId, NodeId> {
209    type Item = (SccId, Scc<'a, NodeId>);
210
211    #[inline]
212    fn next(&mut self) -> Option<Self::Item> {
213        self.inner.next().map(|id| (id, self.sccs.component(id)))
214    }
215
216    #[inline]
217    fn size_hint(&self) -> (usize, Option<usize>) {
218        self.inner.size_hint()
219    }
220
221    #[inline]
222    fn nth(&mut self, n: usize) -> Option<Self::Item> {
223        self.inner.nth(n).map(|id| (id, self.sccs.component(id)))
224    }
225
226    #[inline]
227    fn last(self) -> Option<Self::Item> {
228        self.inner.last().map(|id| (id, self.sccs.component(id)))
229    }
230}
231
232impl<'a, SccId: Id, NodeId: Id> FusedIterator for SccsIter<'a, SccId, NodeId> {}
233
234impl<'a, SccId: Id, NodeId: Id> DoubleEndedIterator for SccsIter<'a, SccId, NodeId> {
235    #[inline]
236    fn next_back(&mut self) -> Option<Self::Item> {
237        self.inner
238            .next_back()
239            .map(|id| (id, self.sccs.component(id)))
240    }
241
242    #[inline]
243    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
244        self.inner
245            .nth_back(n)
246            .map(|id| (id, self.sccs.component(id)))
247    }
248}
249
250impl<'a, SccId: Id, NodeId: Id> ExactSizeIterator for SccsIter<'a, SccId, NodeId> {
251    #[inline]
252    fn len(&self) -> usize {
253        // We know the bounds fit inside a 'usize', because they're derived from the size of an
254        // IdVec.
255        self.inner.end_index().to_usize_unchecked() - self.inner.start_index().to_usize_unchecked()
256    }
257}
258
259/// Finds the strongly-connected components in a directed graph in dependency order.
260///
261/// The input graph is given implicitly by the `node_count` and `node_dependencies` arguments. The
262/// nodes of the input graph are given by the set of all ids in the range specified by `node_count`,
263/// and the edges outgoing from each node are given by the `node_dependencies` function. The
264/// `node_dependencies` function is guaranteed to be called exactly once per node in the graph.
265///
266/// The returned sequence of components is guaranteed to be given in *dependency order*, meaning
267/// that if `node_dependencies(n1)` contains a node `n2`, then `n2` is guaranteed to belong either
268/// to the same component as `n1`, or to an earlier component in the sequence.
269///
270/// # Panics
271///
272/// Panics if `node_dependencies` returns a node not contained in the range specified by
273/// `node_count`.
274///
275/// Panics if the number of connected components in the graph overflows `SccId`.
276pub fn find_components<SccId, NodeId, NodeDependenciesFn, NodeDependencies, Dependency>(
277    node_count: Count<NodeId>,
278    mut node_dependencies: NodeDependenciesFn,
279) -> Sccs<SccId, NodeId>
280where
281    SccId: Id,
282    NodeId: Id,
283    NodeDependenciesFn: FnMut(NodeId) -> NodeDependencies,
284    NodeDependencies: IntoIterator<Item = Dependency>,
285    Dependency: Borrow<NodeId>,
286{
287    // We use Tarjan's algorithm, performing the depth-first search using an explicit Vec-based
288    // stack instead of recursion to avoid stack overflows on large graphs.
289
290    #[derive(Clone, Copy, Debug)]
291    enum NodeState {
292        Unvisited,
293        OnSearchStack { index: u32, low_link: u32 },
294        OnSccStack { index: u32 },
295        Complete,
296    }
297
298    #[derive(Clone, Copy)]
299    enum Action<NodeId> {
300        TryVisit {
301            parent: Option<NodeId>,
302            node: NodeId,
303        },
304        FinishVisit {
305            parent: Option<NodeId>,
306            node: NodeId,
307        },
308    }
309
310    let node_capacity = node_count.to_value().to_usize().unwrap_or(usize::MAX);
311
312    let mut sccs = Sccs::with_node_capacity(node_capacity);
313
314    let mut node_states = IdVec::from_count_with(node_count, |_| NodeState::Unvisited);
315    let mut node_self_loops = IdMap::with_capacity(node_capacity);
316    let mut scc_stack = Vec::new();
317    let mut search_stack = Vec::new();
318    let mut next_index = 0;
319
320    for search_root in node_count {
321        search_stack.push(Action::TryVisit {
322            parent: None,
323            node: search_root,
324        });
325        while let Some(action) = search_stack.pop() {
326            match action {
327                Action::TryVisit { parent, node } => {
328                    match node_states[node] {
329                        NodeState::Unvisited => {
330                            node_states[node] = NodeState::OnSearchStack {
331                                index: next_index,
332                                low_link: next_index,
333                            };
334                            next_index += 1;
335                            scc_stack.push(node);
336
337                            search_stack.push(Action::FinishVisit { parent, node });
338                            // We need to explicitly track self-loops so that when we obtain a size-1
339                            // SCC we can determine if it's cyclic or acyclic.
340                            let mut has_self_loop = false;
341                            for dependency in node_dependencies(node) {
342                                let dependency = *dependency.borrow();
343                                if !node_count.contains(dependency) {
344                                    panic!(
345                                        "node id of type {} with index {} is out of bounds for \
346                                         node count {}",
347                                        std::any::type_name::<NodeId>(),
348                                        dependency.to_index(),
349                                        node_count.to_value()
350                                    );
351                                }
352                                if dependency == node {
353                                    has_self_loop = true;
354                                }
355                                search_stack.push(Action::TryVisit {
356                                    parent: Some(node),
357                                    node: dependency,
358                                });
359                            }
360                            node_self_loops.insert_vacant(node, has_self_loop);
361                        }
362
363                        NodeState::OnSearchStack { index, low_link: _ }
364                        | NodeState::OnSccStack { index } => {
365                            if let Some(parent) = parent {
366                                if let NodeState::OnSearchStack {
367                                    index: _,
368                                    low_link: parent_low_link,
369                                } = &mut node_states[parent]
370                                {
371                                    *parent_low_link = (*parent_low_link).min(index);
372                                } else {
373                                    unreachable!("parent should be on search stack");
374                                }
375                            }
376                        }
377
378                        NodeState::Complete => {}
379                    }
380                }
381
382                Action::FinishVisit { parent, node } => {
383                    let (index, low_link) =
384                        if let NodeState::OnSearchStack { index, low_link } = node_states[node] {
385                            (index, low_link)
386                        } else {
387                            unreachable!("node should be on search stack");
388                        };
389
390                    node_states[node] = NodeState::OnSccStack { index };
391
392                    if let Some(parent) = parent {
393                        if let NodeState::OnSearchStack {
394                            index: _,
395                            low_link: parent_low_link,
396                        } = &mut node_states[parent]
397                        {
398                            *parent_low_link = (*parent_low_link).min(low_link);
399                        } else {
400                            unreachable!("parent should be on search stack")
401                        }
402                    }
403
404                    if low_link == index {
405                        let mut scc_start = scc_stack.len();
406                        loop {
407                            scc_start -= 1;
408                            let scc_node = scc_stack[scc_start];
409                            debug_assert!(matches!(
410                                node_states[scc_node],
411                                NodeState::OnSccStack { .. }
412                            ));
413                            node_states[scc_node] = NodeState::Complete;
414                            if scc_node == node {
415                                break;
416                            }
417                        }
418                        let scc_slice = &scc_stack[scc_start..];
419                        if scc_slice.len() == 1 && !node_self_loops[node] {
420                            sccs.push_acyclic_component(scc_slice[0]);
421                        } else {
422                            sccs.push_cyclic_component(scc_slice);
423                        };
424                        scc_stack.truncate(scc_start);
425                    }
426                }
427            }
428        }
429    }
430
431    sccs
432}