use crate::storage::CsrGraph;
use crate::NodeId;
use anyhow::{anyhow, Result};
#[derive(Clone, Copy, PartialEq, Eq)]
enum NodeState {
Unvisited,
InStack,
Finished,
}
#[must_use]
pub fn is_cyclic(graph: &CsrGraph) -> bool {
let n = graph.num_nodes();
if n == 0 {
return false;
}
let mut state = vec![NodeState::Unvisited; n];
for start in 0..n {
if state[start] == NodeState::Unvisited && has_cycle_from(graph, start, &mut state) {
return true;
}
}
false
}
fn has_cycle_from(graph: &CsrGraph, node: usize, state: &mut [NodeState]) -> bool {
state[node] = NodeState::InStack;
#[allow(clippy::cast_possible_truncation)]
let neighbors = graph.outgoing_neighbors(NodeId(node as u32));
if let Ok(neighbors) = neighbors {
for &neighbor in neighbors {
let neighbor_idx = neighbor as usize;
match state[neighbor_idx] {
NodeState::InStack => {
return true;
}
NodeState::Unvisited => {
if has_cycle_from(graph, neighbor_idx, state) {
return true;
}
}
NodeState::Finished => {
}
}
}
}
state[node] = NodeState::Finished;
false
}
pub fn toposort(graph: &CsrGraph) -> Result<Vec<NodeId>> {
let n = graph.num_nodes();
if n == 0 {
return Ok(Vec::new());
}
let mut state = vec![NodeState::Unvisited; n];
let mut result = Vec::with_capacity(n);
for start in 0..n {
if state[start] == NodeState::Unvisited {
toposort_dfs(graph, start, &mut state, &mut result)?;
}
}
result.reverse();
Ok(result)
}
fn toposort_dfs(
graph: &CsrGraph,
node: usize,
state: &mut [NodeState],
result: &mut Vec<NodeId>,
) -> Result<()> {
state[node] = NodeState::InStack;
#[allow(clippy::cast_possible_truncation)]
let neighbors = graph.outgoing_neighbors(NodeId(node as u32))?;
for &neighbor in neighbors {
let neighbor_idx = neighbor as usize;
match state[neighbor_idx] {
NodeState::InStack => {
return Err(anyhow!("Cycle detected: cannot compute topological order"));
}
NodeState::Unvisited => {
toposort_dfs(graph, neighbor_idx, state, result)?;
}
NodeState::Finished => {
}
}
}
state[node] = NodeState::Finished;
#[allow(clippy::cast_possible_truncation)]
result.push(NodeId(node as u32));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_graph_not_cyclic() {
let graph = CsrGraph::new();
assert!(!is_cyclic(&graph));
}
#[test]
fn test_empty_graph_toposort() {
let graph = CsrGraph::new();
let order = toposort(&graph).unwrap();
assert!(order.is_empty());
}
#[test]
fn test_single_node_not_cyclic() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
assert!(!is_cyclic(&graph));
}
#[test]
fn test_self_loop_is_cyclic() {
let edges = vec![(NodeId(0), NodeId(0), 1.0)]; let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(is_cyclic(&graph));
}
#[test]
fn test_simple_dag_not_cyclic() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(!is_cyclic(&graph));
}
#[test]
fn test_simple_cycle() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(1), NodeId(2), 1.0),
(NodeId(2), NodeId(0), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(is_cyclic(&graph));
}
#[test]
fn test_diamond_dag_not_cyclic() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(2), 1.0),
(NodeId(1), NodeId(3), 1.0),
(NodeId(2), NodeId(3), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(!is_cyclic(&graph));
}
#[test]
fn test_toposort_simple_chain() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let order = toposort(&graph).unwrap();
assert_eq!(order, vec![NodeId(0), NodeId(1), NodeId(2)]);
}
#[test]
fn test_toposort_diamond() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(2), 1.0),
(NodeId(1), NodeId(3), 1.0),
(NodeId(2), NodeId(3), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let order = toposort(&graph).unwrap();
let pos = |n: u32| order.iter().position(|&x| x == NodeId(n)).unwrap();
assert!(pos(0) < pos(1), "0 must come before 1");
assert!(pos(0) < pos(2), "0 must come before 2");
assert!(pos(1) < pos(3), "1 must come before 3");
assert!(pos(2) < pos(3), "2 must come before 3");
}
#[test]
fn test_toposort_fails_on_cycle() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(1), NodeId(2), 1.0),
(NodeId(2), NodeId(0), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let result = toposort(&graph);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Cycle"));
}
#[test]
fn test_disconnected_components() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(2), NodeId(3), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(!is_cyclic(&graph));
let order = toposort(&graph).unwrap();
assert_eq!(order.len(), 4);
let pos = |n: u32| order.iter().position(|&x| x == NodeId(n)).unwrap();
assert!(pos(0) < pos(1));
assert!(pos(2) < pos(3));
}
#[test]
fn test_complex_dag() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(2), 1.0),
(NodeId(0), NodeId(3), 1.0),
(NodeId(1), NodeId(4), 1.0),
(NodeId(2), NodeId(4), 1.0),
(NodeId(3), NodeId(4), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(!is_cyclic(&graph));
let order = toposort(&graph).unwrap();
assert_eq!(order[0], NodeId(0));
assert_eq!(order[4], NodeId(4));
}
#[test]
fn test_two_node_cycle() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(1), NodeId(0), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(is_cyclic(&graph));
}
#[test]
fn test_cycle_in_subgraph() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(1), NodeId(2), 1.0),
(NodeId(2), NodeId(1), 1.0), (NodeId(3), NodeId(4), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert!(is_cyclic(&graph));
}
}