use std::{
backtrace::Backtrace,
collections::{HashMap, HashSet, VecDeque},
};
pub type AdjacencyList<Node> = Vec<(Node, Node)>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Graph<Node> {
pub nodes: Vec<Node>,
pub edges: AdjacencyList<Node>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GraphAnalysis<Node> {
pub sorted: Option<Vec<Node>>,
pub cycles: Vec<Cycle<Node>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Cycle<Node> {
pub nodes: Vec<Node>,
pub edges: Vec<(Node, Node)>,
}
pub fn analyze_graph<Node: std::hash::Hash + Eq + Clone>(
graph: &Graph<Node>,
) -> GraphAnalysis<Node> {
let nodes = normalized_nodes(graph);
let sorted = topological_sort(&nodes, &graph.edges);
let cycles = find_cycles(&nodes, &graph.edges);
GraphAnalysis {
sorted: if cycles.is_empty() {
Some(sorted)
} else {
None
},
cycles,
}
}
pub fn sort_graph<Node: std::hash::Hash + Eq + Clone>(
graph: &Graph<Node>,
) -> Result<Vec<Node>, SortError<Node>> {
let analysis = analyze_graph(graph);
if let Some(sorted) = analysis.sorted {
Ok(sorted)
} else {
Err(SortError::cycle_detected(analysis.cycles))
}
}
fn normalized_nodes<Node: std::hash::Hash + Eq + Clone>(graph: &Graph<Node>) -> Vec<Node> {
let mut nodes = graph.nodes.clone();
let mut seen = nodes.iter().cloned().collect::<HashSet<_>>();
for (source, destination) in &graph.edges {
if seen.insert(source.clone()) {
nodes.push(source.clone());
}
if seen.insert(destination.clone()) {
nodes.push(destination.clone());
}
}
nodes
}
fn topological_sort<Node: std::hash::Hash + Eq + Clone>(
nodes: &[Node],
edges: &[(Node, Node)],
) -> Vec<Node> {
let mut dependencies_to_dependents_map: HashMap<Node, Vec<Node>> = HashMap::default();
let mut in_degree_map: HashMap<Node, usize> = HashMap::default();
for node in nodes {
in_degree_map.entry(node.clone()).or_insert(0);
}
for (src, dest) in edges {
dependencies_to_dependents_map
.entry(src.clone())
.or_default()
.push(dest.clone());
*in_degree_map.entry(dest.clone()).or_insert(0) += 1;
}
let mut queue: VecDeque<Node> = VecDeque::default();
for node in nodes {
if in_degree_map.get(node).is_some_and(|count| *count == 0) {
queue.push_back(node.clone());
}
}
let mut sorted: Vec<Node> = Vec::default();
while let Some(node_without_incoming_edges) = queue.pop_front() {
sorted.push(node_without_incoming_edges.clone());
in_degree_map.remove(&node_without_incoming_edges);
if let Some(neighbors) = dependencies_to_dependents_map.get(&node_without_incoming_edges) {
for neighbor in neighbors {
if let Some(count) = in_degree_map.get_mut(neighbor) {
*count -= 1;
if *count == 0 {
in_degree_map.remove(neighbor);
queue.push_back(neighbor.clone());
}
}
}
}
}
sorted
}
fn find_cycles<Node: std::hash::Hash + Eq + Clone>(
nodes: &[Node],
edges: &[(Node, Node)],
) -> Vec<Cycle<Node>> {
let mut adjacency = HashMap::<Node, Vec<Node>>::new();
for node in nodes {
adjacency.entry(node.clone()).or_default();
}
for (source, destination) in edges {
adjacency
.entry(source.clone())
.or_default()
.push(destination.clone());
}
let mut state = TarjanState {
adjacency: &adjacency,
index: 0,
indexes: HashMap::new(),
lowlinks: HashMap::new(),
stack: Vec::new(),
on_stack: HashSet::new(),
components: Vec::new(),
};
for node in nodes {
if !state.indexes.contains_key(node) {
state.strong_connect(node.clone());
}
}
state
.components
.into_iter()
.filter_map(|component| component_to_cycle(component, edges))
.collect()
}
struct TarjanState<'a, Node> {
adjacency: &'a HashMap<Node, Vec<Node>>,
index: usize,
indexes: HashMap<Node, usize>,
lowlinks: HashMap<Node, usize>,
stack: Vec<Node>,
on_stack: HashSet<Node>,
components: Vec<Vec<Node>>,
}
impl<Node: std::hash::Hash + Eq + Clone> TarjanState<'_, Node> {
fn strong_connect(&mut self, node: Node) {
self.indexes.insert(node.clone(), self.index);
self.lowlinks.insert(node.clone(), self.index);
self.index += 1;
self.stack.push(node.clone());
self.on_stack.insert(node.clone());
for neighbor in self.adjacency.get(&node).into_iter().flatten() {
if !self.indexes.contains_key(neighbor) {
self.strong_connect(neighbor.clone());
let neighbor_lowlink = self.lowlinks[neighbor];
let node_lowlink = self.lowlinks[&node];
self.lowlinks
.insert(node.clone(), node_lowlink.min(neighbor_lowlink));
} else if self.on_stack.contains(neighbor) {
let neighbor_index = self.indexes[neighbor];
let node_lowlink = self.lowlinks[&node];
self.lowlinks
.insert(node.clone(), node_lowlink.min(neighbor_index));
}
}
if self.indexes[&node] == self.lowlinks[&node] {
let mut component = Vec::new();
while let Some(item) = self.stack.pop() {
self.on_stack.remove(&item);
component.push(item.clone());
if item == node {
break;
}
}
component.reverse();
self.components.push(component);
}
}
}
fn component_to_cycle<Node: std::hash::Hash + Eq + Clone>(
nodes: Vec<Node>,
edges: &[(Node, Node)],
) -> Option<Cycle<Node>> {
let node_set = nodes.iter().cloned().collect::<HashSet<_>>();
let cycle_edges = edges
.iter()
.filter(|(source, destination)| node_set.contains(source) && node_set.contains(destination))
.cloned()
.collect::<Vec<_>>();
if nodes.len() > 1
|| cycle_edges
.iter()
.any(|(source, destination)| source == destination)
{
Some(Cycle {
nodes,
edges: cycle_edges,
})
} else {
None
}
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SortError<Node> {
cycles: Vec<Cycle<Node>>,
#[cfg_attr(
feature = "serde",
serde(skip, default = "std::backtrace::Backtrace::capture")
)]
backtrace: Backtrace,
}
impl<Node> SortError<Node> {
fn cycle_detected(cycles: Vec<Cycle<Node>>) -> Self {
Self {
cycles,
backtrace: Backtrace::capture(),
}
}
pub fn cycles(&self) -> &[Cycle<Node>] {
&self.cycles
}
pub fn backtrace(&self) -> &Backtrace {
&self.backtrace
}
}
impl<Node> std::error::Error for SortError<Node> where Node: core::fmt::Debug + core::fmt::Display {}
impl<Node: std::fmt::Display> std::fmt::Display for SortError<Node> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "cycle detected in graph")?;
for cycle in &self.cycles {
writeln!(f, "cycle nodes:")?;
for node in &cycle.nodes {
write!(f, "{} ", node)?;
}
writeln!(f, "\nedges:")?;
for (src, dest) in &cycle.edges {
writeln!(f, " {} -> {}", src, dest)?;
}
}
write!(f, "backtrace:\n{}", self.backtrace)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_graph_is_ok_integer() {
let nodes: Vec<usize> = vec![2, 3, 5, 7, 8, 9, 10, 11];
let edges: Vec<(usize, usize)> = vec![
(5, 11),
(7, 8),
(7, 11),
(3, 8),
(3, 10),
(11, 2),
(11, 9),
(11, 10),
(8, 9),
];
let graph: Graph<usize> = Graph { nodes, edges };
let sorted = sort_graph::<usize>(&graph);
assert!(sorted.is_ok());
}
#[test]
fn test_sort_graph_is_err_integer() {
let nodes: Vec<usize> = vec![2, 3, 5, 7, 8, 9, 10, 11];
let edges: Vec<(usize, usize)> = vec![
(5, 11),
(7, 8),
(7, 11),
(3, 8),
(3, 10),
(11, 2),
(11, 9),
(11, 10),
(8, 9),
(9, 11), ];
let graph: Graph<usize> = Graph { nodes, edges };
let sorted = sort_graph::<usize>(&graph);
assert!(sorted.is_err());
}
#[test]
fn test_sort_graph_is_ok_strings() {
let nodes = vec![
"shirt",
"hoodie",
"socks",
"underwear",
"pants",
"shoes",
"glasses",
"watch",
"school",
];
let edges = vec![
("shirt", "hoodie"),
("hoodie", "school"),
("underwear", "pants"),
("pants", "shoes"),
("socks", "shoes"),
("shoes", "school"),
];
let graph: Graph<&str> = Graph { nodes, edges };
let sorted = sort_graph::<&str>(&graph);
assert!(sorted.is_ok());
}
#[test]
fn test_sort_graph_keeps_node_order_when_independent() {
let graph = Graph {
nodes: vec!["first", "second", "third"],
edges: Vec::new(),
};
assert_eq!(
sort_graph::<&str>(&graph).expect("Expected graph to sort"),
vec!["first", "second", "third"]
);
}
#[test]
fn test_is_err_strings() {
let nodes = vec![
"shirt",
"hoodie",
"socks",
"underwear",
"pants",
"shoes",
"glasses",
"watch",
"school",
];
let edges = vec![
("shirt", "hoodie"),
("hoodie", "school"),
("school", "shirt"), ("underwear", "pants"),
("pants", "shoes"),
("socks", "shoes"),
("shoes", "school"),
];
let graph: Graph<&str> = Graph { nodes, edges };
let sorted = sort_graph::<&str>(&graph);
assert!(sorted.is_err());
}
#[test]
fn analyze_graph_returns_cycle_participants() {
let graph = Graph {
nodes: vec!["database", "orm", "api", "frontend"],
edges: vec![
("database", "orm"),
("orm", "api"),
("api", "database"),
("api", "frontend"),
],
};
let analysis = analyze_graph(&graph);
assert_eq!(analysis.sorted, None);
assert_eq!(analysis.cycles.len(), 1);
assert_eq!(analysis.cycles[0].nodes, vec!["database", "orm", "api"]);
assert_eq!(
analysis.cycles[0].edges,
vec![("database", "orm"), ("orm", "api"), ("api", "database")]
);
}
#[test]
fn analyze_graph_returns_self_cycle() {
let graph = Graph {
nodes: vec!["project", "author"],
edges: vec![("project", "project"), ("project", "author")],
};
let analysis = analyze_graph(&graph);
assert_eq!(analysis.sorted, None);
assert_eq!(analysis.cycles.len(), 1);
assert_eq!(analysis.cycles[0].nodes, vec!["project"]);
assert_eq!(analysis.cycles[0].edges, vec![("project", "project")]);
}
}