Skip to main content

converge_optimization/graph/
dijkstra.rs

1//! Dijkstra's shortest path algorithm
2//!
3//! Finds shortest paths from a source node to all other nodes
4//! in a graph with non-negative edge weights.
5//!
6//! Time complexity: O((V + E) log V) using a binary heap.
7
8use super::{Graph, NodeId, Path};
9use crate::{Cost, Error, Result};
10use petgraph::visit::EdgeRef;
11use std::cmp::Reverse;
12use std::collections::{BinaryHeap, HashMap};
13
14/// Find shortest path between two nodes, returning the full node sequence.
15pub fn shortest_path<N, E>(
16    graph: &Graph<N, E>,
17    source: NodeId,
18    target: NodeId,
19    edge_cost: impl Fn(&E) -> Cost,
20) -> Result<Option<Path>> {
21    let (distances, predecessors) = dijkstra_with_paths(graph, source, edge_cost)?;
22
23    match distances.get(&target) {
24        Some(&cost) => Ok(Some(reconstruct_path(&predecessors, source, target, cost))),
25        None => Ok(None),
26    }
27}
28
29/// Run Dijkstra's algorithm from a source node
30///
31/// Returns a map from node to shortest distance from source.
32pub fn dijkstra<N, E>(
33    graph: &Graph<N, E>,
34    source: NodeId,
35    edge_cost: impl Fn(&E) -> Cost,
36) -> Result<HashMap<NodeId, Cost>> {
37    let mut distances: HashMap<NodeId, Cost> = HashMap::new();
38    let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
39
40    distances.insert(source, 0);
41    heap.push(Reverse((0, source)));
42
43    while let Some(Reverse((cost, node))) = heap.pop() {
44        // Skip if we've found a better path
45        if let Some(&best) = distances.get(&node) {
46            if cost > best {
47                continue;
48            }
49        }
50
51        // Explore neighbors
52        for edge in graph.edges(node) {
53            let edge_weight = edge_cost(edge.weight());
54            if edge_weight < 0 {
55                return Err(Error::invalid_input(
56                    "Dijkstra requires non-negative edge weights",
57                ));
58            }
59
60            let next = edge.target();
61            let next_cost = cost + edge_weight;
62
63            let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
64
65            if is_better {
66                distances.insert(next, next_cost);
67                heap.push(Reverse((next_cost, next)));
68            }
69        }
70    }
71
72    Ok(distances)
73}
74
75/// Find shortest paths from source to all nodes, returning predecessors
76pub fn dijkstra_with_paths<N, E>(
77    graph: &Graph<N, E>,
78    source: NodeId,
79    edge_cost: impl Fn(&E) -> Cost,
80) -> Result<(HashMap<NodeId, Cost>, HashMap<NodeId, NodeId>)> {
81    let mut distances: HashMap<NodeId, Cost> = HashMap::new();
82    let mut predecessors: HashMap<NodeId, NodeId> = HashMap::new();
83    let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
84
85    distances.insert(source, 0);
86    heap.push(Reverse((0, source)));
87
88    while let Some(Reverse((cost, node))) = heap.pop() {
89        if let Some(&best) = distances.get(&node) {
90            if cost > best {
91                continue;
92            }
93        }
94
95        for edge in graph.edges(node) {
96            let edge_weight = edge_cost(edge.weight());
97            if edge_weight < 0 {
98                return Err(Error::invalid_input(
99                    "Dijkstra requires non-negative edge weights",
100                ));
101            }
102
103            let next = edge.target();
104            let next_cost = cost + edge_weight;
105
106            let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
107
108            if is_better {
109                distances.insert(next, next_cost);
110                predecessors.insert(next, node);
111                heap.push(Reverse((next_cost, next)));
112            }
113        }
114    }
115
116    Ok((distances, predecessors))
117}
118
119/// Reconstruct path from predecessors map
120pub fn reconstruct_path(
121    predecessors: &HashMap<NodeId, NodeId>,
122    source: NodeId,
123    target: NodeId,
124    total_cost: Cost,
125) -> Path {
126    let mut path = vec![target];
127    let mut current = target;
128
129    while current != source {
130        if let Some(&pred) = predecessors.get(&current) {
131            path.push(pred);
132            current = pred;
133        } else {
134            // No path exists
135            return Path::empty();
136        }
137    }
138
139    path.reverse();
140    Path {
141        nodes: path,
142        cost: total_cost,
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use petgraph::graph::DiGraph;
150
151    #[test]
152    fn test_simple_dijkstra() {
153        let mut graph: DiGraph<(), i64> = DiGraph::new();
154        let a = graph.add_node(());
155        let b = graph.add_node(());
156        let c = graph.add_node(());
157
158        graph.add_edge(a, b, 1);
159        graph.add_edge(b, c, 2);
160        graph.add_edge(a, c, 5);
161
162        let distances = dijkstra(&graph, a, |&w| w).unwrap();
163
164        assert_eq!(distances[&a], 0);
165        assert_eq!(distances[&b], 1);
166        assert_eq!(distances[&c], 3); // a->b->c = 1+2 < a->c = 5
167    }
168
169    #[test]
170    fn test_unreachable() {
171        let mut graph: DiGraph<(), i64> = DiGraph::new();
172        let a = graph.add_node(());
173        let b = graph.add_node(());
174
175        // No edge from a to b
176        let distances = dijkstra(&graph, a, |&w| w).unwrap();
177
178        assert_eq!(distances.get(&a), Some(&0));
179        assert_eq!(distances.get(&b), None);
180    }
181}