Skip to main content

graphrust_algos/
dijkstra.rs

1//! Dijkstra's shortest path algorithm implementation.
2//!
3//! Provides functions for single-source shortest path computation.
4
5use graphrust_core::{EdgeWeight, Graph, NodeId};
6use std::collections::{BinaryHeap, HashMap};
7use std::cmp::Ordering;
8
9/// State for the priority queue in Dijkstra's algorithm
10#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
11struct State {
12    distance: OrderedFloat,
13    node: NodeId,
14}
15
16/// Wrapper for f64 that implements Ord for use in BinaryHeap
17#[derive(Copy, Clone, PartialEq)]
18struct OrderedFloat(f64);
19
20impl Eq for OrderedFloat {}
21
22impl PartialOrd for OrderedFloat {
23    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24        Some(self.cmp(other))
25    }
26}
27
28impl Ord for OrderedFloat {
29    fn cmp(&self, other: &Self) -> Ordering {
30        // Reverse ordering for min-heap
31        other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
32    }
33}
34
35/// Computes shortest distances from start node to all reachable nodes.
36///
37/// # Arguments
38/// * `graph` - The graph with edge weights
39/// * `start_node` - The starting node
40///
41/// # Returns
42/// HashMap mapping each reachable node to its shortest distance from start_node
43pub fn dijkstra(graph: &Graph, start_node: NodeId) -> HashMap<NodeId, f64> {
44    let mut distances: HashMap<NodeId, f64> = HashMap::new();
45    let mut heap = BinaryHeap::new();
46
47    distances.insert(start_node, 0.0);
48    heap.push(State {
49        distance: OrderedFloat(0.0),
50        node: start_node,
51    });
52
53    while let Some(State { distance, node }) = heap.pop() {
54        let current_distance = distance.0;
55
56        // If we've already found a shorter path to this node, skip it
57        if let Some(&known_distance) = distances.get(&node) {
58            if current_distance > known_distance {
59                continue;
60            }
61        }
62
63        for &neighbor in graph.neighbors(node) {
64            let edge_weight = graph
65                .edge_weight(node, neighbor)
66                .unwrap_or(EdgeWeight::default());
67            let next_distance = current_distance + edge_weight.as_f64();
68
69            let is_new_or_shorter = distances
70                .get(&neighbor)
71                .map_or(true, |&d| next_distance < d);
72
73            if is_new_or_shorter {
74                distances.insert(neighbor, next_distance);
75                heap.push(State {
76                    distance: OrderedFloat(next_distance),
77                    node: neighbor,
78                });
79            }
80        }
81    }
82
83    distances
84}
85
86/// Computes the shortest path from start to end node.
87///
88/// # Arguments
89/// * `graph` - The graph with edge weights
90/// * `start` - The starting node
91/// * `end` - The destination node
92///
93/// # Returns
94/// Some(path) if a path exists, None otherwise
95pub fn dijkstra_path(graph: &Graph, start: NodeId, end: NodeId) -> Option<Vec<NodeId>> {
96    let (distances, predecessors) = dijkstra_with_predecessors(graph, start);
97
98    if !distances.contains_key(&end) {
99        return None;
100    }
101
102    let mut path = vec![end];
103    let mut current = end;
104
105    while current != start {
106        current = predecessors.get(&current).copied()?;
107        path.push(current);
108    }
109
110    path.reverse();
111    Some(path)
112}
113
114/// Computes shortest distances and predecessor map for path reconstruction.
115///
116/// # Arguments
117/// * `graph` - The graph with edge weights
118/// * `start` - The starting node
119///
120/// # Returns
121/// Tuple of (distances HashMap, predecessors HashMap)
122pub fn dijkstra_with_predecessors(
123    graph: &Graph,
124    start: NodeId,
125) -> (HashMap<NodeId, f64>, HashMap<NodeId, NodeId>) {
126    let mut distances: HashMap<NodeId, f64> = HashMap::new();
127    let mut predecessors: HashMap<NodeId, NodeId> = HashMap::new();
128    let mut heap = BinaryHeap::new();
129
130    distances.insert(start, 0.0);
131    heap.push(State {
132        distance: OrderedFloat(0.0),
133        node: start,
134    });
135
136    while let Some(State { distance, node }) = heap.pop() {
137        let current_distance = distance.0;
138
139        if let Some(&known_distance) = distances.get(&node) {
140            if current_distance > known_distance {
141                continue;
142            }
143        }
144
145        for &neighbor in graph.neighbors(node) {
146            let edge_weight = graph
147                .edge_weight(node, neighbor)
148                .unwrap_or(EdgeWeight::default());
149            let next_distance = current_distance + edge_weight.as_f64();
150
151            let is_new_or_shorter = distances
152                .get(&neighbor)
153                .map_or(true, |&d| next_distance < d);
154
155            if is_new_or_shorter {
156                distances.insert(neighbor, next_distance);
157                predecessors.insert(neighbor, node);
158                heap.push(State {
159                    distance: OrderedFloat(next_distance),
160                    node: neighbor,
161                });
162            }
163        }
164    }
165
166    (distances, predecessors)
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use graphrust_core::GraphBuilder;
173
174    fn create_weighted_graph() -> Graph {
175        // Create a graph with weighted edges
176        GraphBuilder::new(4)
177            .directed(true)
178            .add_weighted_edge(NodeId(0), NodeId(1), 1.0)
179            .add_weighted_edge(NodeId(0), NodeId(2), 4.0)
180            .add_weighted_edge(NodeId(1), NodeId(2), 2.0)
181            .add_weighted_edge(NodeId(2), NodeId(3), 1.0)
182            .build()
183    }
184
185    #[test]
186    fn test_dijkstra_basic() {
187        let graph = create_weighted_graph();
188        let distances = dijkstra(&graph, NodeId(0));
189
190        assert_eq!(distances.get(&NodeId(0)), Some(&0.0));
191        assert_eq!(distances.get(&NodeId(1)), Some(&1.0));
192        assert_eq!(distances.get(&NodeId(2)), Some(&3.0)); // 0->1->2
193        assert_eq!(distances.get(&NodeId(3)), Some(&4.0)); // 0->1->2->3
194    }
195
196    #[test]
197    fn test_dijkstra_path() {
198        let graph = create_weighted_graph();
199        let path = dijkstra_path(&graph, NodeId(0), NodeId(3));
200
201        assert_eq!(path, Some(vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]));
202    }
203
204    #[test]
205    fn test_dijkstra_no_path() {
206        let graph = GraphBuilder::new(4)
207            .directed(true)
208            .add_weighted_edge(NodeId(0), NodeId(1), 1.0)
209            .build();
210
211        let path = dijkstra_path(&graph, NodeId(0), NodeId(3));
212        assert_eq!(path, None);
213    }
214
215    #[test]
216    fn test_dijkstra_with_predecessors() {
217        let graph = create_weighted_graph();
218        let (distances, predecessors) = dijkstra_with_predecessors(&graph, NodeId(0));
219
220        assert_eq!(distances.get(&NodeId(3)), Some(&4.0));
221        assert_eq!(predecessors.get(&NodeId(3)), Some(&NodeId(2)));
222        assert_eq!(predecessors.get(&NodeId(2)), Some(&NodeId(1)));
223    }
224}