use std::collections::{HashMap, VecDeque};
pub type DAGAsAdjacencyList<Node> = Vec<(Node, Node)>;
#[derive(Debug, Clone)]
pub struct Graph<Node> {
pub nodes: Vec<Node>,
pub edges: DAGAsAdjacencyList<Node>,
}
pub fn sort_graph<Node: std::hash::Hash + Eq + Clone>(
graph: &Graph<Node>,
) -> Result<Vec<Node>, SortError<Node>> {
let mut dependencies_to_dependents_map: HashMap<Node, Vec<Node>> = HashMap::default();
let mut in_degree_map: HashMap<Node, usize> = HashMap::default();
for node in &graph.nodes {
in_degree_map.entry(node.clone()).or_insert(0);
}
for (src, dest) in &graph.edges {
dependencies_to_dependents_map
.entry(src.clone())
.or_default()
.push(dest.clone());
*in_degree_map.entry(dest.clone()).or_insert(0) += 1;
}
let mut queue: VecDeque<Node> = VecDeque::default();
for node in &graph.nodes {
if in_degree_map.get(node).is_some_and(|count| *count == 0) {
queue.push_back(node.clone());
}
}
let mut sorted: Vec<Node> = Vec::default();
while let Some(node_without_incoming_edges) = queue.pop_front() {
sorted.push(node_without_incoming_edges.clone());
in_degree_map.remove(&node_without_incoming_edges);
for neighbor in dependencies_to_dependents_map
.get(&node_without_incoming_edges)
.unwrap_or(&vec![])
{
if let Some(count) = in_degree_map.get_mut(neighbor) {
*count -= 1;
if *count == 0 {
in_degree_map.remove(neighbor);
queue.push_back(neighbor.clone());
}
}
}
}
if in_degree_map.is_empty() {
Ok(sorted)
} else {
Err(SortError::CycleDetected(graph.edges.clone()))
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum SortError<Node> {
CycleDetected(Vec<(Node, Node)>),
}
impl<Node> std::error::Error for SortError<Node> where
Node: Clone + Ord + core::fmt::Display + core::fmt::Debug
{
}
impl<Node: Clone + Ord + std::fmt::Display + std::fmt::Debug> std::fmt::Display
for SortError<Node>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
SortError::CycleDetected(edges) => {
writeln!(f, "Cycle detected in the following DAG:")?;
let mut unique_nodes = std::collections::BTreeSet::new();
for (src, dest) in edges.iter() {
unique_nodes.insert(src.clone());
unique_nodes.insert(dest.clone());
}
let sorted_nodes: Vec<Node> = unique_nodes.into_iter().collect();
writeln!(f, "Nodes:")?;
for node in &sorted_nodes {
write!(f, "{} ", node)?;
}
writeln!(f, "\n")?;
writeln!(f, "Edges:")?;
for (src, dest) in edges.iter() {
if src < dest {
writeln!(f, " {} → {}", src, dest)?;
} else {
writeln!(f, " {} ↖ {}", src, dest)?;
}
}
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_graph_is_ok_integer() {
let nodes: Vec<usize> = vec![2, 3, 5, 7, 8, 9, 10, 11];
let edges: Vec<(usize, usize)> = vec![
(5, 11),
(7, 8),
(7, 11),
(3, 8),
(3, 10),
(11, 2),
(11, 9),
(11, 10),
(8, 9),
];
let graph: Graph<usize> = Graph { nodes, edges };
let sorted = sort_graph::<usize>(&graph);
assert!(sorted.is_ok());
}
#[test]
fn test_sort_graph_is_err_integer() {
let nodes: Vec<usize> = vec![2, 3, 5, 7, 8, 9, 10, 11];
let edges: Vec<(usize, usize)> = vec![
(5, 11),
(7, 8),
(7, 11),
(3, 8),
(3, 10),
(11, 2),
(11, 9),
(11, 10),
(8, 9),
(9, 11), ];
let graph: Graph<usize> = Graph { nodes, edges };
let sorted = sort_graph::<usize>(&graph);
assert!(sorted.is_err());
}
#[test]
fn test_sort_graph_is_ok_strings() {
let nodes = vec![
"shirt",
"hoodie",
"socks",
"underwear",
"pants",
"shoes",
"glasses",
"watch",
"school",
];
let edges = vec![
("shirt", "hoodie"),
("hoodie", "school"),
("underwear", "pants"),
("pants", "shoes"),
("socks", "shoes"),
("shoes", "school"),
];
let graph: Graph<&str> = Graph { nodes, edges };
let sorted = sort_graph::<&str>(&graph);
assert!(sorted.is_ok());
}
#[test]
fn test_sort_graph_keeps_node_order_when_independent() {
let graph = Graph {
nodes: vec!["first", "second", "third"],
edges: Vec::new(),
};
assert_eq!(
sort_graph::<&str>(&graph).expect("Expected graph to sort"),
vec!["first", "second", "third"]
);
}
#[test]
fn test_is_err_strings() {
let nodes = vec![
"shirt",
"hoodie",
"socks",
"underwear",
"pants",
"shoes",
"glasses",
"watch",
"school",
];
let edges = vec![
("shirt", "hoodie"),
("hoodie", "school"),
("school", "shirt"), ("underwear", "pants"),
("pants", "shoes"),
("socks", "shoes"),
("shoes", "school"),
];
let graph: Graph<&str> = Graph { nodes, edges };
let sorted = sort_graph::<&str>(&graph);
assert!(sorted.is_err());
}
}