Skip to main content

py_pathfinding/algorithms/
astar.rs

1use std::cmp::Ordering;
2
3use std::collections::BinaryHeap;
4use std::collections::HashMap;
5use std::collections::HashSet;
6
7use std::f32::consts::SQRT_2;
8use std::f32::EPSILON;
9
10use std::ops::Sub;
11
12use ndarray::Array;
13
14use ndarray::Array2;
15
16use std::fs::File;
17use std::io::Read;
18
19pub fn absdiff<T>(x: T, y: T) -> T
20where
21    T: Sub<Output = T> + PartialOrd,
22{
23    if x < y {
24        y - x
25    } else {
26        x - y
27    }
28}
29
30#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
31pub struct Point2d {
32    x: i32,
33    y: i32,
34}
35
36impl Point2d {
37    fn add_neighbor(&self, other: (i32, i32)) -> Point2d {
38        //        let (x, y) = other;
39        Point2d {
40            x: self.x + other.0,
41            y: self.y + other.1,
42        }
43    }
44
45    // These functions are just used for the pathfinding crate
46    fn distance(&self, other: &Self) -> u32 {
47        (absdiff(self.x, other.x) + absdiff(self.y, other.y)) as u32
48    }
49
50    fn successors(&self) -> Vec<(Point2d, u32)> {
51        let x = self.x;
52        let y = self.y;
53        vec![
54            Point2d { x: x + 1, y: y },
55            Point2d { x: x - 1, y: y },
56            Point2d { x: x, y: y + 1 },
57            Point2d { x: x, y: y - 1 },
58            Point2d { x: x + 1, y: y + 1 },
59            Point2d { x: x + 1, y: y - 1 },
60            Point2d { x: x - 1, y: y + 1 },
61            Point2d { x: x - 1, y: y - 1 },
62        ]
63        .into_iter()
64        .map(|p| (p, 1))
65        .collect()
66    }
67}
68
69// https://doc.rust-lang.org/std/collections/binary_heap/
70#[derive(Copy, Clone, Debug)]
71struct Node {
72    cost_to_source: f32,
73    total_estimated_cost: f32,
74    position: Point2d,
75    came_from: Point2d,
76}
77
78impl PartialEq for Node {
79    fn eq(&self, other: &Self) -> bool {
80        self.total_estimated_cost - other.total_estimated_cost < EPSILON
81    }
82}
83
84impl PartialOrd for Node {
85    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
86        other
87            .total_estimated_cost
88            .partial_cmp(&self.total_estimated_cost)
89    }
90}
91
92// The result of this implementation doesnt seem to matter - instead what matters, is that it is implemented
93impl Ord for Node {
94    fn cmp(&self, other: &Self) -> Ordering {
95        other
96            .total_estimated_cost
97            .partial_cmp(&self.total_estimated_cost)
98            .unwrap()
99    }
100}
101
102impl Eq for Node {}
103
104fn manhattan_heuristic(source: &Point2d, target: &Point2d) -> f32 {
105    (absdiff(source.x, target.x) + absdiff(source.y, target.y)) as f32
106}
107
108static SQRT_2_MINUS_2: f32 = SQRT_2 - 2.0;
109
110fn octal_heuristic(source: &Point2d, target: &Point2d) -> f32 {
111    let dx = absdiff(source.x, target.x);
112    let dy = absdiff(source.y, target.y);
113    let min = std::cmp::min(dx, dy);
114    dx as f32 + dy as f32 + SQRT_2_MINUS_2 * min as f32
115}
116
117fn euclidean_heuristic(source: &Point2d, target: &Point2d) -> f32 {
118    let x = source.x - target.x;
119    let xx = x * x;
120    let y = source.y - target.y;
121    let yy = y * y;
122    let _sum = xx + yy;
123    ((xx + yy) as f32).sqrt()
124}
125
126fn no_heuristic(_source: &Point2d, target: &Point2d) -> f32 {
127    0.0
128}
129
130fn construct_path(
131    source: &Point2d,
132    target: &Point2d,
133    nodes_map: &HashMap<Point2d, Point2d>,
134) -> Option<Vec<Point2d>> {
135    let mut path = vec![];
136    path.push(*target);
137    let mut pos = nodes_map.get(&target).unwrap();
138    while pos != source {
139        path.push(*pos);
140        pos = nodes_map.get(pos).unwrap();
141    }
142    path.push(*source);
143    path.reverse();
144    Some(path)
145}
146
147// https://doc.rust-lang.org/std/default/trait.Default.html
148// https://stackoverflow.com/questions/19650265/is-there-a-faster-shorter-way-to-initialize-variables-in-a-rust-struct
149struct PathFinder {
150    allow_diagonal: bool,
151    heuristic: String,
152    grid: Array2<u8>,
153    came_from_grid: Array2<u8>,
154}
155
156// https://medium.com/@nicholas.w.swift/easy-a-star-pathfinding-7e6689c7f7b2
157impl PathFinder {
158    fn update_grid(&mut self, grid: Array2<u8>) {
159        self.grid = grid;
160    }
161
162    fn find_path(&self, source: &Point2d, target: &Point2d) -> Option<Vec<Point2d>> {
163        let mut nodes_map = HashMap::new();
164        let mut closed_list = HashSet::new();
165
166        // Add source
167        let mut heap = BinaryHeap::new();
168        heap.push(Node {
169            cost_to_source: 0.0,
170            total_estimated_cost: 0.0,
171            position: *source,
172            came_from: *source,
173        });
174
175        let neighbors;
176        match self.allow_diagonal {
177            true => {
178                neighbors = vec![
179                    ((0, 1), 1.0, 1),
180                    ((1, 0), 1.0, 1),
181                    ((-1, 0), 1.0, 1),
182                    ((0, -1), 1.0, 1),
183                    ((1, 1), SQRT_2, 1),
184                    ((1, -1), SQRT_2, 1),
185                    ((-1, 1), SQRT_2, 1),
186                    ((-1, -1), SQRT_2, 1),
187                ]
188            }
189            false => {
190                neighbors = vec![
191                    ((0, 1), 1.0, 1),
192                    ((1, 0), 1.0, 1),
193                    ((-1, 0), 1.0, 1),
194                    ((0, -1), 1.0, 1),
195                ]
196            }
197        }
198
199        let heuristic: fn(&Point2d, &Point2d) -> f32;
200        match self.heuristic.as_ref() {
201            "manhattan" => heuristic = manhattan_heuristic,
202            "octal" => heuristic = octal_heuristic,
203            "euclidean" => heuristic = euclidean_heuristic,
204            "none" => heuristic = no_heuristic,
205            _ => heuristic = euclidean_heuristic,
206        }
207
208        while let Some(Node {
209            cost_to_source,
210            position,
211            came_from,
212            ..
213        }) = heap.pop()
214        {
215            // Already checked this position
216            if closed_list.contains(&position) {
217                continue;
218            }
219
220            nodes_map.insert(position, came_from);
221
222            if position == *target {
223                return construct_path(&source, &target, &nodes_map);
224            }
225
226            closed_list.insert(position);
227
228            for (neighbor, real_cost, _cost_estimate) in neighbors.iter() {
229                let new_node = position.add_neighbor(*neighbor);
230                // TODO add cost from grid
231                //  if grid point has value == 0 (or -1?): is wall
232
233                let new_cost_to_source = cost_to_source + *real_cost;
234
235                // Should perhaps check if position is already in open list, but doesnt matter
236                heap.push(Node {
237                    cost_to_source: new_cost_to_source,
238                    total_estimated_cost: new_cost_to_source + heuristic(&new_node, target),
239                    position: new_node,
240                    came_from: position,
241                });
242            }
243        }
244        None
245    }
246}
247
248pub fn grid_setup(size: usize) -> Array2<u8> {
249    // Set up a grid with size 'size' and make the borders a wall (value 1)
250    // https://stackoverflow.com/a/59043086/10882657
251    let mut ndarray = Array2::<u8>::ones((size, size));
252    // Set boundaries
253    for y in 0..size {
254        ndarray[[y, 0]] = 0;
255        ndarray[[y, size - 1]] = 0;
256    }
257    for x in 0..size {
258        ndarray[[0, x]] = 0;
259        ndarray[[size - 1, x]] = 0;
260    }
261    ndarray
262}
263
264pub fn read_grid_from_file(path: String) -> Result<(Array2<u8>, u32, u32), std::io::Error> {
265    let mut file = File::open(path)?;
266    //    let mut data = Vec::new();
267    let mut data = String::new();
268
269    file.read_to_string(&mut data)?;
270    let mut height = 0;
271    let mut width = 0;
272    // Create one dimensional vec
273    let mut my_vec = Vec::new();
274    for line in data.lines() {
275        width = line.len();
276        height += 1;
277        for char in line.chars() {
278            my_vec.push(char as u8 - 48);
279        }
280    }
281
282    let array = Array::from(my_vec).into_shape((height, width)).unwrap();
283    Ok((array, height as u32, width as u32))
284}
285
286#[cfg(test)] // Only compiles when running tests
287mod tests {
288    use super::*;
289    #[allow(unused_imports)]
290    use test::Bencher;
291
292    fn astar_pf(grid: Array2<u8>) -> PathFinder {
293        let came_from_grid = Array::zeros(grid.raw_dim());
294        PathFinder {
295            allow_diagonal: true,
296            heuristic: String::from("manhattan"),
297            grid,
298            came_from_grid,
299        }
300    }
301
302    fn astar_test(pf: &mut PathFinder, source: &Point2d, target: &Point2d) -> Option<Vec<Point2d>> {
303        let path = pf.find_path(source, target);
304        path
305    }
306
307    #[bench]
308    fn bench_astar_test_from_file(b: &mut Bencher) {
309        let result = read_grid_from_file(String::from("AutomatonLE.txt"));
310        let (array, _height, _width) = result.unwrap();
311        // Spawn to spawn
312        let source = Point2d { x: 32, y: 51 };
313        let target = Point2d { x: 150, y: 129 };
314        // Main ramp to main ramp
315        //        let source = Point2d { x: 32, y: 51 };
316        //        let target = Point2d { x: 150, y: 129 };
317        let mut pf = astar_pf(array);
318        b.iter(|| astar_test(&mut pf, &source, &target));
319    }
320
321    #[bench]
322    fn bench_astar_test(b: &mut Bencher) {
323        let grid = grid_setup(30);
324        let mut pf = astar_pf(grid);
325        let source: Point2d = Point2d { x: 5, y: 5 };
326        let target: Point2d = Point2d { x: 10, y: 12 };
327        b.iter(|| astar_test(&mut pf, &source, &target));
328    }
329}