seastar/
lib.rs

1#![doc = include_str!("../README.md")]
2mod grid;
3mod node;
4mod point;
5
6use std::collections::BinaryHeap;
7
8pub use grid::Grid;
9pub use node::Node;
10pub use point::Point;
11
12/// Attempts to find the shortest path from `start` to `end` using the A*
13/// algorithm. Returns `None` if no path is found.
14#[must_use]
15pub fn astar(grid: &Grid, start: Point, end: Point) -> Option<Vec<Point>> {
16    let width = grid.width();
17    let height = grid.height();
18    let capacity = width * height;
19
20    let mut open_nodes = BinaryHeap::new(); // Min-heap; see `Node` impl
21    let mut closed_nodes = vec![false; capacity];
22    let mut g_scores = vec![u16::MAX; capacity];
23    let mut all_nodes = Vec::with_capacity(capacity);
24
25    let start_node = Node {
26        point: start,
27        g: 0,
28        h: manhattan_distance(&start, &end),
29        parent_index: None,
30    };
31
32    let start_index = point_to_index(start, width);
33    g_scores[start_index] = 0;
34    all_nodes.push(start_node);
35    open_nodes.push((0, 0));
36
37    while let Some((_f_score, current_index)) = open_nodes.pop() {
38        let current = all_nodes[current_index];
39
40        if current.point == end {
41            return Some(retrace_path(&all_nodes, current_index));
42        }
43
44        let current_point_index = point_to_index(current.point, width);
45        if closed_nodes[current_point_index] {
46            continue;
47        }
48        closed_nodes[current_point_index] = true;
49
50        let current_g = current.g;
51
52        for neighbor_point in get_neighbor_points(grid, current.point) {
53            let neighbor_index = point_to_index(neighbor_point, width);
54            if closed_nodes[neighbor_index] {
55                continue;
56            }
57
58            let tentative_g = current_g + 1;
59
60            if tentative_g >= g_scores[neighbor_index] as isize {
61                continue;
62            }
63
64            let h = manhattan_distance(&neighbor_point, &end);
65            let f = tentative_g + h;
66
67            let neighbor = Node {
68                point: neighbor_point,
69                g: tentative_g,
70                h,
71                parent_index: Some(current_index),
72            };
73
74            g_scores[neighbor_index] = tentative_g as u16;
75            all_nodes.push(neighbor);
76            open_nodes.push((-f, all_nodes.len() - 1));
77        }
78    }
79
80    None
81}
82
83/// Converts a `Point` to an index in a 1D vector.
84#[inline]
85fn point_to_index(point: Point, width: usize) -> usize {
86    point.y as usize * width + point.x as usize
87}
88
89/// Returns the path from start to end as a list of points.
90fn retrace_path(nodes: &[Node], mut current_index: usize) -> Vec<Point> {
91    let start = nodes[0].point;
92    let end = nodes[current_index].point;
93    let initial_capacity = manhattan_distance(&start, &end) as usize;
94    let mut path = Vec::with_capacity(initial_capacity);
95
96    loop {
97        let current = &nodes[current_index];
98        path.push(current.point);
99
100        if let Some(parent_index) = current.parent_index {
101            current_index = parent_index;
102        } else {
103            break;
104        }
105    }
106
107    path.reverse();
108    path
109}
110
111/// Shortest distance between two points.
112#[inline]
113fn manhattan_distance(a: &Point, b: &Point) -> isize {
114    (a.x - b.x).abs() + (a.y - b.y).abs()
115}
116
117// TODO: Support diagonal movement?
118const NEIGHBORS: [(isize, isize); 4] = [(0, 1), (1, 0), (0, -1), (-1, 0)];
119
120fn get_neighbor_points(grid: &Grid, point: Point) -> impl Iterator<Item = Point> + '_ {
121    NEIGHBORS.iter().filter_map(move |&(dx, dy)| {
122        let new_x = point.x + dx;
123        let new_y = point.y + dy;
124        grid.get(new_x, new_y)
125            .filter(|&node| !node)
126            .map(|_| Point::new(new_x, new_y))
127    })
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_valid_path() {
136        #[rustfmt::skip]
137        let grid = Grid::from_2d(vec![
138            vec![false, false, false],
139            vec![true,  false, true ],
140            vec![false, false, false],
141        ]);
142
143        let start = Point::new(0, 0);
144        let end = Point::new(2, 2);
145
146        let path = astar(&grid, start, end).unwrap();
147
148        assert_eq!(
149            path,
150            vec![
151                start,
152                Point::new(1, 0),
153                Point::new(1, 1),
154                Point::new(1, 2),
155                end
156            ]
157        );
158    }
159
160    #[test]
161    fn test_no_valid_path() {
162        #[rustfmt::skip]
163        let grid = Grid::from_2d(vec![
164            vec![false, false, false],
165            vec![false, true,  true ],
166            vec![false, true,  false],
167        ]);
168
169        let start = Point::new(0, 0);
170        let end = Point::new(2, 2);
171
172        let path = astar(&grid, start, end);
173
174        assert!(path.is_none());
175    }
176
177    #[test]
178    fn test_collidable_neighbors() {
179        #[rustfmt::skip]
180        let grid = Grid::from_2d(vec![
181            vec![false,     false, false],
182            vec![true,     false, false],
183            vec![false,     false, false],
184        ]);
185
186        let point = Point::new(0, 0);
187        let neighbors = get_neighbor_points(&grid, point);
188        assert_eq!(neighbors.count(), 1);
189
190        let point2 = Point::new(1, 1);
191        let neighbors2 = get_neighbor_points(&grid, point2);
192        assert_eq!(neighbors2.count(), 3);
193
194        let point3 = Point::new(2, 1);
195        let neighbors3 = get_neighbor_points(&grid, point3);
196        assert_eq!(neighbors3.count(), 3);
197    }
198
199    #[test]
200    fn test_distance() {
201        let point1 = Point::new(0, 0);
202        let point2 = Point::new(1, 1);
203        let point3 = Point::new(2, 2);
204
205        assert_eq!(manhattan_distance(&point1, &point2), 2);
206        assert_eq!(manhattan_distance(&point1, &point3), 4);
207        assert_eq!(manhattan_distance(&point2, &point3), 2);
208    }
209
210    #[test]
211    fn test_get_shortest() {
212        #[rustfmt::skip]
213        let grid = Grid::from_2d(vec![
214            vec![false,     false,     false,     false,     false,],
215            vec![true,     true,     true,     false,     false],
216            vec![false,     false,     false,     false,     false],
217            vec![false,     true,     true,     true,     true],
218            vec![false,     false,     false,     false,     false],
219        ]);
220
221        let start = Point::new(0, 0);
222        let end = Point::new(4, 4);
223
224        let path = astar(&grid, start, end).unwrap();
225
226        assert_eq!(path.first(), Some(&start));
227        assert_eq!(path.last(), Some(&end));
228        assert_eq!(path.len(), 15);
229
230        for window in path.windows(2) {
231            let current = window[0];
232            let next = window[1];
233            assert_eq!(
234                manhattan_distance(&current, &next),
235                1,
236                "Points {current:?} and {next:?} are not adjacent"
237            );
238        }
239
240        for point in &path {
241            assert!(
242                !grid.get(point.x, point.y).unwrap(),
243                "Path contains collision at {point:?}"
244            );
245        }
246    }
247}