1#[cfg(feature = "hashbrown")]
3use hashbrown::HashSet;
4#[cfg(not(feature = "hashbrown"))]
5use std::collections::HashSet;
6
7use crate::{Edge, Graph, GraphInterface, NodeID};
8
9#[derive(Clone)]
11pub struct DepthFirstSearch<'a, G: GraphInterface> {
12 graph: &'a G,
13 start: NodeID,
14 visited: HashSet<NodeID>,
15 stack: Vec<NodeID>,
16 cyclic: bool,
17 visited_edges: Vec<(NodeID, NodeID)>,
18}
19
20impl<'a, G: GraphInterface> DepthFirstSearch<'a, G> {
21 pub fn new(graph: &'a G, start: NodeID) -> Self {
22 Self {
23 graph,
24 start,
25 visited: HashSet::new(),
26 stack: vec![start],
27 cyclic: false,
28 visited_edges: Vec::new(),
29 }
30 }
31}
32
33impl<'a, G: GraphInterface> Iterator for DepthFirstSearch<'a, G> {
34 type Item = NodeID;
35
36 fn next(&mut self) -> Option<Self::Item> {
37 if let Some(node) = self.stack.pop() {
38 if self.visited.contains(&node) {
39 self.cyclic = true;
40 return self.next();
41 }
42 self.visited.insert(node);
43
44 let node = self.graph.node(node).unwrap();
45 for edge in &node.connections {
46 let edge = self.graph.edge(*edge).unwrap();
47 if (edge.to != self.start) && !self.visited.contains(&edge.to) {
48 self.stack.push(edge.to);
49 self.visited_edges.push((edge.from, edge.to));
50 }
51 }
55
56 return Some(node.id);
57 }
58 None
59 }
60}
61
62impl<'a, G: GraphInterface> std::iter::FusedIterator for DepthFirstSearch<'a, G> {}
63
64pub trait IterDepthFirst<'a, G: GraphInterface> {
66 fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G>;
68
69 fn connected_components(&'a self) -> Vec<HashSet<NodeID>>;
72}
73
74impl<'a, G: GraphInterface> IterDepthFirst<'a, G> for G {
75 fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G> {
76 DepthFirstSearch::new(self, start)
77 }
78
79 fn connected_components(&'a self) -> Vec<HashSet<NodeID>> {
82 let mut visited = HashSet::new();
83 let mut components = Vec::new();
84 let mut current_component = 0usize;
85
86 for node_id in self.nodes() {
88 if visited.contains(&node_id) {
90 continue;
91 }
92 for node in self.iter_depth_first(node_id) {
93 visited.insert(node);
94
95 if current_component >= components.len() {
97 components.push(HashSet::new());
98 }
99 components[current_component].insert(node);
100 }
101 current_component += 1;
102 }
103
104 components
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::GraphInterface;
112
113 #[derive(Clone, Debug)]
114 enum NodeData {
115 Int64(i64),
116 }
117 impl PartialEq for NodeData {
118 fn eq(&self, other: &Self) -> bool {
119 match (self, other) {
120 (NodeData::Int64(a), NodeData::Int64(b)) => a == b,
121 }
122 }
123 }
124
125 macro_rules! get_graph {
126 ($graph:ident, $n:expr) => {{
127 let mut nodes = Vec::new();
128 for i in 0..$n {
129 nodes.push(NodeData::Int64(i));
130 }
131 let nodes = $graph.add_nodes(&nodes);
132 if nodes.len() != $n {
133 panic!("Failed to add nodes");
134 }
135 nodes[..].try_into().unwrap()
136 }};
137 }
138
139 #[test]
140 fn test_dfs_connected_components() {
141 let mut graph: Graph<NodeData, ()> = Graph::new();
142 let [node0, node1, node2, node3, node4] = get_graph!(graph, 5);
143
144 let mut components = graph.connected_components();
145 println!(
146 "Connected components 1 ({}): {:#?}",
147 components.len(),
148 components
149 );
150 assert_eq!(components.len(), 5);
151 assert_eq!(components[0].len(), 1);
152
153 graph.add_edges(&[(node0, node1), (node1, node0)]);
154
155 components = graph.connected_components();
156 println!(
157 "Connected components 2 ({}): {:#?}",
158 components.len(),
159 components
160 );
161 assert_eq!(components.len(), 4);
162 assert_eq!(components[0].len(), 2);
163
164 graph.add_edges(&[(node2, node3), (node3, node4)]);
165
166 components = graph.connected_components();
167 println!(
168 "Connected components 3 ({}): {:#?}",
169 components.len(),
170 components
171 );
172
173 assert_eq!(components.len(), 2);
174 assert_eq!(components[1].len(), 3);
175 }
176
177 #[test]
178 fn test_dfs_iter() {
179 let mut graph1: Graph<NodeData, ()> = Graph::new();
180 let [node0, node1, node2, node3, node4] = get_graph!(graph1, 5);
181
182 graph1.add_edges(&[
183 (node0, node1),
184 (node0, node3),
185 (node0, node2),
186 (node1, node0),
187 (node2, node3),
188 (node2, node0),
189 (node2, node4),
190 (node4, node2),
191 ]);
192
193 let mut graph2: Graph<NodeData, ()> = Graph::new();
194 let [node02, node12, node22, node32, node42] = get_graph!(graph2, 5);
195
196 graph2.add_edges(&[
197 (node02, node32),
198 (node02, node22),
199 (node12, node02),
200 (node22, node32),
201 (node42, node22),
202 ]);
203
204 println!(
205 "Depth First Search 1 (node count: {}):",
206 graph1.node_count()
207 );
208 println!("Edges: {:#?}", graph1.edges.len());
209 let mut visited = Vec::new();
210 let depth_first = graph1.iter_depth_first(node0);
211 for node in depth_first {
212 let node = graph1.node(node).unwrap();
213 visited.push(node);
215 }
216
217 assert_eq!(visited.len(), graph1.node_count());
218 assert_eq!(visited.len(), 5);
219
220 println!(
221 "Depth First Search 2 (node count: {}):",
222 graph1.node_count()
223 );
224 println!("Edges: {:#?}", graph1.edges.len());
225
226 let mut visited = Vec::new();
227 for node in graph1.iter_depth_first(node0) {
228 let node = graph1.node(node).unwrap();
229 visited.push(node);
231
232 if node.data == NodeData::Int64(4) {
233 break;
234 }
235 }
236
237 assert_ne!(visited.len(), graph1.node_count());
238 assert_eq!(visited.len(), 3);
239
240 println!(
241 "Depth First Search 3 (node count: {}):",
242 graph2.node_count()
243 );
244 println!("Edges: {:#?}", graph2.edges.len());
245 let mut visited2 = Vec::new();
246 for node in graph2.iter_depth_first(node02) {
247 let node = graph2.node(node).unwrap();
248 visited2.push(node);
250
251 if node.data == NodeData::Int64(4) {
252 break;
253 }
254 }
255
256 assert_eq!(visited.len(), visited2.len());
257 }
258}