dsa/algorithms/
graph_traversal.rs

1use std::cmp::Reverse;
2use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
3use std::hash::Hash;
4use std::rc::Rc;
5use crate::data_structures::graph::Graph;
6use crate::data_structures::tree::TreeNode;
7
8/// Performs a breadth-first search (BFS) on the given graph starting from the `start` node.
9///
10/// This function explores the graph level by level, visiting each node starting from the `start` node and
11/// moving outward to its neighbors. It returns the nodes in the order they were visited.
12///
13/// # Examples
14/// ```
15/// use std::rc::Rc;
16/// use dsa::data_structures::graph::Graph;
17/// use dsa::data_structures::tree::TreeNode;
18/// use dsa::algorithms::graph_traversal::breadth_first_search;
19///
20/// let mut graph = Graph::new();
21/// let node1 = Rc::new(TreeNode::new(1));
22/// let node2 = Rc::new(TreeNode::new(2));
23/// let node3 = Rc::new(TreeNode::new(3));
24///
25/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node2), Some(1));
26/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node3), Some(1));
27///
28/// let result = breadth_first_search(&graph, Rc::clone(&node1));
29/// assert_eq!(result, vec![1, 2, 3]);
30/// ```
31///
32/// # Parameters
33/// - `graph`: A reference to the graph in which the search will take place.
34/// - `start`: The node from which the search begins.
35///
36/// # Returns
37/// A vector containing the nodes in the order they were visited.
38pub fn breadth_first_search<T: Eq + Hash + Clone>(
39    graph: &Graph<T>, 
40    start: Rc<TreeNode<T>>
41) -> Vec<T> {
42    let mut visited = HashSet::new();
43    let mut queue = VecDeque::new();
44    let mut result = Vec::new();
45
46    queue.push_back(Rc::clone(&start));
47    visited.insert(Rc::clone(&start));
48
49    while let Some(node) = queue.pop_front() {
50        result.push(node.value.clone());
51        if let Some(neighbors) = graph.graph.get(&node) {
52            for (neighbor, _) in neighbors {
53                if !visited.contains(neighbor) {
54                    queue.push_back(Rc::clone(neighbor));
55                    visited.insert(Rc::clone(neighbor));
56                }
57            }
58        }
59    }
60
61    result
62}
63
64/// Performs a depth-first search (DFS) on the given graph starting from the `start` node.
65///
66/// This function explores the graph by going as deep as possible along each branch before backtracking.
67/// It returns the nodes in the order they were visited.
68///
69/// # Examples
70///
71/// ```
72/// use std::rc::Rc;
73/// use dsa::algorithms::graph_traversal::depth_first_search;
74/// use dsa::data_structures::graph::Graph;
75/// use dsa::data_structures::tree::TreeNode;
76/// 
77/// let mut graph = Graph::new();
78/// let node1 = Rc::new(TreeNode::new(1));
79/// let node2 = Rc::new(TreeNode::new(2));
80/// let node3 = Rc::new(TreeNode::new(3));
81///
82/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node2), Some(1));
83/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node3), Some(1));
84///
85/// let result = depth_first_search(&graph, Rc::clone(&node1));
86/// assert_eq!(result, vec![1, 2, 3]);
87/// ```
88///
89/// # Parameters
90/// - `graph`: A reference to the graph in which the search will take place.
91/// - `start`: The node from which the search begins.
92///
93/// # Returns
94/// A vector containing the values of the nodes in the order they were visited.
95pub fn depth_first_search<T: Eq + Hash + Clone>(
96    graph: &Graph<T>, 
97    start: Rc<TreeNode<T>>
98) -> Vec<T> {
99    let mut visited = HashSet::new();
100    let mut result = Vec::new();
101    depth_first_search_helper(graph, start, &mut visited, &mut result);
102    result
103}
104
105/// Helper function for depth-first search (DFS) that recursively visits nodes in the graph.
106///
107/// This function is used internally in `depth_first_search` to recursively traverse the graph.
108///
109/// # Parameters
110/// - `graph`: A reference to the graph in which the search will take place.
111/// - `node`: The current node being visited.
112/// - `visited`: A mutable set tracking the visited nodes.
113/// - `result`: A mutable vector to store the nodes in the order they were visited.
114fn depth_first_search_helper<'a, T: Eq + Hash + Clone>(
115    graph: &'a Graph<T>, 
116    node: Rc<TreeNode<T>>,
117    visited: &mut HashSet<Rc<TreeNode<T>>>, 
118    result: &mut Vec<T>,
119) {
120    if visited.contains(&node) {
121        return;
122    }
123    visited.insert(Rc::clone(&node));
124    result.push(node.value.clone());
125    if let Some(neighbors) = graph.graph.get(&node) {
126        for(neighbor, _) in neighbors {
127            depth_first_search_helper(graph, Rc::clone(neighbor), visited, result);
128        }
129    }
130}
131
132/// Performs Dijkstra's algorithm to find the shortest path from the `start` node to all other nodes in the graph.
133///
134/// This function computes the shortest distance from the start node to every other node in the graph and returns
135/// a map of nodes to their corresponding shortest distances.
136///
137/// # Examples
138///
139/// ```
140/// use std::collections::HashMap;
141/// use std::rc::Rc;
142/// use dsa::algorithms::graph_traversal::dijkstra;
143/// use dsa::data_structures::graph::Graph;
144/// use dsa::data_structures::tree::TreeNode;
145///
146/// let mut graph = Graph::new();
147/// let node1 = Rc::new(TreeNode::new(1));
148/// let node2 = Rc::new(TreeNode::new(2));
149/// let node3 = Rc::new(TreeNode::new(3));
150///
151/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node2), Some(5));
152/// graph.add_edge(Rc::clone(&node2), Rc::clone(&node3), Some(10));
153///
154/// let result = dijkstra(&graph, Rc::clone(&node1));
155///
156/// let expected = {
157///     let mut map = HashMap::new();
158///     map.insert(Rc::clone(&node1), 0);
159///     map.insert(Rc::clone(&node2), 5);
160///     map.insert(Rc::clone(&node3), 15);
161///     map
162/// };
163///
164/// assert_eq!(result, expected);
165/// ```
166///
167/// # Parameters
168/// - `graph`: A reference to the graph on which Dijkstra's algorithm is to be run.
169/// - `start`: The node from which the shortest paths are calculated.
170///
171/// # Returns
172/// A map of nodes to their shortest distances from the `start` node.
173pub fn dijkstra<T: Eq + Hash + Clone + Ord>(
174    graph: &Graph<T>,
175    start: Rc<TreeNode<T>>
176) -> HashMap<Rc<TreeNode<T>>, u32> {
177    let mut distances = HashMap::new();
178    let mut priority_queue = BinaryHeap::new();
179    
180    distances.insert(Rc::clone(&start), 0);
181    priority_queue.push(Reverse((0, Rc::clone(&start))));
182    
183    while let Some(Reverse((current_distance, current_node))) = priority_queue.pop() {
184        if current_distance < *distances.get(&current_node).unwrap_or(&u32::MAX) {
185            continue;
186        }
187        
188        if let Some(neighbors) = graph.graph.get(&*current_node) {
189            for (neighbor, weight) in neighbors {
190                if let Some(weight) = weight {
191                    let new_distance = current_distance + *weight as u32;
192                    if new_distance < *distances.get(neighbor).unwrap_or(&u32::MAX) {
193                        distances.insert(neighbor.clone(), new_distance);
194                        priority_queue.push(Reverse((new_distance, neighbor.clone())));
195                    }
196                }
197            }
198        }
199    }
200
201    distances
202}
203
204/// Performs the Bellman-Ford algorithm to find the shortest paths from the `start` node to all other nodes in the graph,
205/// even when there are negative edge weights.
206///
207/// This function computes the shortest distance from the start node to every other node and checks for negative weight cycles.
208///
209/// # Examples
210///
211/// ```
212/// use std::collections::HashMap;
213/// use std::rc::Rc;
214/// use dsa::data_structures::graph::Graph;
215/// use dsa::algorithms::graph_traversal::bellman_ford;
216/// use dsa::data_structures::tree::TreeNode;
217///
218/// let mut graph = Graph::new();
219/// let node1 = Rc::new(TreeNode::new(1));
220/// let node2 = Rc::new(TreeNode::new(2));
221/// let node3 = Rc::new(TreeNode::new(3));
222///
223/// graph.add_edge(Rc::clone(&node1), Rc::clone(&node2), Some(5));
224/// graph.add_edge(Rc::clone(&node2), Rc::clone(&node3), Some(10));
225///
226/// let result = bellman_ford(&graph, Rc::clone(&node1));
227///
228/// let expected = {
229///     let mut map = HashMap::new();
230///     map.insert(Rc::clone(&node1), 0);
231///     map.insert(Rc::clone(&node2), 5);
232///     map.insert(Rc::clone(&node3), 15);
233///     Ok(map)
234/// };
235///
236/// assert_eq!(result, expected);
237/// ```
238///
239/// # Parameters
240/// - `graph`: A reference to the graph on which Bellman-Ford algorithm is to be run.
241/// - `start`: The node from which the shortest paths are calculated.
242///
243/// # Returns
244/// A `Result` containing a map of nodes to their shortest distances, or an error if a negative weight cycle is detected.
245pub fn bellman_ford<T: Eq + Hash + Clone + Ord>(
246    graph: &Graph<T>,
247    start: Rc<TreeNode<T>>
248) -> Result<HashMap<Rc<TreeNode<T>>, i32>, String> {
249    let mut distances = HashMap::new();
250    distances.insert(Rc::clone(&start), 0);
251    let num_vertices: i32 = graph.graph.len() as i32;
252
253    for _ in 0..num_vertices - 1 {
254        let mut updates = Vec::new();
255        for (node, neighbors) in &graph.graph {
256            if let Some(current_distance) = distances.get(node) {
257                for (neighbor, weight) in neighbors {
258                    let new_distance = current_distance + weight.unwrap();
259                    updates.push((neighbor, new_distance));
260                }
261            }
262        }
263        for (neighbor, new_distance) in updates {
264            let existing_distance = distances.entry(Rc::clone(neighbor)).or_insert(i32::MAX);
265            if new_distance < *existing_distance {
266                *existing_distance = new_distance;
267            }
268        }
269    }
270
271    for (node, neighbors) in &graph.graph {
272        if let Some(&current_distance) = distances.get(node) {
273            for (neighbor, weight) in neighbors {
274                let new_distance = current_distance + weight.unwrap();
275                if let Some(existing_distance) = distances.get(neighbor) {
276                    if new_distance < *existing_distance {
277                        return Err("Graph contains a negative weight cycle! \
278                            Bellman-Ford will not be accurate for this graph".to_string());
279                    }
280                }
281            }
282        }
283    }
284
285    Ok(distances)
286}