converge_optimization/graph/
dijkstra.rs1use super::{Graph, NodeId, Path};
9use crate::{Cost, Error, Result};
10use petgraph::visit::EdgeRef;
11use std::cmp::Reverse;
12use std::collections::{BinaryHeap, HashMap};
13
14pub fn shortest_path<N, E>(
16 graph: &Graph<N, E>,
17 source: NodeId,
18 target: NodeId,
19 edge_cost: impl Fn(&E) -> Cost,
20) -> Result<Option<Path>> {
21 let (distances, predecessors) = dijkstra_with_paths(graph, source, edge_cost)?;
22
23 match distances.get(&target) {
24 Some(&cost) => Ok(Some(reconstruct_path(&predecessors, source, target, cost))),
25 None => Ok(None),
26 }
27}
28
29pub fn dijkstra<N, E>(
33 graph: &Graph<N, E>,
34 source: NodeId,
35 edge_cost: impl Fn(&E) -> Cost,
36) -> Result<HashMap<NodeId, Cost>> {
37 let mut distances: HashMap<NodeId, Cost> = HashMap::new();
38 let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
39
40 distances.insert(source, 0);
41 heap.push(Reverse((0, source)));
42
43 while let Some(Reverse((cost, node))) = heap.pop() {
44 if let Some(&best) = distances.get(&node) {
46 if cost > best {
47 continue;
48 }
49 }
50
51 for edge in graph.edges(node) {
53 let edge_weight = edge_cost(edge.weight());
54 if edge_weight < 0 {
55 return Err(Error::invalid_input(
56 "Dijkstra requires non-negative edge weights",
57 ));
58 }
59
60 let next = edge.target();
61 let next_cost = cost + edge_weight;
62
63 let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
64
65 if is_better {
66 distances.insert(next, next_cost);
67 heap.push(Reverse((next_cost, next)));
68 }
69 }
70 }
71
72 Ok(distances)
73}
74
75pub fn dijkstra_with_paths<N, E>(
77 graph: &Graph<N, E>,
78 source: NodeId,
79 edge_cost: impl Fn(&E) -> Cost,
80) -> Result<(HashMap<NodeId, Cost>, HashMap<NodeId, NodeId>)> {
81 let mut distances: HashMap<NodeId, Cost> = HashMap::new();
82 let mut predecessors: HashMap<NodeId, NodeId> = HashMap::new();
83 let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
84
85 distances.insert(source, 0);
86 heap.push(Reverse((0, source)));
87
88 while let Some(Reverse((cost, node))) = heap.pop() {
89 if let Some(&best) = distances.get(&node) {
90 if cost > best {
91 continue;
92 }
93 }
94
95 for edge in graph.edges(node) {
96 let edge_weight = edge_cost(edge.weight());
97 if edge_weight < 0 {
98 return Err(Error::invalid_input(
99 "Dijkstra requires non-negative edge weights",
100 ));
101 }
102
103 let next = edge.target();
104 let next_cost = cost + edge_weight;
105
106 let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
107
108 if is_better {
109 distances.insert(next, next_cost);
110 predecessors.insert(next, node);
111 heap.push(Reverse((next_cost, next)));
112 }
113 }
114 }
115
116 Ok((distances, predecessors))
117}
118
119pub fn reconstruct_path(
121 predecessors: &HashMap<NodeId, NodeId>,
122 source: NodeId,
123 target: NodeId,
124 total_cost: Cost,
125) -> Path {
126 let mut path = vec![target];
127 let mut current = target;
128
129 while current != source {
130 if let Some(&pred) = predecessors.get(¤t) {
131 path.push(pred);
132 current = pred;
133 } else {
134 return Path::empty();
136 }
137 }
138
139 path.reverse();
140 Path {
141 nodes: path,
142 cost: total_cost,
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use petgraph::graph::DiGraph;
150
151 #[test]
152 fn test_simple_dijkstra() {
153 let mut graph: DiGraph<(), i64> = DiGraph::new();
154 let a = graph.add_node(());
155 let b = graph.add_node(());
156 let c = graph.add_node(());
157
158 graph.add_edge(a, b, 1);
159 graph.add_edge(b, c, 2);
160 graph.add_edge(a, c, 5);
161
162 let distances = dijkstra(&graph, a, |&w| w).unwrap();
163
164 assert_eq!(distances[&a], 0);
165 assert_eq!(distances[&b], 1);
166 assert_eq!(distances[&c], 3); }
168
169 #[test]
170 fn test_unreachable() {
171 let mut graph: DiGraph<(), i64> = DiGraph::new();
172 let a = graph.add_node(());
173 let b = graph.add_node(());
174
175 let distances = dijkstra(&graph, a, |&w| w).unwrap();
177
178 assert_eq!(distances.get(&a), Some(&0));
179 assert_eq!(distances.get(&b), None);
180 }
181}