use alloc::{vec, vec::Vec};
use core::{
fmt::{self, Debug},
hash::{BuildHasher, Hash},
};
use thiserror::Error;
use bevy_platform::{
collections::{HashMap, HashSet},
hash::FixedHasher,
};
use indexmap::IndexMap;
use smallvec::SmallVec;
use Direction::{Incoming, Outgoing};
pub trait GraphNodeId: Copy + Eq + Hash + Ord + Debug {
type Adjacent: Copy + Debug + From<(Self, Direction)> + Into<(Self, Direction)>;
type Edge: Copy + Eq + Hash + Debug + From<(Self, Self)> + Into<(Self, Self)>;
fn kind(&self) -> &'static str;
}
pub type UnGraph<N, S = FixedHasher> = Graph<false, N, S>;
pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
#[derive(Clone)]
pub struct Graph<const DIRECTED: bool, N: GraphNodeId, S = FixedHasher>
where
S: BuildHasher,
{
nodes: IndexMap<N, Vec<N::Adjacent>, S>,
edges: HashSet<N::Edge, S>,
}
impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Debug for Graph<DIRECTED, N, S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.nodes.fmt(f)
}
}
impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S> {
pub fn with_capacity(nodes: usize, edges: usize) -> Self
where
S: Default,
{
Self {
nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
edges: HashSet::with_capacity_and_hasher(edges, S::default()),
}
}
#[inline]
fn edge_key(a: N, b: N) -> N::Edge {
let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
N::Edge::from((a, b))
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn add_node(&mut self, n: N) {
self.nodes.entry(n).or_default();
}
pub fn remove_node(&mut self, n: N) {
let Some(links) = self.nodes.swap_remove(&n) else {
return;
};
let links = links.into_iter().map(N::Adjacent::into);
for (succ, dir) in links {
let edge = if dir == Outgoing {
Self::edge_key(n, succ)
} else {
Self::edge_key(succ, n)
};
self.remove_single_edge(succ, n, dir.opposite());
self.edges.remove(&edge);
}
}
pub fn contains_node(&self, n: N) -> bool {
self.nodes.contains_key(&n)
}
pub fn add_edge(&mut self, a: N, b: N) {
if self.edges.insert(Self::edge_key(a, b)) {
self.nodes
.entry(a)
.or_insert_with(|| Vec::with_capacity(1))
.push(N::Adjacent::from((b, Outgoing)));
if a != b {
self.nodes
.entry(b)
.or_insert_with(|| Vec::with_capacity(1))
.push(N::Adjacent::from((a, Incoming)));
}
}
}
fn remove_single_edge(&mut self, a: N, b: N, dir: Direction) -> bool {
let Some(sus) = self.nodes.get_mut(&a) else {
return false;
};
let Some(index) = sus
.iter()
.copied()
.map(N::Adjacent::into)
.position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
else {
return false;
};
sus.swap_remove(index);
true
}
pub fn remove_edge(&mut self, a: N, b: N) -> bool {
let exist1 = self.remove_single_edge(a, b, Outgoing);
let exist2 = if a != b {
self.remove_single_edge(b, a, Incoming)
} else {
exist1
};
let weight = self.edges.remove(&Self::edge_key(a, b));
debug_assert!(exist1 == exist2 && exist1 == weight);
weight
}
pub fn contains_edge(&self, a: N, b: N) -> bool {
self.edges.contains(&Self::edge_key(a, b))
}
pub fn reserve_nodes(&mut self, additional: usize) {
self.nodes.reserve(additional);
}
pub fn reserve_edges(&mut self, additional: usize) {
self.edges.reserve(additional);
}
pub fn nodes(&self) -> impl DoubleEndedIterator<Item = N> + ExactSizeIterator<Item = N> + '_ {
self.nodes.keys().copied()
}
pub fn neighbors(&self, a: N) -> impl DoubleEndedIterator<Item = N> + '_ {
let iter = match self.nodes.get(&a) {
Some(neigh) => neigh.iter(),
None => [].iter(),
};
iter.copied()
.map(N::Adjacent::into)
.filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
}
pub fn neighbors_directed(
&self,
a: N,
dir: Direction,
) -> impl DoubleEndedIterator<Item = N> + '_ {
let iter = match self.nodes.get(&a) {
Some(neigh) => neigh.iter(),
None => [].iter(),
};
iter.copied()
.map(N::Adjacent::into)
.filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
}
pub fn edges(&self, a: N) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
self.neighbors(a)
.map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
None => unreachable!(),
Some(_) => (a, b),
})
}
pub fn edges_directed(
&self,
a: N,
dir: Direction,
) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
self.neighbors_directed(a, dir).map(move |b| {
let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
match self.edges.get(&Self::edge_key(a, b)) {
None => unreachable!(),
Some(_) => (a, b),
}
})
}
pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (N, N)> + '_ {
self.edges.iter().copied().map(N::Edge::into)
}
pub(crate) fn to_index(&self, ix: N) -> usize {
self.nodes.get_index_of(&ix).unwrap()
}
pub fn try_convert<T>(self) -> Result<Graph<DIRECTED, T, S>, N::Error>
where
N: TryInto<T>,
T: GraphNodeId,
S: Default,
{
fn try_convert_node<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
(key, adj): (N, Vec<N::Adjacent>),
) -> Result<(T, Vec<T::Adjacent>), N::Error> {
let key = key.try_into()?;
let adj = adj
.into_iter()
.map(|node| {
let (id, dir) = node.into();
Ok(T::Adjacent::from((id.try_into()?, dir)))
})
.collect::<Result<_, N::Error>>()?;
Ok((key, adj))
}
fn try_convert_edge<N: GraphNodeId + TryInto<T>, T: GraphNodeId>(
edge: N::Edge,
) -> Result<T::Edge, N::Error> {
let (a, b) = edge.into();
Ok(T::Edge::from((a.try_into()?, b.try_into()?)))
}
let nodes = self
.nodes
.into_iter()
.map(try_convert_node::<N, T>)
.collect::<Result<_, N::Error>>()?;
let edges = self
.edges
.into_iter()
.map(try_convert_edge::<N, T>)
.collect::<Result<_, N::Error>>()?;
Ok(Graph { nodes, edges })
}
}
impl<const DIRECTED: bool, N, S> Default for Graph<DIRECTED, N, S>
where
N: GraphNodeId,
S: BuildHasher + Default,
{
fn default() -> Self {
Self::with_capacity(0, 0)
}
}
impl<N: GraphNodeId, S: BuildHasher> DiGraph<N, S> {
pub fn toposort(&self, mut scratch: Vec<N>) -> Result<Vec<N>, DiGraphToposortError<N>> {
if let Some((node, _)) = self.all_edges().find(|(left, right)| left == right) {
return Err(DiGraphToposortError::Loop(node));
}
scratch.clear();
scratch.reserve_exact(self.node_count().saturating_sub(scratch.capacity()));
let mut top_sorted_nodes = scratch;
let mut sccs_with_cycles = Vec::new();
for scc in self.iter_sccs() {
top_sorted_nodes.extend_from_slice(&scc);
if scc.len() > 1 {
sccs_with_cycles.push(scc);
}
}
if sccs_with_cycles.is_empty() {
top_sorted_nodes.reverse();
Ok(top_sorted_nodes)
} else {
let mut cycles = Vec::new();
for scc in &sccs_with_cycles {
cycles.append(&mut self.simple_cycles_in_component(scc));
}
Err(DiGraphToposortError::Cycle(cycles))
}
}
pub fn simple_cycles_in_component(&self, scc: &[N]) -> Vec<Vec<N>> {
let mut cycles = vec![];
let mut sccs = vec![SmallVec::from_slice(scc)];
while let Some(mut scc) = sccs.pop() {
let mut subgraph = DiGraph::<N>::with_capacity(scc.len(), 0);
for &node in &scc {
subgraph.add_node(node);
}
for &node in &scc {
for successor in self.neighbors(node) {
if subgraph.contains_node(successor) {
subgraph.add_edge(node, successor);
}
}
}
let mut path = Vec::with_capacity(subgraph.node_count());
let mut blocked: HashSet<_> =
HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
let mut unblock_together: HashMap<N, HashSet<N>> =
HashMap::with_capacity_and_hasher(subgraph.node_count(), Default::default());
let mut unblock_stack = Vec::with_capacity(subgraph.node_count());
let mut maybe_in_more_cycles: HashSet<N> =
HashSet::with_capacity_and_hasher(subgraph.node_count(), Default::default());
let mut stack = Vec::with_capacity(subgraph.node_count());
let root = scc.pop().unwrap();
path.clear();
path.push(root);
blocked.insert(root);
stack.clear();
stack.push((root, subgraph.neighbors(root)));
while !stack.is_empty() {
let &mut (ref node, ref mut successors) = stack.last_mut().unwrap();
if let Some(next) = successors.next() {
if next == root {
maybe_in_more_cycles.extend(path.iter());
cycles.push(path.clone());
} else if !blocked.contains(&next) {
maybe_in_more_cycles.remove(&next);
path.push(next);
blocked.insert(next);
stack.push((next, subgraph.neighbors(next)));
continue;
} else {
}
}
if successors.peekable().peek().is_none() {
if maybe_in_more_cycles.contains(node) {
unblock_stack.push(*node);
while let Some(n) = unblock_stack.pop() {
if blocked.remove(&n) {
let unblock_predecessors = unblock_together.entry(n).or_default();
unblock_stack.extend(unblock_predecessors.iter());
unblock_predecessors.clear();
}
}
} else {
for successor in subgraph.neighbors(*node) {
unblock_together.entry(successor).or_default().insert(*node);
}
}
path.pop();
stack.pop();
}
}
drop(stack);
subgraph.remove_node(root);
sccs.extend(subgraph.iter_sccs().filter(|scc| scc.len() > 1));
}
cycles
}
pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
super::tarjan_scc::new_tarjan_scc(self)
}
}
#[derive(Error, Debug)]
pub enum DiGraphToposortError<N: GraphNodeId> {
#[error("self-loop detected at node `{0:?}`")]
Loop(N),
#[error("cycles detected: {0:?}")]
Cycle(Vec<Vec<N>>),
}
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
#[repr(u8)]
pub enum Direction {
Outgoing = 0,
Incoming = 1,
}
impl Direction {
#[inline]
pub fn opposite(self) -> Self {
match self {
Self::Outgoing => Self::Incoming,
Self::Incoming => Self::Outgoing,
}
}
}
#[cfg(test)]
mod tests {
use crate::schedule::{NodeId, SystemKey};
use super::*;
use alloc::vec;
use slotmap::SlotMap;
#[test]
fn node_order_preservation() {
use NodeId::System;
let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
let mut graph = DiGraph::<NodeId>::default();
let sys1 = slotmap.insert(());
let sys2 = slotmap.insert(());
let sys3 = slotmap.insert(());
let sys4 = slotmap.insert(());
graph.add_node(System(sys1));
graph.add_node(System(sys2));
graph.add_node(System(sys3));
graph.add_node(System(sys4));
assert_eq!(
graph.nodes().collect::<Vec<_>>(),
vec![System(sys1), System(sys2), System(sys3), System(sys4)]
);
graph.remove_node(System(sys1));
assert_eq!(
graph.nodes().collect::<Vec<_>>(),
vec![System(sys4), System(sys2), System(sys3)]
);
graph.remove_node(System(sys4));
assert_eq!(
graph.nodes().collect::<Vec<_>>(),
vec![System(sys3), System(sys2)]
);
graph.remove_node(System(sys2));
assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(sys3)]);
graph.remove_node(System(sys3));
assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
}
#[test]
fn strongly_connected_components() {
use NodeId::System;
let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
let mut graph = DiGraph::<NodeId>::default();
let sys1 = slotmap.insert(());
let sys2 = slotmap.insert(());
let sys3 = slotmap.insert(());
let sys4 = slotmap.insert(());
let sys5 = slotmap.insert(());
let sys6 = slotmap.insert(());
graph.add_edge(System(sys1), System(sys2));
graph.add_edge(System(sys2), System(sys1));
graph.add_edge(System(sys2), System(sys3));
graph.add_edge(System(sys3), System(sys2));
graph.add_edge(System(sys4), System(sys5));
graph.add_edge(System(sys5), System(sys4));
graph.add_edge(System(sys6), System(sys2));
let sccs = graph
.iter_sccs()
.map(|scc| scc.to_vec())
.collect::<Vec<_>>();
assert_eq!(
sccs,
vec![
vec![System(sys3), System(sys2), System(sys1)],
vec![System(sys5), System(sys4)],
vec![System(sys6)]
]
);
}
}