#[cfg(feature = "hashbrown")]
use hashbrown::HashSet;
#[cfg(not(feature = "hashbrown"))]
use std::collections::HashSet;
use crate::{Edge, Graph, GraphInterface, NodeID};
#[derive(Clone)]
pub struct DepthFirstSearch<'a, G: GraphInterface> {
graph: &'a G,
start: NodeID,
visited: HashSet<NodeID>,
stack: Vec<NodeID>,
cyclic: bool,
visited_edges: Vec<(NodeID, NodeID)>,
}
impl<'a, G: GraphInterface> DepthFirstSearch<'a, G> {
pub fn new(graph: &'a G, start: NodeID) -> Self {
Self {
graph,
start,
visited: HashSet::new(),
stack: vec![start],
cyclic: false,
visited_edges: Vec::new(),
}
}
}
impl<'a, G: GraphInterface> Iterator for DepthFirstSearch<'a, G> {
type Item = NodeID;
fn next(&mut self) -> Option<Self::Item> {
if let Some(node) = self.stack.pop() {
if self.visited.contains(&node) {
self.cyclic = true;
return self.next();
}
self.visited.insert(node);
let node = self.graph.node(node).unwrap();
for edge in &node.connections {
let edge = self.graph.edge(*edge).unwrap();
if (edge.to != self.start) && !self.visited.contains(&edge.to) {
self.stack.push(edge.to);
self.visited_edges.push((edge.from, edge.to));
}
}
return Some(node.id);
}
None
}
}
impl<'a, G: GraphInterface> std::iter::FusedIterator for DepthFirstSearch<'a, G> {}
pub trait IterDepthFirst<'a, G: GraphInterface> {
fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G>;
fn connected_components(&'a self) -> Vec<HashSet<NodeID>>;
}
impl<'a, G: GraphInterface> IterDepthFirst<'a, G> for G {
fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G> {
DepthFirstSearch::new(self, start)
}
fn connected_components(&'a self) -> Vec<HashSet<NodeID>> {
let mut visited = HashSet::new();
let mut components = Vec::new();
let mut current_component = 0usize;
for node_id in self.nodes() {
if visited.contains(&node_id) {
continue;
}
for node in self.iter_depth_first(node_id) {
visited.insert(node);
if current_component >= components.len() {
components.push(HashSet::new());
}
components[current_component].insert(node);
}
current_component += 1;
}
components
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GraphInterface;
#[derive(Clone, Debug)]
enum NodeData {
Int64(i64),
}
impl PartialEq for NodeData {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(NodeData::Int64(a), NodeData::Int64(b)) => a == b,
}
}
}
macro_rules! get_graph {
($graph:ident, $n:expr) => {{
let mut nodes = Vec::new();
for i in 0..$n {
nodes.push(NodeData::Int64(i));
}
let nodes = $graph.add_nodes(&nodes);
if nodes.len() != $n {
panic!("Failed to add nodes");
}
nodes[..].try_into().unwrap()
}};
}
#[test]
fn test_dfs_connected_components() {
let mut graph: Graph<NodeData, ()> = Graph::new();
let [node0, node1, node2, node3, node4] = get_graph!(graph, 5);
let mut components = graph.connected_components();
println!(
"Connected components 1 ({}): {:#?}",
components.len(),
components
);
assert_eq!(components.len(), 5);
assert_eq!(components[0].len(), 1);
graph.add_edges(&[(node0, node1), (node1, node0)]);
components = graph.connected_components();
println!(
"Connected components 2 ({}): {:#?}",
components.len(),
components
);
assert_eq!(components.len(), 4);
assert_eq!(components[0].len(), 2);
graph.add_edges(&[(node2, node3), (node3, node4)]);
components = graph.connected_components();
println!(
"Connected components 3 ({}): {:#?}",
components.len(),
components
);
assert_eq!(components.len(), 2);
assert_eq!(components[1].len(), 3);
}
#[test]
fn test_dfs_iter() {
let mut graph1: Graph<NodeData, ()> = Graph::new();
let [node0, node1, node2, node3, node4] = get_graph!(graph1, 5);
graph1.add_edges(&[
(node0, node1),
(node0, node3),
(node0, node2),
(node1, node0),
(node2, node3),
(node2, node0),
(node2, node4),
(node4, node2),
]);
let mut graph2: Graph<NodeData, ()> = Graph::new();
let [node02, node12, node22, node32, node42] = get_graph!(graph2, 5);
graph2.add_edges(&[
(node02, node32),
(node02, node22),
(node12, node02),
(node22, node32),
(node42, node22),
]);
println!(
"Depth First Search 1 (node count: {}):",
graph1.node_count()
);
println!("Edges: {:#?}", graph1.edges.len());
let mut visited = Vec::new();
let depth_first = graph1.iter_depth_first(node0);
for node in depth_first {
let node = graph1.node(node).unwrap();
visited.push(node);
}
assert_eq!(visited.len(), graph1.node_count());
assert_eq!(visited.len(), 5);
println!(
"Depth First Search 2 (node count: {}):",
graph1.node_count()
);
println!("Edges: {:#?}", graph1.edges.len());
let mut visited = Vec::new();
for node in graph1.iter_depth_first(node0) {
let node = graph1.node(node).unwrap();
visited.push(node);
if node.data == NodeData::Int64(4) {
break;
}
}
assert_ne!(visited.len(), graph1.node_count());
assert_eq!(visited.len(), 3);
println!(
"Depth First Search 3 (node count: {}):",
graph2.node_count()
);
println!("Edges: {:#?}", graph2.edges.len());
let mut visited2 = Vec::new();
for node in graph2.iter_depth_first(node02) {
let node = graph2.node(node).unwrap();
visited2.push(node);
if node.data == NodeData::Int64(4) {
break;
}
}
assert_eq!(visited.len(), visited2.len());
}
}