advanced_algorithms/graph/
astar.rs1use crate::graph::Graph;
25use std::collections::{BinaryHeap, HashMap, HashSet};
26use std::cmp::Ordering;
27
28#[derive(Debug, Clone)]
29struct State {
30 node: usize,
31 f_score: f64, g_score: f64, }
34
35impl Eq for State {}
36
37impl PartialEq for State {
38 fn eq(&self, other: &Self) -> bool {
39 self.f_score == other.f_score && self.node == other.node
40 }
41}
42
43impl Ord for State {
44 fn cmp(&self, other: &Self) -> Ordering {
45 other.f_score.partial_cmp(&self.f_score)
47 .unwrap_or(Ordering::Equal)
48 .then_with(|| self.node.cmp(&other.node))
49 }
50}
51
52impl PartialOrd for State {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58pub fn find_path<H>(
75 graph: &Graph,
76 start: usize,
77 goal: usize,
78 heuristic: H,
79) -> Option<(f64, Vec<usize>)>
80where
81 H: Fn(usize) -> f64,
82{
83 let mut open_set = BinaryHeap::new();
84 let mut came_from: HashMap<usize, usize> = HashMap::new();
85 let mut g_score: HashMap<usize, f64> = HashMap::new();
86 let mut closed_set: HashSet<usize> = HashSet::new();
87
88 g_score.insert(start, 0.0);
89 open_set.push(State {
90 node: start,
91 f_score: heuristic(start),
92 g_score: 0.0,
93 });
94
95 while let Some(State { node, g_score: current_g, .. }) = open_set.pop() {
96 if closed_set.contains(&node) {
98 continue;
99 }
100
101 if node == goal {
103 return Some((current_g, reconstruct_path(&came_from, start, goal)));
104 }
105
106 closed_set.insert(node);
107
108 for &(neighbor, weight) in graph.neighbors(node) {
110 if closed_set.contains(&neighbor) {
111 continue;
112 }
113
114 let tentative_g = current_g + weight;
115
116 if let Some(&existing_g) = g_score.get(&neighbor) {
118 if tentative_g >= existing_g {
119 continue;
120 }
121 }
122
123 came_from.insert(neighbor, node);
125 g_score.insert(neighbor, tentative_g);
126
127 let h = heuristic(neighbor);
128 open_set.push(State {
129 node: neighbor,
130 f_score: tentative_g + h,
131 g_score: tentative_g,
132 });
133 }
134 }
135
136 None }
138
139fn reconstruct_path(
141 came_from: &HashMap<usize, usize>,
142 start: usize,
143 goal: usize,
144) -> Vec<usize> {
145 let mut path = vec![goal];
146 let mut current = goal;
147
148 while current != start {
149 if let Some(&prev) = came_from.get(¤t) {
150 path.push(prev);
151 current = prev;
152 } else {
153 break;
154 }
155 }
156
157 path.reverse();
158 path
159}
160
161pub fn find_path_bounded<H>(
163 graph: &Graph,
164 start: usize,
165 goal: usize,
166 heuristic: H,
167 max_cost: f64,
168) -> Option<(f64, Vec<usize>)>
169where
170 H: Fn(usize) -> f64,
171{
172 let mut open_set = BinaryHeap::new();
173 let mut came_from: HashMap<usize, usize> = HashMap::new();
174 let mut g_score: HashMap<usize, f64> = HashMap::new();
175 let mut closed_set: HashSet<usize> = HashSet::new();
176
177 g_score.insert(start, 0.0);
178 open_set.push(State {
179 node: start,
180 f_score: heuristic(start),
181 g_score: 0.0,
182 });
183
184 while let Some(State { node, g_score: current_g, f_score }) = open_set.pop() {
185 if f_score > max_cost {
187 return None;
188 }
189
190 if closed_set.contains(&node) {
191 continue;
192 }
193
194 if node == goal {
195 return Some((current_g, reconstruct_path(&came_from, start, goal)));
196 }
197
198 closed_set.insert(node);
199
200 for &(neighbor, weight) in graph.neighbors(node) {
201 if closed_set.contains(&neighbor) {
202 continue;
203 }
204
205 let tentative_g = current_g + weight;
206
207 if tentative_g > max_cost {
209 continue;
210 }
211
212 if let Some(&existing_g) = g_score.get(&neighbor) {
213 if tentative_g >= existing_g {
214 continue;
215 }
216 }
217
218 came_from.insert(neighbor, node);
219 g_score.insert(neighbor, tentative_g);
220
221 let h = heuristic(neighbor);
222 open_set.push(State {
223 node: neighbor,
224 f_score: tentative_g + h,
225 g_score: tentative_g,
226 });
227 }
228 }
229
230 None
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
235pub struct GridPos {
236 pub x: i32,
237 pub y: i32,
238}
239
240impl GridPos {
241 pub fn new(x: i32, y: i32) -> Self {
242 GridPos { x, y }
243 }
244
245 pub fn manhattan_distance(&self, other: &GridPos) -> f64 {
247 ((self.x - other.x).abs() + (self.y - other.y).abs()) as f64
248 }
249
250 pub fn euclidean_distance(&self, other: &GridPos) -> f64 {
252 let dx = (self.x - other.x) as f64;
253 let dy = (self.y - other.y) as f64;
254 (dx * dx + dy * dy).sqrt()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_simple_path() {
264 let mut graph = Graph::new(4);
265 graph.add_edge(0, 1, 1.0);
266 graph.add_edge(1, 2, 1.0);
267 graph.add_edge(2, 3, 1.0);
268 graph.add_edge(0, 3, 5.0);
269
270 let heuristic = |_: usize| 0.0;
272
273 let result = find_path(&graph, 0, 3, heuristic).unwrap();
274
275 assert_eq!(result.0, 3.0);
276 assert_eq!(result.1, vec![0, 1, 2, 3]);
277 }
278
279 #[test]
280 fn test_with_heuristic() {
281 let mut graph = Graph::new(5);
282 graph.add_edge(0, 1, 2.0);
283 graph.add_edge(0, 2, 1.0);
284 graph.add_edge(1, 3, 1.0);
285 graph.add_edge(2, 3, 5.0);
286 graph.add_edge(3, 4, 1.0);
287
288 let heuristic = |node: usize| (4 - node) as f64;
290
291 let result = find_path(&graph, 0, 4, heuristic).unwrap();
292
293 assert_eq!(result.0, 4.0); }
295
296 #[test]
297 fn test_no_path() {
298 let mut graph = Graph::new(3);
299 graph.add_edge(0, 1, 1.0);
300
301 let heuristic = |_: usize| 0.0;
302
303 let result = find_path(&graph, 0, 2, heuristic);
304 assert!(result.is_none());
305 }
306
307 #[test]
308 fn test_manhattan_distance() {
309 let p1 = GridPos::new(0, 0);
310 let p2 = GridPos::new(3, 4);
311
312 assert_eq!(p1.manhattan_distance(&p2), 7.0);
313 assert!((p1.euclidean_distance(&p2) - 5.0).abs() < 0.001);
314 }
315}