use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::visit::{
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Direction::Outgoing;
use std::hash::Hash;
pub fn find_cycle<G>(graph: G, source: Option<G::NodeId>) -> Vec<(G::NodeId, G::NodeId)>
where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
let mut cycle: Vec<(G::NodeId, G::NodeId)> = Vec::with_capacity(graph.edge_count());
let source_index = match source {
Some(source_value) => source_value,
None => match find_node_in_arbitrary_cycle(&graph) {
Some(node_in_cycle) => node_in_cycle,
None => {
return Vec::new();
}
},
};
let mut stack: Vec<G::NodeId> = vec![source_index];
let mut pred: HashMap<G::NodeId, G::NodeId> = HashMap::new();
let mut visiting = HashSet::new();
let mut visited = HashSet::new();
while !stack.is_empty() {
let mut z = *stack.last().unwrap();
visiting.insert(z);
let children = graph.neighbors_directed(z, Outgoing);
for child in children {
if visiting.contains(&child) {
cycle.push((z, child));
loop {
if z == child {
cycle.reverse();
break;
}
cycle.push((pred[&z], z));
z = pred[&z];
}
return cycle;
}
if !visited.contains(&child) {
stack.push(child);
pred.insert(child, z);
}
}
let top = *stack.last().unwrap();
if top == z {
stack.pop();
visiting.remove(&z);
visited.insert(z);
}
}
cycle
}
fn find_node_in_arbitrary_cycle<G>(graph: &G) -> Option<G::NodeId>
where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
for scc in algo::kosaraju_scc(&graph) {
if scc.len() > 1 {
return Some(scc[0]);
}
}
for node in graph.node_identifiers() {
for neighbor in graph.neighbors_directed(node, Outgoing) {
if neighbor == node {
return Some(node);
}
}
}
None
}
#[cfg(test)]
mod tests {
use crate::connectivity::find_cycle;
use petgraph::prelude::*;
macro_rules! assert_cycle {
($g: expr, $cycle: expr) => {{
for i in 0..$cycle.len() {
let (s, t) = $cycle[i];
assert!($g.contains_edge(s, t));
let (next_s, _) = $cycle[(i + 1) % $cycle.len()];
assert_eq!(t, next_s);
}
}};
}
#[test]
fn test_find_cycle_source() {
let edge_list = vec![
(0, 1),
(3, 0),
(0, 5),
(8, 0),
(1, 2),
(1, 6),
(2, 3),
(3, 4),
(4, 5),
(6, 7),
(7, 8),
(8, 9),
];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
for i in [0, 1, 2, 3].iter() {
let idx = NodeIndex::new(*i);
let res = find_cycle(&graph, Some(idx));
assert_cycle!(graph, res);
assert_eq!(res[0].0, idx);
}
let res = find_cycle(&graph, Some(NodeIndex::new(5)));
assert_eq!(res, []);
}
#[test]
fn test_self_loop() {
let edge_list = vec![
(0, 1),
(3, 0),
(0, 5),
(8, 0),
(1, 2),
(1, 6),
(2, 3),
(3, 4),
(4, 5),
(6, 7),
(7, 8),
(8, 9),
];
let mut graph = DiGraph::<i32, i32>::from_edges(edge_list);
graph.add_edge(NodeIndex::new(1), NodeIndex::new(1), 0);
let res = find_cycle(&graph, Some(NodeIndex::new(0)));
assert_eq!(res[0].0, NodeIndex::new(1));
assert_cycle!(graph, res);
}
#[test]
fn test_self_loop_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (2, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}
#[test]
fn test_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}
#[test]
fn test_no_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_eq!(res, []);
}
}