1#![doc = include_str!("../README.md")]
2mod grid;
3mod node;
4mod point;
5
6use std::collections::BinaryHeap;
7
8pub use grid::Grid;
9pub use node::Node;
10pub use point::Point;
11
12#[must_use]
15pub fn astar(grid: &Grid, start: Point, end: Point) -> Option<Vec<Point>> {
16 let width = grid.width();
17 let height = grid.height();
18 let capacity = width * height;
19
20 let mut open_nodes = BinaryHeap::new(); let mut closed_nodes = vec![false; capacity];
22 let mut g_scores = vec![u16::MAX; capacity];
23 let mut all_nodes = Vec::with_capacity(capacity);
24
25 let start_node = Node {
26 point: start,
27 g: 0,
28 h: manhattan_distance(&start, &end),
29 parent_index: None,
30 };
31
32 let start_index = point_to_index(start, width);
33 g_scores[start_index] = 0;
34 all_nodes.push(start_node);
35 open_nodes.push((0, 0));
36
37 while let Some((_f_score, current_index)) = open_nodes.pop() {
38 let current = all_nodes[current_index];
39
40 if current.point == end {
41 return Some(retrace_path(&all_nodes, current_index));
42 }
43
44 let current_point_index = point_to_index(current.point, width);
45 if closed_nodes[current_point_index] {
46 continue;
47 }
48 closed_nodes[current_point_index] = true;
49
50 let current_g = current.g;
51
52 for neighbor_point in get_neighbor_points(grid, current.point) {
53 let neighbor_index = point_to_index(neighbor_point, width);
54 if closed_nodes[neighbor_index] {
55 continue;
56 }
57
58 let tentative_g = current_g + 1;
59
60 if tentative_g >= g_scores[neighbor_index] as isize {
61 continue;
62 }
63
64 let h = manhattan_distance(&neighbor_point, &end);
65 let f = tentative_g + h;
66
67 let neighbor = Node {
68 point: neighbor_point,
69 g: tentative_g,
70 h,
71 parent_index: Some(current_index),
72 };
73
74 g_scores[neighbor_index] = tentative_g as u16;
75 all_nodes.push(neighbor);
76 open_nodes.push((-f, all_nodes.len() - 1));
77 }
78 }
79
80 None
81}
82
83#[inline]
85fn point_to_index(point: Point, width: usize) -> usize {
86 point.y as usize * width + point.x as usize
87}
88
89fn retrace_path(nodes: &[Node], mut current_index: usize) -> Vec<Point> {
91 let start = nodes[0].point;
92 let end = nodes[current_index].point;
93 let initial_capacity = manhattan_distance(&start, &end) as usize;
94 let mut path = Vec::with_capacity(initial_capacity);
95
96 loop {
97 let current = &nodes[current_index];
98 path.push(current.point);
99
100 if let Some(parent_index) = current.parent_index {
101 current_index = parent_index;
102 } else {
103 break;
104 }
105 }
106
107 path.reverse();
108 path
109}
110
111#[inline]
113fn manhattan_distance(a: &Point, b: &Point) -> isize {
114 (a.x - b.x).abs() + (a.y - b.y).abs()
115}
116
117const NEIGHBORS: [(isize, isize); 4] = [(0, 1), (1, 0), (0, -1), (-1, 0)];
119
120fn get_neighbor_points(grid: &Grid, point: Point) -> impl Iterator<Item = Point> + '_ {
121 NEIGHBORS.iter().filter_map(move |&(dx, dy)| {
122 let new_x = point.x + dx;
123 let new_y = point.y + dy;
124 grid.get(new_x, new_y)
125 .filter(|&node| !node)
126 .map(|_| Point::new(new_x, new_y))
127 })
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_valid_path() {
136 #[rustfmt::skip]
137 let grid = Grid::from_2d(vec![
138 vec![false, false, false],
139 vec![true, false, true ],
140 vec![false, false, false],
141 ]);
142
143 let start = Point::new(0, 0);
144 let end = Point::new(2, 2);
145
146 let path = astar(&grid, start, end).unwrap();
147
148 assert_eq!(
149 path,
150 vec![
151 start,
152 Point::new(1, 0),
153 Point::new(1, 1),
154 Point::new(1, 2),
155 end
156 ]
157 );
158 }
159
160 #[test]
161 fn test_no_valid_path() {
162 #[rustfmt::skip]
163 let grid = Grid::from_2d(vec![
164 vec![false, false, false],
165 vec![false, true, true ],
166 vec![false, true, false],
167 ]);
168
169 let start = Point::new(0, 0);
170 let end = Point::new(2, 2);
171
172 let path = astar(&grid, start, end);
173
174 assert!(path.is_none());
175 }
176
177 #[test]
178 fn test_collidable_neighbors() {
179 #[rustfmt::skip]
180 let grid = Grid::from_2d(vec![
181 vec![false, false, false],
182 vec![true, false, false],
183 vec![false, false, false],
184 ]);
185
186 let point = Point::new(0, 0);
187 let neighbors = get_neighbor_points(&grid, point);
188 assert_eq!(neighbors.count(), 1);
189
190 let point2 = Point::new(1, 1);
191 let neighbors2 = get_neighbor_points(&grid, point2);
192 assert_eq!(neighbors2.count(), 3);
193
194 let point3 = Point::new(2, 1);
195 let neighbors3 = get_neighbor_points(&grid, point3);
196 assert_eq!(neighbors3.count(), 3);
197 }
198
199 #[test]
200 fn test_distance() {
201 let point1 = Point::new(0, 0);
202 let point2 = Point::new(1, 1);
203 let point3 = Point::new(2, 2);
204
205 assert_eq!(manhattan_distance(&point1, &point2), 2);
206 assert_eq!(manhattan_distance(&point1, &point3), 4);
207 assert_eq!(manhattan_distance(&point2, &point3), 2);
208 }
209
210 #[test]
211 fn test_get_shortest() {
212 #[rustfmt::skip]
213 let grid = Grid::from_2d(vec![
214 vec![false, false, false, false, false,],
215 vec![true, true, true, false, false],
216 vec![false, false, false, false, false],
217 vec![false, true, true, true, true],
218 vec![false, false, false, false, false],
219 ]);
220
221 let start = Point::new(0, 0);
222 let end = Point::new(4, 4);
223
224 let path = astar(&grid, start, end).unwrap();
225
226 assert_eq!(path.first(), Some(&start));
227 assert_eq!(path.last(), Some(&end));
228 assert_eq!(path.len(), 15);
229
230 for window in path.windows(2) {
231 let current = window[0];
232 let next = window[1];
233 assert_eq!(
234 manhattan_distance(¤t, &next),
235 1,
236 "Points {current:?} and {next:?} are not adjacent"
237 );
238 }
239
240 for point in &path {
241 assert!(
242 !grid.get(point.x, point.y).unwrap(),
243 "Path contains collision at {point:?}"
244 );
245 }
246 }
247}