use std::collections::VecDeque;
use crate::utils::graph::{GraphBase, NodeId, Predecessors, Successors};
pub fn topological_sort<G>(graph: &G) -> Option<Vec<NodeId>>
where
G: GraphBase + Successors + Predecessors,
{
let node_count = graph.node_count();
if node_count == 0 {
return Some(Vec::new());
}
let mut in_degree: Vec<usize> = vec![0; node_count];
for node in graph.node_ids() {
for _ in graph.predecessors(node) {
in_degree[node.index()] += 1;
}
}
let mut queue: VecDeque<NodeId> = VecDeque::new();
for node in graph.node_ids() {
if in_degree[node.index()] == 0 {
queue.push_back(node);
}
}
let mut result = Vec::with_capacity(node_count);
while let Some(node) = queue.pop_front() {
result.push(node);
for successor in graph.successors(node) {
in_degree[successor.index()] -= 1;
if in_degree[successor.index()] == 0 {
queue.push_back(successor);
}
}
}
if result.len() == node_count {
Some(result)
} else {
None
}
}
#[cfg(test)]
mod tests {
use crate::utils::graph::{algorithms::topological::topological_sort, DirectedGraph, NodeId};
#[test]
fn test_topological_sort_empty_graph() {
let graph: DirectedGraph<(), ()> = DirectedGraph::new();
let result = topological_sort(&graph);
assert_eq!(result, Some(Vec::new()));
}
#[test]
fn test_topological_sort_single_node() {
let mut graph: DirectedGraph<(), ()> = DirectedGraph::new();
let a = graph.add_node(());
let result = topological_sort(&graph);
assert_eq!(result, Some(vec![a]));
}
#[test]
fn test_topological_sort_linear() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(b, c, ()).unwrap();
let result = topological_sort(&graph);
assert!(result.is_some());
assert_eq!(result.unwrap(), vec![a, b, c]);
}
#[test]
fn test_topological_sort_diamond() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
let d = graph.add_node("D");
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(a, c, ()).unwrap();
graph.add_edge(b, d, ()).unwrap();
graph.add_edge(c, d, ()).unwrap();
let result = topological_sort(&graph);
assert!(result.is_some());
let order = result.unwrap();
assert_eq!(order.len(), 4);
let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap();
assert!(pos(a) < pos(b));
assert!(pos(a) < pos(c));
assert!(pos(b) < pos(d));
assert!(pos(c) < pos(d));
}
#[test]
fn test_topological_sort_simple_cycle() {
let mut graph: DirectedGraph<(), ()> = DirectedGraph::new();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(b, c, ()).unwrap();
graph.add_edge(c, a, ()).unwrap();
assert!(topological_sort(&graph).is_none());
}
#[test]
fn test_topological_sort_self_loop() {
let mut graph: DirectedGraph<(), ()> = DirectedGraph::new();
let a = graph.add_node(());
graph.add_edge(a, a, ()).unwrap();
assert!(topological_sort(&graph).is_none());
}
#[test]
fn test_topological_sort_disconnected_components() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
let d = graph.add_node("D");
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(c, d, ()).unwrap();
let result = topological_sort(&graph);
assert!(result.is_some());
let order = result.unwrap();
assert_eq!(order.len(), 4);
let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap();
assert!(pos(a) < pos(b));
assert!(pos(c) < pos(d));
}
#[test]
fn test_topological_sort_partial_cycle() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
let d = graph.add_node("D");
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(b, c, ()).unwrap();
graph.add_edge(c, d, ()).unwrap();
graph.add_edge(d, b, ()).unwrap();
assert!(topological_sort(&graph).is_none());
}
#[test]
fn test_topological_sort_multiple_valid_orderings() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
graph.add_edge(a, c, ()).unwrap();
graph.add_edge(b, c, ()).unwrap();
let result = topological_sort(&graph);
assert!(result.is_some());
let order = result.unwrap();
assert_eq!(order.len(), 3);
let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap();
assert!(pos(a) < pos(c));
assert!(pos(b) < pos(c));
}
#[test]
fn test_topological_sort_wide_dag() {
let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new();
let root = graph.add_node("Root");
let children: Vec<NodeId> = (0..5)
.map(|_| {
let child = graph.add_node("Child");
graph.add_edge(root, child, ()).unwrap();
child
})
.collect();
let result = topological_sort(&graph);
assert!(result.is_some());
let order = result.unwrap();
assert_eq!(order.len(), 6);
assert_eq!(order[0], root);
for child in children {
assert!(order.contains(&child));
}
}
#[test]
fn test_topological_sort_deep_dag() {
let mut graph: DirectedGraph<usize, ()> = DirectedGraph::new();
let nodes: Vec<NodeId> = (0..100).map(|i| graph.add_node(i)).collect();
for i in 0..99 {
graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap();
}
let result = topological_sort(&graph);
assert!(result.is_some());
let order = result.unwrap();
assert_eq!(order.len(), 100);
for i in 0..100 {
assert_eq!(order[i], nodes[i]);
}
}
}