use alloc::vec::Vec;
use core::{hash::BuildHasher, num::NonZeroUsize};
use smallvec::SmallVec;
use crate::schedule::graph::{DiGraph, GraphNodeId};
pub(crate) fn new_tarjan_scc<N: GraphNodeId, S: BuildHasher>(
graph: &DiGraph<N, S>,
) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
let unchecked_nodes = graph.nodes();
let nodes = graph
.nodes()
.map(|node| NodeData {
root_index: None,
pending: None,
neighbors: graph.neighbors(node),
})
.collect::<Vec<_>>();
TarjanScc {
graph,
unchecked_nodes,
index: 1, component_count: usize::MAX, nodes,
stack: Vec::new(),
visitation_stack: Vec::new(),
start: None,
index_adjustment: None,
}
}
struct NodeData<N: GraphNodeId, Neighbors: Iterator<Item = N>> {
root_index: Option<NonZeroUsize>,
pending: Option<N>,
neighbors: Neighbors,
}
struct TarjanScc<'graph, N, Hasher, AllNodes, Neighbors>
where
N: GraphNodeId,
Hasher: BuildHasher,
AllNodes: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
{
graph: &'graph DiGraph<N, Hasher>,
unchecked_nodes: AllNodes,
index: usize,
component_count: usize,
nodes: Vec<NodeData<N, Neighbors>>,
stack: Vec<N>,
visitation_stack: Vec<(N, bool)>,
start: Option<usize>,
index_adjustment: Option<usize>,
}
impl<
'graph,
N: GraphNodeId,
S: BuildHasher,
A: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
> TarjanScc<'graph, N, S, A, Neighbors>
{
fn next_scc(&mut self) -> Option<&[N]> {
if let (Some(start), Some(index_adjustment)) =
(self.start.take(), self.index_adjustment.take())
{
self.stack.truncate(start);
self.index -= index_adjustment; self.component_count -= 1;
}
loop {
while let Some((v, v_is_local_root)) = self.visitation_stack.pop() {
if let Some(start) = self.visit_once(v, v_is_local_root) {
return Some(&self.stack[start..]);
};
}
let Some(node) = self.unchecked_nodes.next() else {
break None;
};
let visited = self.nodes[self.graph.to_index(node)].root_index.is_some();
if !visited {
self.visitation_stack.push((node, true));
}
}
}
fn visit_once(&mut self, v: N, mut v_is_local_root: bool) -> Option<usize> {
let graph_index_v = self.graph.to_index(v);
let node_v = &mut self.nodes[graph_index_v];
if node_v.root_index.is_none() {
let v_index = self.index;
node_v.root_index = NonZeroUsize::new(v_index);
self.index += 1;
}
if let Some(w) = node_v.pending.take() {
let graph_index_w = self.graph.to_index(w);
if self.nodes[graph_index_w].root_index < self.nodes[graph_index_v].root_index {
self.nodes[graph_index_v].root_index = self.nodes[graph_index_w].root_index;
v_is_local_root = false;
}
}
while let Some(w) = self.nodes[graph_index_v].neighbors.next() {
let graph_index_w = self.graph.to_index(w);
if self.nodes[graph_index_w].root_index.is_none() {
self.visitation_stack.push((v, v_is_local_root));
self.visitation_stack.push((w, true));
self.nodes[graph_index_v].pending = Some(w);
return None;
}
if self.nodes[graph_index_w].root_index < self.nodes[graph_index_v].root_index {
self.nodes[graph_index_v].root_index = self.nodes[graph_index_w].root_index;
v_is_local_root = false;
}
}
if !v_is_local_root {
self.stack.push(v); return None;
}
let mut index_adjustment = 1;
let c = NonZeroUsize::new(self.component_count);
let nodes = &mut self.nodes;
let start = self
.stack
.iter()
.rposition(|&w| {
let graph_index_w = self.graph.to_index(w);
if nodes[graph_index_v].root_index > nodes[graph_index_w].root_index {
true
} else {
nodes[graph_index_w].root_index = c;
index_adjustment += 1;
false
}
})
.map(|x| x + 1)
.unwrap_or_default();
nodes[graph_index_v].root_index = c;
self.stack.push(v);
self.start = Some(start);
self.index_adjustment = Some(index_adjustment);
Some(start)
}
}
impl<
'graph,
N: GraphNodeId,
S: BuildHasher,
A: Iterator<Item = N>,
Neighbors: Iterator<Item = N>,
> Iterator for TarjanScc<'graph, N, S, A, Neighbors>
{
type Item = SmallVec<[N; 4]>;
fn next(&mut self) -> Option<Self::Item> {
let next = SmallVec::from_slice(self.next_scc()?);
Some(next)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.nodes.len()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schedule::graph::Direction;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct Node(i32);
impl GraphNodeId for Node {
type Adjacent = (Node, Direction);
type Edge = (Node, Node);
fn kind(&self) -> &'static str {
""
}
}
#[test]
fn a_b_c_a() {
let mut graph = DiGraph::<Node>::with_capacity(3, 3);
graph.add_node(Node(1));
graph.add_node(Node(2));
graph.add_node(Node(3));
graph.add_edge(Node(1), Node(2));
graph.add_edge(Node(2), Node(3));
graph.add_edge(Node(3), Node(1));
let mut tarjan = new_tarjan_scc(&graph);
let scc = tarjan.next().unwrap();
let none = tarjan.next();
assert_eq!(scc.len(), 3);
assert!(scc.contains(&Node(1)));
assert!(scc.contains(&Node(2)));
assert!(scc.contains(&Node(3)));
assert!(none.is_none());
}
}