graph_algorithms/
dijkstra.rs

1use std::{
2    cmp::Ordering,
3    collections::{BinaryHeap, HashMap},
4};
5
6use crate::{GraphAlgorithm, GraphError};
7
8/// Dijkstra's Algorithm.
9/// Find the shortest path from a starting node to all other nodes in a weighted graph.
10#[derive(Debug, Clone)]
11pub struct DijkstraAlgorithm {
12    /// Graph to search.
13    pub graph: HashMap<usize, Vec<(usize, usize)>>,
14}
15
16/// State of the algorithm.
17#[derive(Copy, Clone, Debug, Eq, PartialEq)]
18struct State {
19    /// Cost of the path.
20    cost: usize,
21
22    /// Position of the node.
23    position: usize,
24}
25
26impl Ord for State {
27    /// Compare two states.
28    ///
29    /// # Arguments
30    ///
31    /// - `other`: The other state to compare.
32    ///
33    /// # Returns
34    ///
35    /// Ordering of the two states.
36    fn cmp(&self, other: &Self) -> Ordering {
37        other
38            .cost
39            .cmp(&self.cost)
40            .then_with(|| self.position.cmp(&other.position))
41    }
42}
43
44impl PartialOrd for State {
45    /// Compare two states partially.
46    ///
47    /// # Arguments
48    ///
49    /// - `other`: The other state to compare.
50    ///
51    /// # Returns
52    ///
53    /// Ordering of the two states.
54    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
55        Some(self.cmp(other))
56    }
57}
58
59impl Default for DijkstraAlgorithm {
60    /// Create a new default instance of Dijkstra's Algorithm.
61    ///
62    /// # Returns
63    ///
64    /// New default instance of Dijkstra's Algorithm.
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl DijkstraAlgorithm {
71    /// Create a new instance of Dijkstra's Algorithm.
72    ///
73    /// # Returns
74    ///
75    /// New instance of Dijkstra's Algorithm.
76    pub fn new() -> Self {
77        DijkstraAlgorithm {
78            graph: HashMap::new(),
79        }
80    }
81
82    /// Set the node of the graph.
83    ///
84    /// # Arguments
85    ///
86    /// - `node`: Node of the graph.
87    /// - `edges`: Edges of the node.
88    pub fn set_node(&mut self, node: usize, edges: Vec<(usize, usize)>) {
89        self.graph.insert(node, edges);
90    }
91
92    /// Set the nodes of the graph.
93    ///
94    /// # Arguments
95    ///
96    /// - `nodes`: Vector of nodes and their edges.
97    pub fn set_nodes(&mut self, nodes: Vec<(usize, Vec<(usize, usize)>)>) {
98        for (node, edges) in nodes {
99            self.graph.insert(node, edges);
100        }
101    }
102}
103
104impl GraphAlgorithm for DijkstraAlgorithm {
105    /// Type of node.
106    type Node = usize;
107
108    /// Type of weight.
109    type Weight = Vec<usize>;
110
111    /// Run Dijkstra's Algorithm.
112    ///
113    /// # Arguments
114    ///
115    /// - `start`: Starting node.
116    ///
117    /// # Returns
118    ///
119    /// Vector of the shortest path from the starting node to all other nodes.
120    fn run(&self, start: Option<Self::Node>) -> Result<Self::Weight, GraphError> {
121        let start = start.ok_or(GraphError::MissingStartNode)?;
122
123        let mut priority_queue = BinaryHeap::new();
124        let mut distances = HashMap::new();
125        let mut result = vec![usize::MAX; self.graph.len()];
126
127        distances.insert(start, 0);
128        priority_queue.push(State {
129            cost: 0,
130            position: start,
131        });
132
133        while let Some(state) = priority_queue.pop() {
134            // Determine if the current shortest path is already known.
135            // If it is, skip the current node.
136            if distances
137                .get(&state.position)
138                .map(|&d| state.cost > d)
139                .unwrap_or(false)
140            {
141                continue;
142            }
143
144            if let Some(neighbors) = self.graph.get(&state.position) {
145                for &(neighbor, weight) in neighbors {
146                    let next = State {
147                        cost: state.cost + weight,
148                        position: neighbor,
149                    };
150
151                    // Determine if the new path is shorter than the current shortest path.
152                    // If it is, update the shortest path.
153                    if distances
154                        .get(&neighbor)
155                        .map(|&d| next.cost < d)
156                        .unwrap_or(true)
157                    {
158                        distances.insert(neighbor, next.cost);
159                        priority_queue.push(next);
160                    }
161                }
162            }
163        }
164
165        // Convert the distances to a vector of shortest paths.
166        for (node, dist) in distances.into_iter() {
167            if node < result.len() {
168                result[node] = dist;
169            }
170        }
171
172        Ok(result)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_new() {
182        let algorithm = DijkstraAlgorithm::new();
183        let algorithm_default = DijkstraAlgorithm::default();
184
185        assert_eq!(algorithm.graph.len(), 0);
186        assert_eq!(algorithm_default.graph.len(), 0);
187    }
188
189    #[test]
190    fn test_missing_start_node() {
191        let algorithm = DijkstraAlgorithm::new();
192
193        assert_eq!(algorithm.run(None), Err(GraphError::MissingStartNode));
194    }
195
196    #[test]
197    fn test_run() {
198        let mut algorithm = DijkstraAlgorithm::new();
199        let nodes = vec![
200            (0, vec![(1, 1), (2, 4), (3, 7)]),
201            (1, vec![(2, 2), (3, 5), (4, 12)]),
202            (2, vec![(3, 1), (4, 3)]),
203            (3, vec![(4, 2), (5, 8)]),
204            (4, vec![(5, 1), (6, 5)]),
205            (5, vec![(6, 2), (7, 3)]),
206            (6, vec![(7, 1), (8, 4)]),
207            (7, vec![(8, 2), (9, 6)]),
208            (8, vec![(9, 1)]),
209            (9, vec![(10, 2), (11, 3)]),
210            (10, vec![(11, 1), (12, 4)]),
211            (11, vec![(12, 2), (13, 6)]),
212            (12, vec![(13, 1), (14, 5)]),
213            (13, vec![(14, 2), (15, 3)]),
214            (14, vec![(15, 1), (16, 4)]),
215            (15, vec![(16, 2), (17, 6)]),
216            (16, vec![(17, 1), (18, 5)]),
217            (17, vec![(18, 2), (19, 3)]),
218            (18, vec![(19, 1)]),
219            (19, vec![]),
220        ];
221        algorithm.set_nodes(nodes);
222
223        assert_eq!(
224            algorithm.run(Some(0)).unwrap(),
225            vec![0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22, 24, 25, 27, 28]
226        );
227    }
228
229    #[test]
230    fn test_run_single_node_graph() {
231        let mut algorithm = DijkstraAlgorithm::new();
232        algorithm.set_nodes(vec![(0, vec![])]);
233
234        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0]);
235    }
236
237    #[test]
238    fn test_run_two_node_graph() {
239        let mut algorithm = DijkstraAlgorithm::new();
240        algorithm.set_node(0, vec![(1, 1)]);
241        algorithm.set_node(1, vec![]);
242
243        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1]);
244    }
245
246    #[test]
247    fn test_run_disconnected_graph() {
248        let mut algorithm = DijkstraAlgorithm::new();
249        algorithm.set_nodes(vec![(0, vec![]), (1, vec![])]);
250
251        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, usize::MAX]);
252    }
253
254    #[test]
255    fn test_run_simple_path() {
256        let mut algorithm = DijkstraAlgorithm::new();
257        algorithm.set_nodes(vec![(0, vec![(1, 1)]), (1, vec![(2, 1)]), (2, vec![])]);
258
259        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 2]);
260    }
261
262    #[test]
263    fn test_run_multiple_paths() {
264        let mut algorithm = DijkstraAlgorithm::new();
265        algorithm.set_nodes(vec![
266            (0, vec![(1, 1), (2, 4)]),
267            (1, vec![(2, 2)]),
268            (2, vec![]),
269        ]);
270
271        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 3]);
272    }
273
274    #[test]
275    fn test_run_graph_with_cycle() {
276        let mut algorithm = DijkstraAlgorithm::new();
277        algorithm.set_nodes(vec![
278            (0, vec![(1, 1)]),
279            (1, vec![(2, 1)]),
280            (2, vec![(0, 1)]),
281        ]);
282
283        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 2]);
284    }
285
286    #[test]
287    fn test_run_graph_with_multiple_shortest_paths() {
288        let mut algorithm = DijkstraAlgorithm::new();
289        algorithm.set_nodes(vec![
290            (0, vec![(1, 1), (2, 1)]),
291            (1, vec![(2, 1)]),
292            (2, vec![]),
293        ]);
294
295        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 1]);
296    }
297
298    #[test]
299    fn test_run_large_weights() {
300        let mut algorithm = DijkstraAlgorithm::new();
301        algorithm.set_nodes(vec![
302            (0, vec![(1, 1000)]),
303            (1, vec![(2, 1000)]),
304            (2, vec![]),
305        ]);
306
307        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1000, 2000]);
308    }
309
310    #[test]
311    fn test_run_no_edges() {
312        let mut algorithm = DijkstraAlgorithm::new();
313        algorithm.set_nodes(vec![(0, vec![]), (1, vec![]), (2, vec![])]);
314
315        assert_eq!(
316            algorithm.run(Some(0)).unwrap(),
317            vec![0, usize::MAX, usize::MAX]
318        );
319    }
320
321    #[test]
322    fn test_run_multiple_edges_to_same_node() {
323        let mut algorithm = DijkstraAlgorithm::new();
324        algorithm.set_nodes(vec![
325            (0, vec![(1, 1), (1, 2)]),
326            (1, vec![(2, 1)]),
327            (2, vec![]),
328        ]);
329
330        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 2]);
331    }
332
333    #[test]
334    fn test_run_graph_with_isolated_node() {
335        let mut algorithm = DijkstraAlgorithm::new();
336        algorithm.set_nodes(vec![
337            (0, vec![(1, 1)]),
338            (1, vec![(2, 1)]),
339            (2, vec![]),
340            (3, vec![]),
341        ]);
342
343        assert_eq!(algorithm.run(Some(0)).unwrap(), vec![0, 1, 2, usize::MAX]);
344    }
345}