fast_graph/algorithms/
dfs.rs

1//! # Under development
2#[cfg(feature = "hashbrown")]
3use hashbrown::HashSet;
4#[cfg(not(feature = "hashbrown"))]
5use std::collections::HashSet;
6
7use crate::{Edge, Graph, GraphInterface, NodeID};
8
9/// Under development
10#[derive(Clone)]
11pub struct DepthFirstSearch<'a, G: GraphInterface> {
12    graph: &'a G,
13    start: NodeID,
14    visited: HashSet<NodeID>,
15    stack: Vec<NodeID>,
16    cyclic: bool,
17    visited_edges: Vec<(NodeID, NodeID)>,
18}
19
20impl<'a, G: GraphInterface> DepthFirstSearch<'a, G> {
21    pub fn new(graph: &'a G, start: NodeID) -> Self {
22        Self {
23            graph,
24            start,
25            visited: HashSet::new(),
26            stack: vec![start],
27            cyclic: false,
28            visited_edges: Vec::new(),
29        }
30    }
31}
32
33impl<'a, G: GraphInterface> Iterator for DepthFirstSearch<'a, G> {
34    type Item = NodeID;
35
36    fn next(&mut self) -> Option<Self::Item> {
37        if let Some(node) = self.stack.pop() {
38            if self.visited.contains(&node) {
39                self.cyclic = true;
40                return self.next();
41            }
42            self.visited.insert(node);
43
44            let node = self.graph.node(node).unwrap();
45            for edge in &node.connections {
46                let edge = self.graph.edge(*edge).unwrap();
47                if (edge.to != self.start) && !self.visited.contains(&edge.to) {
48                    self.stack.push(edge.to);
49                    self.visited_edges.push((edge.from, edge.to));
50                }
51                // else if (edge.from != self.start) && !self.visited.contains(&edge.from){
52                //     self.stack.push(edge.from)
53                // }
54            }
55
56            return Some(node.id);
57        }
58        None
59    }
60}
61
62impl<'a, G: GraphInterface> std::iter::FusedIterator for DepthFirstSearch<'a, G> {}
63
64/// Under development
65pub trait IterDepthFirst<'a, G: GraphInterface> {
66    /// Returns a *depth first search* iterator starting from a given node
67    fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G>;
68
69    /// Returns a vector of sets of node IDs, where each set is a connected component. \
70    /// Starts a DFS at every node (except if it's already been visited) and marks all reachable nodes as being part of the same component.
71    fn connected_components(&'a self) -> Vec<HashSet<NodeID>>;
72}
73
74impl<'a, G: GraphInterface> IterDepthFirst<'a, G> for G {
75    fn iter_depth_first(&'a self, start: NodeID) -> DepthFirstSearch<'a, G> {
76        DepthFirstSearch::new(self, start)
77    }
78
79    /// Returns a vector of sets of node IDs, where each set is a connected component. \
80    /// Starts a DFS at every node (except if it's already been visited) and marks all reachable nodes as being part of the same component.
81    fn connected_components(&'a self) -> Vec<HashSet<NodeID>> {
82        let mut visited = HashSet::new();
83        let mut components = Vec::new();
84        let mut current_component = 0usize;
85
86        // Starts a DFS at every node
87        for node_id in self.nodes() {
88            // (except if it's already been visited)
89            if visited.contains(&node_id) {
90                continue;
91            }
92            for node in self.iter_depth_first(node_id) {
93                visited.insert(node);
94
95                // and marks all reachable nodes as being part of the same component.
96                if current_component >= components.len() {
97                    components.push(HashSet::new());
98                }
99                components[current_component].insert(node);
100            }
101            current_component += 1;
102        }
103
104        components
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::GraphInterface;
112
113    #[derive(Clone, Debug)]
114    enum NodeData {
115        Int64(i64),
116    }
117    impl PartialEq for NodeData {
118        fn eq(&self, other: &Self) -> bool {
119            match (self, other) {
120                (NodeData::Int64(a), NodeData::Int64(b)) => a == b,
121            }
122        }
123    }
124
125    macro_rules! get_graph {
126        ($graph:ident, $n:expr) => {{
127            let mut nodes = Vec::new();
128            for i in 0..$n {
129                nodes.push(NodeData::Int64(i));
130            }
131            let nodes = $graph.add_nodes(&nodes);
132            if nodes.len() != $n {
133                panic!("Failed to add nodes");
134            }
135            nodes[..].try_into().unwrap()
136        }};
137    }
138
139    #[test]
140    fn test_dfs_connected_components() {
141        let mut graph: Graph<NodeData, ()> = Graph::new();
142        let [node0, node1, node2, node3, node4] = get_graph!(graph, 5);
143
144        let mut components = graph.connected_components();
145        println!(
146            "Connected components 1 ({}): {:#?}",
147            components.len(),
148            components
149        );
150        assert_eq!(components.len(), 5);
151        assert_eq!(components[0].len(), 1);
152
153        graph.add_edges(&[(node0, node1), (node1, node0)]);
154
155        components = graph.connected_components();
156        println!(
157            "Connected components 2 ({}): {:#?}",
158            components.len(),
159            components
160        );
161        assert_eq!(components.len(), 4);
162        assert_eq!(components[0].len(), 2);
163
164        graph.add_edges(&[(node2, node3), (node3, node4)]);
165
166        components = graph.connected_components();
167        println!(
168            "Connected components 3 ({}): {:#?}",
169            components.len(),
170            components
171        );
172
173        assert_eq!(components.len(), 2);
174        assert_eq!(components[1].len(), 3);
175    }
176
177    #[test]
178    fn test_dfs_iter() {
179        let mut graph1: Graph<NodeData, ()> = Graph::new();
180        let [node0, node1, node2, node3, node4] = get_graph!(graph1, 5);
181
182        graph1.add_edges(&[
183            (node0, node1),
184            (node0, node3),
185            (node0, node2),
186            (node1, node0),
187            (node2, node3),
188            (node2, node0),
189            (node2, node4),
190            (node4, node2),
191        ]);
192
193        let mut graph2: Graph<NodeData, ()> = Graph::new();
194        let [node02, node12, node22, node32, node42] = get_graph!(graph2, 5);
195
196        graph2.add_edges(&[
197            (node02, node32),
198            (node02, node22),
199            (node12, node02),
200            (node22, node32),
201            (node42, node22),
202        ]);
203
204        println!(
205            "Depth First Search 1 (node count: {}):",
206            graph1.node_count()
207        );
208        println!("Edges: {:#?}", graph1.edges.len());
209        let mut visited = Vec::new();
210        let depth_first = graph1.iter_depth_first(node0);
211        for node in depth_first {
212            let node = graph1.node(node).unwrap();
213            //println!("{:?}", node.data);
214            visited.push(node);
215        }
216
217        assert_eq!(visited.len(), graph1.node_count());
218        assert_eq!(visited.len(), 5);
219
220        println!(
221            "Depth First Search 2 (node count: {}):",
222            graph1.node_count()
223        );
224        println!("Edges: {:#?}", graph1.edges.len());
225
226        let mut visited = Vec::new();
227        for node in graph1.iter_depth_first(node0) {
228            let node = graph1.node(node).unwrap();
229            //println!("{:?}", node.data);
230            visited.push(node);
231
232            if node.data == NodeData::Int64(4) {
233                break;
234            }
235        }
236
237        assert_ne!(visited.len(), graph1.node_count());
238        assert_eq!(visited.len(), 3);
239
240        println!(
241            "Depth First Search 3 (node count: {}):",
242            graph2.node_count()
243        );
244        println!("Edges: {:#?}", graph2.edges.len());
245        let mut visited2 = Vec::new();
246        for node in graph2.iter_depth_first(node02) {
247            let node = graph2.node(node).unwrap();
248            //println!("{:?}", node.data);
249            visited2.push(node);
250
251            if node.data == NodeData::Int64(4) {
252                break;
253            }
254        }
255
256        assert_eq!(visited.len(), visited2.len());
257    }
258}