alg_grid/
two_dim.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4use core::ops::Add;
5use core::cmp::{max, min};
6use heapless::{
7    Vec,
8    binary_heap::{BinaryHeap, Max},
9    spsc::Queue,
10    consts::*
11};
12use map_vec::Map;
13use num_rational::BigRational;
14
15/// Representation of a point in two dimensions
16#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
17pub struct Point2D {
18    pub x: i32,
19    pub y: i32,
20}
21
22impl Default for Point2D {
23    fn default() -> Self {
24        Self {
25            x: 0,
26            y: 0,
27        }
28    }
29}
30
31impl Add for Point2D {
32    type Output = Self;
33    fn add(self, other: Self) -> Self {
34        Self {
35            x: self.x + other.x,
36            y: self.y + other.y,
37        }
38    }
39}
40
41/// Trait for implementing a Grid in two dimensions
42pub trait Grid2D {
43    /// The dimensions of the Grid. The only method that must be defined.
44    fn dimensions(&self) -> Point2D;
45
46    /// The lower bound. Defaults to (0, 0).
47    fn lower_bound(&self) -> Point2D {
48        Default::default()
49    }
50
51    /// The upper bound. Defaults to the dimensions itself.
52    fn upper_bound(&self) -> Point2D {
53        let dim = self.dimensions();
54        let low = self.lower_bound();
55        Point2D {
56            x: low.x + dim.x,
57            y: low.y + dim.y,
58        }
59    }
60
61    /// Check if a point is in the bounds of the grid.
62    fn in_bounds(&self, point: Point2D) -> bool {
63        let low = self.lower_bound();
64        let upp = self.upper_bound();
65        low.x <= point.x && point.x < upp.x &&
66        low.y <= point.y && point.y < upp.y
67    }
68
69    /// Convert a point to an index.
70    /// 
71    /// Useful if you store the grid in a one-dimensional array.
72    fn point2d_to_index(&self, point: Point2D) -> usize {
73        let dim = self.dimensions();
74        (point.y * dim.x + point.x) as usize
75    }
76
77    /// Convert an index to a point.
78    /// 
79    /// Useful if you store the grid in a one-dimensional array.
80    fn index_to_point2d(&self, index: usize) -> Point2D {
81        let dim = self.dimensions();
82        let x = index as i32 % dim.x;
83        let y = index as i32 / dim.x;
84        Point2D {
85            x,
86            y,
87        }
88    }
89
90    /// Check if a point is traversable.
91    /// 
92    /// Defaults to always, so you may want to implement
93    /// this.
94    #[allow(unused_variables)]
95    fn is_opaque(&self, point: Point2D) -> bool {
96        true
97    }
98
99    /// Get all possible neighbors of the point, regardless if the
100    /// point or its neighbors is in bounds, opaque, or neither.
101    fn get_possible_neighbors(&self, point: Point2D) -> [Point2D; 8] {
102        [
103            point + Point2D {
104                x: 0,
105                y: 1,
106            },
107            point + Point2D {
108                x: 1,
109                y: 0,
110            },
111            point + Point2D {
112                x: 0,
113                y: -1,
114            },
115            point + Point2D {
116                x: -1,
117                y: 0,
118            },
119            point + Point2D {
120                x: 1,
121                y: 1,
122            },
123            point + Point2D {
124                x: 1,
125                y: -1,
126            },
127            point + Point2D {
128                x: -1,
129                y: -1,
130            },
131            point + Point2D {
132                x: -1,
133                y: 1,
134            },
135        ]
136    }
137
138    /// Check if two points are possible neighbors.
139    /// 
140    /// Does not check if either points are inbounds or non-opaque.
141    fn is_possible_neighbor(&self, p1: Point2D, p2: Point2D) -> bool {
142        self.get_possible_neighbors(p1)
143            .iter()
144            .any(|&x| x == p2)
145    }
146
147    /// Get the neighbors that is in bounds and not opaque.
148    fn get_neighbors(&self, point: Point2D) -> [Option<Point2D>; 8] {
149        let mut arr: [Option<Point2D>; 8] = [None; 8];
150        let possible_neighbors = self.get_possible_neighbors(point);
151        for (i, n) in possible_neighbors.iter().enumerate() {
152            if !self.is_opaque(*n) && self.in_bounds(*n) {
153                arr[i] = Some(*n);
154            }
155        }
156        arr
157    }
158
159    /// Check if two points are neighbors.
160    /// 
161    /// Checks if either points are inbounds or non-opaque.
162    fn is_neighbor(&self, p1: Point2D, p2: Point2D) -> bool {
163        self.get_neighbors(p1)
164            .iter()
165            .any(|&x| x == Some(p2))
166    }
167
168    /// Get the neighbors with the associated cost.
169    /// 
170    /// Defaults to all eight neighbors having a cost of 1.0
171    /// if the neighbor is valid.
172    /// 
173    /// If you want the diagonals to cost sqrt(2), reimplement
174    /// this method yourself.
175    fn get_neighbors_with_cost(&self, point: Point2D) -> [(Option<Point2D>, BigRational); 8] {
176        let mut arr: [(Option<Point2D>, BigRational); 8] = [
177            (None, BigRational::from_float(0.0).unwrap()),
178            (None, BigRational::from_float(0.0).unwrap()),
179            (None, BigRational::from_float(0.0).unwrap()),
180            (None, BigRational::from_float(0.0).unwrap()),
181            (None, BigRational::from_float(0.0).unwrap()),
182            (None, BigRational::from_float(0.0).unwrap()),
183            (None, BigRational::from_float(0.0).unwrap()),
184            (None, BigRational::from_float(0.0).unwrap()),
185        ];
186        let neighbors = self.get_neighbors(point);
187        for (i, n) in neighbors.iter().enumerate() {
188            let one = BigRational::from_float(1.0).unwrap();
189            if n != &None {
190                arr[i] = (*n, one);
191            }
192        }
193        arr
194    }
195}
196
197/// Algorithms to get the distance of two points.
198pub enum Distance2D {
199    Pythagoras,
200    Manhattan,
201    Chebyshev,
202}
203
204impl Distance2D {
205    pub fn distance(&self, start: Point2D, end: Point2D) -> f64 {
206        use Distance2D::*;
207        match self {
208            Pythagoras => {
209                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
210                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
211                libm::sqrt((dx * dx) + (dy * dy))
212            }
213            Manhattan => {
214                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
215                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
216                dx + dy
217            }
218            Chebyshev => {
219                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
220                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
221                if dx > dy {
222                    (dx - dy) + 1.0 * dy
223                } else {
224                    (dy - dx) + 1.0 * dx
225                }
226            }
227        }
228    }
229}
230
231pub fn breadth_first_search(graph: &dyn Grid2D, start: Point2D, goal: Point2D) -> Map<Point2D, Option<Point2D>> {
232    let mut frontier: Queue<Point2D, U16> = Queue::new();
233    frontier.enqueue(start).expect("Failed to enqueue");
234    let mut came_from = Map::<Point2D, Option<Point2D>>::new();
235    came_from.insert(start, None);
236    while !frontier.is_empty() {
237        let current = frontier.dequeue().unwrap();
238        if current == goal { break }
239        for next in &graph.get_neighbors(current) {
240            let next = next.unwrap();
241            if !came_from.contains_key(&next) {
242                frontier.enqueue(next).expect("Failed to enqueue");
243                came_from.insert(next, Some(current));
244            }
245        }
246    }
247    came_from
248}
249
250pub fn dijkstra_search(
251    graph: &dyn Grid2D,
252    start: Point2D,
253    goal: Point2D
254) -> (
255    Map<Point2D, Option<Point2D>>,
256    Map<Point2D, BigRational>,
257) {
258    let mut frontier: BinaryHeap<(BigRational, Point2D), U16, Max> = BinaryHeap::new();
259    
260    let zero1 = BigRational::from_float(0.0).unwrap();
261    let zero2 = BigRational::from_float(0.0).unwrap();
262    
263    frontier.push((zero1, start)).expect("fail to push to heap");
264    let mut came_from = Map::<Point2D, Option<Point2D>>::new();
265    let mut cost_so_far = Map::<Point2D, BigRational>::new();
266    
267    came_from.insert(start, None);
268    cost_so_far.insert(start, zero2);
269    
270    while !frontier.is_empty() {
271        let current = frontier.pop().unwrap();
272        let point = current.1;
273        
274        if current.1 == goal { break }
275        
276        for (next, cost) in &graph.get_neighbors_with_cost(point) {
277            let next = next.unwrap();
278            let new_cost1 = cost_so_far.get(&point).unwrap() + cost;
279            let new_cost2 = cost_so_far.get(&point).unwrap() + cost;
280            
281            if !cost_so_far.contains_key(&next) || new_cost1 < *cost_so_far.get(&next).unwrap() {
282                cost_so_far.insert(next, new_cost1);
283                
284                let priority = new_cost2;
285                
286                frontier.push((priority, next)).expect("fail to push to heap");
287                came_from.insert(next, Some(point));
288            }
289        }
290    }
291    (came_from, cost_so_far)
292}
293
294pub fn a_star_search(
295    graph: &dyn Grid2D,
296    start: Point2D,
297    goal: Point2D
298) -> (
299    Map<Point2D, Option<Point2D>>,
300    Map<Point2D, BigRational>,
301) {
302    use num_bigint::ToBigInt;
303    
304    let mut frontier: BinaryHeap<(BigRational, Point2D), U16, Max> = BinaryHeap::new();
305    
306    let zero1 = BigRational::from_float(0.0).unwrap();
307    let zero2 = BigRational::from_float(0.0).unwrap();
308    
309    frontier.push((zero1, start)).expect("fail to push to heap");
310    
311    let mut came_from = Map::<Point2D, Option<Point2D>>::new();
312    let mut cost_so_far = Map::<Point2D, BigRational>::new();
313    
314    came_from.insert(start, None);
315    cost_so_far.insert(start, zero2);
316    
317    while !frontier.is_empty() {
318        let current = frontier.pop().unwrap();
319        let point = current.1;
320        
321        if current.1 == goal { break }
322        
323        for (next, cost) in &graph.get_neighbors_with_cost(point) {
324            let next = next.unwrap();
325            
326            let new_cost1 = cost_so_far.get(&point).unwrap() + cost;
327            let new_cost2 = cost_so_far.get(&point).unwrap() + cost;
328            
329            if !cost_so_far.contains_key(&next) || new_cost1 < *cost_so_far.get(&next).unwrap() {
330                cost_so_far.insert(next, new_cost1);
331                
332                let h = BigRational::from_integer(
333                    heuristic(goal, next).to_bigint().unwrap()
334                );
335                let priority = new_cost2 + h;
336                
337                frontier.push((priority, next)).expect("fail to push to heap");
338                came_from.insert(next, Some(point));
339            }
340        }
341    }
342    (came_from, cost_so_far)
343}
344
345
346fn heuristic(p1: Point2D, p2: Point2D) -> i32 {
347    (p1.x - p2.x).abs() + (p1.y - p2.y).abs()
348}
349
350pub fn reconstruct_path(
351    came_from: Map<Point2D, Option<Point2D>>,
352    start: Point2D,
353    goal: Point2D
354) -> Vec<Point2D, U16> {
355    let mut current = goal;
356    let mut path = Vec::<Point2D, U16>::new();
357    while current != start {
358        path.push(current).expect("Cannot push to vector");
359        current = came_from.get(&current).unwrap().unwrap();
360    }
361    path.push(start).expect("Cannot push to vector");
362    #[cfg(feature = "reverse_path")]
363    let path = path.iter()
364        .cloned()
365        .rev()
366        .collect::<Vec<_, U16>>();
367    path
368}
369
370// TODO: Create some tests
371#[cfg(test)]
372mod tests {
373    #[test]
374    fn it_works() {
375        assert_eq!(2 + 2, 4);
376    }
377}