use super::AlgorithmError;
use crate::graph::{Graph, NodeIndex};
use crate::visited::TriStateVisitedTracker;
fn topological_sort_dfs_visit<G, NI, VT>(
graph: &G,
node: NI,
visited: &mut VT,
sorted_nodes: &mut [NI],
sort_index: &mut usize,
) -> Result<(), AlgorithmError<NI>>
where
G: Graph<NI>,
NI: NodeIndex,
VT: TriStateVisitedTracker<NI> + ?Sized,
AlgorithmError<NI>: From<G::Error>,
{
if visited.is_visiting(&node) {
return Err(AlgorithmError::CycleDetected);
}
if visited.is_visited(&node) {
return Ok(());
}
visited
.mark_visiting(&node)
.map_err(|_| AlgorithmError::VisitedTrackerCapacityExceeded)?;
for next_node in graph.outgoing_edges(node)? {
topological_sort_dfs_visit(graph, next_node, visited, sorted_nodes, sort_index)?;
}
visited
.mark_visited(&node)
.map_err(|_| AlgorithmError::VisitedTrackerCapacityExceeded)?;
if *sort_index >= sorted_nodes.len() {
return Err(AlgorithmError::ResultCapacityExceeded);
}
sorted_nodes[*sort_index] = node;
*sort_index += 1;
Ok(())
}
pub fn topological_sort_dfs<'a, G, NI, VT>(
graph: &G,
visited: &mut VT,
sorted_nodes: &'a mut [NI],
) -> Result<&'a [NI], AlgorithmError<NI>>
where
G: Graph<NI>,
NI: NodeIndex,
VT: TriStateVisitedTracker<NI> + ?Sized,
AlgorithmError<NI>: From<G::Error>,
{
visited.reset();
let mut sort_index = 0;
for node in graph.iter_nodes()? {
if visited.is_unvisited(&node) {
topological_sort_dfs_visit(graph, node, visited, sorted_nodes, &mut sort_index)?;
}
}
let result_slice = &mut sorted_nodes[..sort_index];
result_slice.reverse();
Ok(result_slice)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::edgelist::edge_list::EdgeList;
use crate::visited::NodeState;
#[test]
fn test_topological_sort_simple() {
let graph = EdgeList::<8, _, _>::new([(0usize, 1usize), (1, 2)]);
let mut visited = [NodeState::Unvisited; 8];
let mut sorted_nodes = [0usize; 8];
let result =
topological_sort_dfs(&graph, visited.as_mut_slice(), &mut sorted_nodes).unwrap();
assert_eq!(result, &[0, 1, 2]);
}
#[test]
fn test_topological_sort_complex() {
let graph = EdgeList::<8, _, _>::new([(0usize, 1usize), (0, 2), (1, 3), (2, 3)]);
let mut visited = [NodeState::Unvisited; 8];
let mut sorted_nodes = [0usize; 8];
let result =
topological_sort_dfs(&graph, visited.as_mut_slice(), &mut sorted_nodes).unwrap();
assert_eq!(result.len(), 4);
assert_eq!(result[0], 0);
assert_eq!(result[result.len() - 1], 3);
assert!(result.contains(&1));
assert!(result.contains(&2));
}
#[test]
fn test_topological_sort_cycle_detection() {
let graph = EdgeList::<8, _, _>::new([(0usize, 1usize), (1, 2), (2, 0)]);
let mut visited = [NodeState::Unvisited; 8];
let mut sorted_nodes = [0usize; 8];
let error = topological_sort_dfs(&graph, visited.as_mut_slice(), &mut sorted_nodes);
assert!(matches!(error, Err(AlgorithmError::CycleDetected)));
}
#[test]
fn test_topological_sort_disconnected() {
let graph = EdgeList::<8, _, _>::new([(0usize, 1usize), (2, 3)]);
let mut visited = [NodeState::Unvisited; 8];
let mut sorted_nodes = [0usize; 8];
let result =
topological_sort_dfs(&graph, visited.as_mut_slice(), &mut sorted_nodes).unwrap();
assert_eq!(result.len(), 4);
let pos_0 = result.iter().position(|&x| x == 0).unwrap();
let pos_1 = result.iter().position(|&x| x == 1).unwrap();
let pos_2 = result.iter().position(|&x| x == 2).unwrap();
let pos_3 = result.iter().position(|&x| x == 3).unwrap();
assert!(pos_0 < pos_1); assert!(pos_2 < pos_3); }
}