1use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap, HashSet};
11
12use crate::metrics::CostMetric;
13
14#[derive(Debug, Clone)]
16pub struct AStarResult<N> {
17 pub path: Vec<N>,
19 pub cost: f64,
21}
22
23#[derive(Clone)]
24struct OpenEntry<N> {
25 node: N,
26 f: f64,
27 g: f64,
28}
29
30impl<N: PartialEq> PartialEq for OpenEntry<N> {
31 fn eq(&self, other: &Self) -> bool {
32 self.node == other.node
33 }
34}
35
36impl<N: Eq> Eq for OpenEntry<N> {}
37
38impl<N: Eq> PartialOrd for OpenEntry<N> {
39 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44impl<N: Eq> Ord for OpenEntry<N> {
45 fn cmp(&self, other: &Self) -> Ordering {
46 other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
47 }
48}
49
50pub fn astar<N, FH, FN, I>(
60 start: N,
61 goal: N,
62 mut heuristic: FH,
63 mut neighbors: FN,
64) -> Option<AStarResult<N>>
65where
66 N: Clone + Eq + std::hash::Hash,
67 FH: FnMut(&N, &N) -> f64,
68 FN: FnMut(&N) -> I,
69 I: IntoIterator<Item = (N, f64)>,
70{
71 let mut open = BinaryHeap::new();
72 let mut g_scores: HashMap<N, f64> = HashMap::new();
73 let mut came_from: HashMap<N, N> = HashMap::new();
74 let mut closed: HashSet<N> = HashSet::new();
75
76 let h = heuristic(&start, &goal);
77 g_scores.insert(start.clone(), 0.0);
78 open.push(OpenEntry {
79 node: start.clone(),
80 f: h,
81 g: 0.0,
82 });
83
84 while let Some(current) = open.pop() {
85 if current.node == goal {
86 let mut path = Vec::new();
87 let mut cur = goal.clone();
88 loop {
89 path.push(cur.clone());
90 match came_from.get(&cur) {
91 Some(prev) => cur = prev.clone(),
92 None => break,
93 }
94 }
95 path.reverse();
96 return Some(AStarResult {
97 path,
98 cost: current.g,
99 });
100 }
101
102 if !closed.insert(current.node.clone()) {
103 continue;
104 }
105
106 for (neighbor, edge_cost) in neighbors(¤t.node) {
107 if closed.contains(&neighbor) {
108 continue;
109 }
110
111 let tentative_g = current.g + edge_cost;
112 let prev_g = g_scores.get(&neighbor).copied().unwrap_or(f64::INFINITY);
113
114 if tentative_g < prev_g {
115 g_scores.insert(neighbor.clone(), tentative_g);
116 came_from.insert(neighbor.clone(), current.node.clone());
117 let h = heuristic(&neighbor, &goal);
118 open.push(OpenEntry {
119 node: neighbor,
120 f: tentative_g + h,
121 g: tentative_g,
122 });
123 }
124 }
125 }
126
127 None
128}
129
130pub fn astar_grid2d(
136 start: (usize, usize),
137 goal: (usize, usize),
138 width: usize,
139 height: usize,
140 walkable: &dyn Fn(usize, usize) -> bool,
141 diagonal: bool,
142) -> Option<AStarResult<(usize, usize)>> {
143 let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
144 let dx = (a.0 as f64 - b.0 as f64).abs();
145 let dy = (a.1 as f64 - b.1 as f64).abs();
146 if diagonal {
147 dx.max(dy) } else {
149 dx + dy }
151 };
152
153 let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
154 let (x, y) = *node;
155 let mut result = Vec::new();
156
157 let deltas: &[(i32, i32)] = if diagonal {
158 &[
159 (-1, -1),
160 (-1, 0),
161 (-1, 1),
162 (0, -1),
163 (0, 1),
164 (1, -1),
165 (1, 0),
166 (1, 1),
167 ]
168 } else {
169 &[(-1, 0), (1, 0), (0, -1), (0, 1)]
170 };
171
172 for &(dx, dy) in deltas {
173 let nx = x as i32 + dx;
174 let ny = y as i32 + dy;
175 if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
176 let nx = nx as usize;
177 let ny = ny as usize;
178 if walkable(nx, ny) {
179 let cost = if dx != 0 && dy != 0 {
180 std::f64::consts::SQRT_2
181 } else {
182 1.0
183 };
184 result.push(((nx, ny), cost));
185 }
186 }
187 }
188
189 result
190 };
191
192 astar(start, goal, heuristic, neighbors)
193}
194
195pub struct GridAStarOpts<'a> {
199 pub width: usize,
201 pub height: usize,
203 pub diagonal: bool,
205 pub periodic: bool,
207 pub admissibility: f64,
213 pub walkable: Option<&'a dyn Fn(usize, usize) -> bool>,
216 pub cost_metric: Option<&'a dyn CostMetric>,
219}
220
221impl<'a> GridAStarOpts<'a> {
222 pub fn new(width: usize, height: usize) -> Self {
224 Self {
225 width,
226 height,
227 diagonal: true,
228 periodic: false,
229 admissibility: 0.0,
230 walkable: None,
231 cost_metric: None,
232 }
233 }
234}
235
236impl<'a> std::fmt::Debug for GridAStarOpts<'a> {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("GridAStarOpts")
239 .field("width", &self.width)
240 .field("height", &self.height)
241 .field("diagonal", &self.diagonal)
242 .field("periodic", &self.periodic)
243 .field("admissibility", &self.admissibility)
244 .field("walkable", &self.walkable.is_some())
245 .field("cost_metric", &self.cost_metric.is_some())
246 .finish()
247 }
248}
249
250pub fn astar_grid2d_opts(
284 start: (usize, usize),
285 goal: (usize, usize),
286 opts: &GridAStarOpts<'_>,
287) -> Option<AStarResult<(usize, usize)>> {
288 let width = opts.width;
289 let height = opts.height;
290 let diagonal = opts.diagonal;
291 let periodic = opts.periodic;
292 let admissibility = opts.admissibility;
293
294 let default_metric = crate::metrics::DirectDistance::new();
295 let metric: &dyn CostMetric = opts.cost_metric.unwrap_or(&default_metric);
296
297 let always_walkable = |_: usize, _: usize| true;
298 let walkable: &dyn Fn(usize, usize) -> bool = match &opts.walkable {
299 Some(f) => *f,
300 None => &always_walkable,
301 };
302
303 let start_n = normalize_grid_pos(start, periodic, width, height)?;
305 let goal_n = normalize_grid_pos(goal, periodic, width, height)?;
306 if !walkable(start_n.0, start_n.1) || !walkable(goal_n.0, goal_n.1) {
307 return None;
308 }
309
310 let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
311 (1.0 + admissibility) * metric.delta_cost(*a, *b, periodic, width, height, diagonal)
312 };
313
314 let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
315 let (x, y) = *node;
316 let mut result = Vec::new();
317
318 let deltas: &[(i32, i32)] = if diagonal {
319 &[
320 (-1, -1),
321 (-1, 0),
322 (-1, 1),
323 (0, -1),
324 (0, 1),
325 (1, -1),
326 (1, 0),
327 (1, 1),
328 ]
329 } else {
330 &[(-1, 0), (1, 0), (0, -1), (0, 1)]
331 };
332
333 for &(dx, dy) in deltas {
334 let nx = x as i32 + dx;
335 let ny = y as i32 + dy;
336
337 let neighbor = if periodic {
338 let px = ((nx % width as i32) + width as i32) % width as i32;
339 let py = ((ny % height as i32) + height as i32) % height as i32;
340 Some((px as usize, py as usize))
341 } else if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
342 Some((nx as usize, ny as usize))
343 } else {
344 None
345 };
346
347 if let Some(n) = neighbor {
348 if walkable(n.0, n.1) {
349 let cost = metric.delta_cost(*node, n, periodic, width, height, diagonal);
350 result.push((n, cost));
351 }
352 }
353 }
354
355 result
356 };
357
358 astar(start_n, goal_n, heuristic, neighbors)
359}
360
361fn normalize_grid_pos(
362 pos: (usize, usize),
363 periodic: bool,
364 width: usize,
365 height: usize,
366) -> Option<(usize, usize)> {
367 if periodic {
368 Some((pos.0 % width, pos.1 % height))
369 } else if pos.0 < width && pos.1 < height {
370 Some(pos)
371 } else {
372 None
373 }
374}