use {
crate::{SEED, reflection::Type},
ahash::{HashMap, RandomState},
core::{
mem,
num::NonZero,
ops::{Deref, DerefMut},
},
std::collections::{BTreeMap, BTreeSet},
};
#[derive(Clone, Debug)]
pub struct Metadata {
pub cardinality: Option<NonZero<usize>>,
pub edges: BTreeSet<Type>,
pub ty: Type,
}
#[derive(Debug)]
pub enum Node {
Parent(Type),
Root(Metadata),
}
#[derive(Debug)]
pub struct StronglyConnectedComponents {
nodes: HashMap<Type, Node>,
}
#[derive(Debug, Eq, PartialEq)]
pub struct TarjanMetadata {
index: usize,
lowlink: usize,
}
impl StronglyConnectedComponents {
#[inline]
pub fn merge(&mut self, lhs: Type, rhs: Type) {
merge(&mut self.nodes, lhs, rhs)
}
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
nodes: HashMap::with_hasher(RandomState::with_seed(usize::from(SEED))),
}
}
#[inline]
pub fn root(&mut self, element: Type) -> Option<Type> {
root(&mut self.nodes, element)
}
#[inline]
pub fn tarjan(&mut self) {
let mut index = 0;
let mut metadata = BTreeMap::new();
let mut stack = vec![];
let vertices: Vec<Type> = self.nodes.keys().copied().collect();
for vertex in vertices {
let () = self.tarjan_dfs(vertex, &mut metadata, &mut stack, &mut index);
}
}
#[inline]
#[expect(clippy::expect_used, clippy::panic, reason = "internal invariants")]
pub fn tarjan_dfs(
&mut self,
vertex: Type,
metadata: &mut BTreeMap<Type, TarjanMetadata>,
stack: &mut Vec<Type>,
index: &mut usize,
) {
let overwritten: Option<_> = metadata.insert(
vertex,
TarjanMetadata {
index: *index,
lowlink: *index,
},
);
debug_assert_eq!(overwritten, None, "internal `pbt` error: TOCTOU");
{
#![expect(clippy::arithmetic_side_effects, reason = "constrained by hardware")]
*index += 1;
}
stack.push(vertex);
let Some(node) = self.nodes.get(&vertex) else {
panic!("internal `pbt` error: unregistered SCC element `{vertex:#?}`")
};
let Node::Root(Metadata { ref edges, .. }) = *node else {
return;
};
let edges: Vec<Type> = edges.iter().copied().collect();
for successor in edges {
if let Some(&TarjanMetadata {
index: successor_index,
..
}) = metadata.get(&successor)
{
if stack.iter().rev().any(|&t| t == vertex) {
let this_lowlink = &mut metadata
.get_mut(&vertex)
.expect("internal `pbt` error: schrodinger's metadata")
.lowlink;
let new_lowlink = (*this_lowlink).min(successor_index);
*this_lowlink = new_lowlink;
}
} else {
let () = self.tarjan_dfs(successor, metadata, stack, index);
let successor_lowlink = metadata
.get(&successor)
.expect("internal `pbt` error: schrodinger's metadata")
.lowlink;
let this_lowlink = &mut metadata
.get_mut(&vertex)
.expect("internal `pbt` error: schrodinger's metadata")
.lowlink;
let new_lowlink = (*this_lowlink).min(successor_lowlink);
*this_lowlink = new_lowlink;
}
}
let metadata = metadata
.get(&vertex)
.expect("internal `pbt` error: schrodinger's metadata");
if metadata.index == metadata.lowlink {
'pop: loop {
let popped = stack
.pop()
.expect("internal `pbt` error: schrodinger's stack");
let () = self.merge(vertex, popped);
if popped == vertex {
break 'pop;
}
}
}
}
}
impl Default for StronglyConnectedComponents {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl Deref for StronglyConnectedComponents {
type Target = HashMap<Type, Node>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.nodes
}
}
impl DerefMut for StronglyConnectedComponents {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.nodes
}
}
#[inline]
#[expect(clippy::expect_used, clippy::panic, reason = "internal invariants")]
fn merge(nodes: &mut HashMap<Type, Node>, lhs: Type, rhs: Type) {
let mut lhs = root(nodes, lhs).expect("internal `pbt` error: merging unregistered SCC element");
let mut rhs = root(nodes, rhs).expect("internal `pbt` error: merging unregistered SCC element");
if lhs == rhs {
return;
}
let Some(&Node::Root(ref lhs_meta)) = nodes.get(&lhs) else {
panic!("internal `pbt` error: `scc::root` is not idempotent")
};
let Some(&Node::Root(ref rhs_meta)) = nodes.get(&rhs) else {
panic!("internal `pbt` error: `scc::root` is not idempotent")
};
if rhs_meta.cardinality > lhs_meta.cardinality {
let () = mem::swap(&mut lhs, &mut rhs);
}
let edges: Vec<_> = lhs_meta
.edges
.iter()
.chain(&rhs_meta.edges)
.copied()
.collect();
let meta = Metadata {
#[expect(clippy::arithmetic_side_effects, reason = "constrained by hardware")]
cardinality: NonZero::new(
lhs_meta.cardinality.map_or(1, NonZero::get) + rhs_meta.cardinality.map_or(1, NonZero::get),
),
edges: edges
.into_iter()
.map(|edge| {
root(nodes, edge).expect("internal `pbt` error: invalid (transitive) parent in SCC")
})
.filter(|&root| root != lhs)
.collect(),
ty: lhs,
};
let _: Option<Node> = nodes.insert(rhs, Node::Parent(lhs));
let _: Option<Node> = nodes.insert(lhs, Node::Root(meta));
}
#[inline]
#[expect(
clippy::expect_used,
clippy::unwrap_in_result,
reason = "internal invariants"
)]
fn root(nodes: &mut HashMap<Type, Node>, element: Type) -> Option<Type> {
let node = nodes.get(&element)?;
let parent = match *node {
Node::Root(ref metadata) => return Some(metadata.ty),
Node::Parent(parent) => parent,
};
let root =
root(nodes, parent).expect("internal `pbt` error: invalid (transitive) parent in SCC");
let () = debug_assert_ne!(root, element, "internal `pbt` error: union-find cycle");
let _: Option<Node> = nodes.insert(element, Node::Parent(root));
Some(root)
}