alg_grid/
three_dim.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4use core::ops::Add;
5use core::cmp::{max, min};
6use heapless::{
7    Vec,
8    binary_heap::{BinaryHeap, Max},
9    spsc::Queue,
10    consts::*
11};
12use map_vec::Map;
13use num_rational::BigRational;
14
15/// Representation of a point in three dimensions
16#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
17pub struct Point3D {
18    pub x: i32,
19    pub y: i32,
20    pub z: i32,
21}
22
23impl Default for Point3D {
24    fn default() -> Self {
25        Self {
26            x: 0,
27            y: 0,
28            z: 0,
29        }
30    }
31}
32
33impl Add for Point3D {
34    type Output = Self;
35    fn add(self, other: Self) -> Self {
36        Self {
37            x: self.x + other.x,
38            y: self.y + other.y,
39            z: self.z + other.z,
40        }
41    }
42}
43
44/// Trait for implementing a Grid in three dimensions
45pub trait Grid3D {
46    /// The dimensions of the Grid. The only method that must be defined.
47    fn dimensions(&self) -> Point3D;
48
49    /// The lower bound. Defaults to (0, 0, 0).
50    fn lower_bound(&self) -> Point3D {
51        Default::default()
52    }
53
54    /// The upper bound. Defaults to the dimensions itself.
55    fn upper_bound(&self) -> Point3D {
56        let dim = self.dimensions();
57        let low = self.lower_bound();
58        Point3D {
59            x: low.x + dim.x,
60            y: low.y + dim.y,
61            z: low.z + dim.z
62        }
63    }
64
65    /// Check if a point is in the bounds of the grid.
66    fn in_bounds(&self, point: Point3D) -> bool {
67        let low = self.lower_bound();
68        let upp = self.upper_bound();
69        low.x <= point.x && point.x < upp.x &&
70        low.y <= point.y && point.y < upp.y
71    }
72
73    /// Convert a point to an index.
74    /// 
75    /// Useful if you store the grid in a one-dimensional array.
76    fn point3d_to_index(&self, point: Point3D) -> usize {
77        let dim = self.dimensions();
78        (
79            point.z * dim.y * dim.x +
80            point.y * dim.x +
81            point.x
82        ) as usize
83    }
84
85    /// Convert an index to a point.
86    /// 
87    /// Useful if you store the grid in a one-dimensional array.
88    fn index_to_point3d(&self, index: usize) -> Point3D {
89        let dim = self.dimensions();
90        let mut idx = index as i32;
91        let z = idx / (dim.x * dim.y);
92        idx -= z * dim.x * dim.y;
93        let y = idx / dim.x;
94        let x = idx % dim.x;
95        Point3D {
96            x,
97            y,
98            z,
99        }
100    }
101
102    /// Check if a point is traversable.
103    /// 
104    /// Defaults to always, so you may want to implement
105    /// this.
106    #[allow(unused_variables)]
107    fn is_opaque(&self, point: Point3D) -> bool {
108        true
109    }
110
111    /// Get all possible neighbors of the point, regardless if the
112    /// point or its neighbors is in bounds, opaque, or neither.
113    fn get_possible_neighbors(&self, point: Point3D) -> [Point3D; 26] {
114        [
115            // centers
116            point + Point3D {
117                x: 0,
118                y: 1,
119                z: 0,
120            },
121            point + Point3D {
122                x: 0,
123                y: -1,
124                z: 0,
125            },
126            point + Point3D {
127                x: 1,
128                y: 0,
129                z: 0,
130            },
131            point + Point3D {
132                x: -1,
133                y: 0,
134                z: 0,
135            },
136            point + Point3D {
137                x: 0,
138                y: 0,
139                z: 1,
140            },
141            point + Point3D {
142                x: 0,
143                y: 0,
144                z: -1,
145            },
146            // sides
147            point + Point3D {
148                x: 1,
149                y: 1,
150                z: 0,
151            },
152            point + Point3D {
153                x: 1,
154                y: -1,
155                z: 0,
156            },
157            point + Point3D {
158                x: -1,
159                y: -1,
160                z: 0,
161            },
162            point + Point3D {
163                x: -1,
164                y: 0,
165                z: 0,
166            },
167            point + Point3D {
168                x: 0,
169                y: 1,
170                z: 1,
171            },
172            point + Point3D {
173                x: 0,
174                y: 1,
175                z: -1,
176            },
177            point + Point3D {
178                x: 0,
179                y: -1,
180                z: -1,
181            },
182            point + Point3D {
183                x: 0,
184                y: -1,
185                z: 1,
186            },
187            point + Point3D {
188                x: 1,
189                y: 0,
190                z: 1,
191            },
192            point + Point3D {
193                x: 1,
194                y: 0,
195                z: -1,
196            },
197            point + Point3D {
198                x: -1,
199                y: 0,
200                z: -1,
201            },
202            point + Point3D {
203                x: -1,
204                y: 0,
205                z: 1,
206            },
207            // corners
208            point + Point3D {
209                x: 1,
210                y: 1,
211                z: 1,
212            },
213            point + Point3D {
214                x: 1,
215                y: -1,
216                z: 1,
217            },
218            point + Point3D {
219                x: -1,
220                y: -1,
221                z: 1,
222            },
223            point + Point3D {
224                x: -1,
225                y: 1,
226                z: 1,
227            },
228            point + Point3D {
229                x: 1,
230                y: 1,
231                z: -1,
232            },
233            point + Point3D {
234                x: 1,
235                y: -1,
236                z: -1,
237            },
238            point + Point3D {
239                x: -1,
240                y: -1,
241                z: -1,
242            },
243            point + Point3D {
244                x: -1,
245                y: 1,
246                z: -1,
247            },
248        ]
249    }
250
251    /// Check if two points are possible neighbors.
252    /// 
253    /// Does not check if either points are inbounds or non-opaque.
254    fn is_possible_neighbor(&self, p1: Point3D, p2: Point3D) -> bool {
255        self.get_possible_neighbors(p1)
256            .iter()
257            .any(|&x| x == p2)
258    }
259
260    /// Get the neighbors that is in bounds and not opaque.
261    fn get_neighbors(&self, point: Point3D) -> [Option<Point3D>; 26] {
262        let mut arr: [Option<Point3D>; 26] = [None; 26];
263        let possible_neighbors = self.get_possible_neighbors(point);
264        for (i, n) in possible_neighbors.iter().enumerate() {
265            if !self.is_opaque(*n) && self.in_bounds(*n) {
266                arr[i] = Some(*n);
267            }
268        }
269        arr
270    }
271
272    /// Check if two points are neighbors.
273    /// 
274    /// Checks if either points are inbounds or non-opaque.
275    fn is_neighbor(&self, p1: Point3D, p2: Point3D) -> bool {
276        self.get_neighbors(p1)
277            .iter()
278            .any(|&x| x == Some(p2))
279    }
280
281    /// Get the neighbors with the associated cost.
282    /// 
283    /// Defaults to all eight neighbors having a cost of 1.0
284    /// if the neighbor is valid.
285    /// 
286    /// If you want the diagonals to cost sqrt(2), reimplement
287    /// this method yourself.
288    fn get_neighbors_with_cost(&self, point: Point3D) -> [(Option<Point3D>, BigRational); 26] {
289        let mut arr: [(Option<Point3D>, BigRational); 26] = [
290            (None, BigRational::from_float(0.0).unwrap()),
291            (None, BigRational::from_float(0.0).unwrap()),
292            (None, BigRational::from_float(0.0).unwrap()),
293            (None, BigRational::from_float(0.0).unwrap()),
294            (None, BigRational::from_float(0.0).unwrap()),
295            (None, BigRational::from_float(0.0).unwrap()),
296            (None, BigRational::from_float(0.0).unwrap()),
297            (None, BigRational::from_float(0.0).unwrap()),
298            (None, BigRational::from_float(0.0).unwrap()),
299            (None, BigRational::from_float(0.0).unwrap()),
300            (None, BigRational::from_float(0.0).unwrap()),
301            (None, BigRational::from_float(0.0).unwrap()),
302            (None, BigRational::from_float(0.0).unwrap()),
303            (None, BigRational::from_float(0.0).unwrap()),
304            (None, BigRational::from_float(0.0).unwrap()),
305            (None, BigRational::from_float(0.0).unwrap()),
306            (None, BigRational::from_float(0.0).unwrap()),
307            (None, BigRational::from_float(0.0).unwrap()),
308            (None, BigRational::from_float(0.0).unwrap()),
309            (None, BigRational::from_float(0.0).unwrap()),
310            (None, BigRational::from_float(0.0).unwrap()),
311            (None, BigRational::from_float(0.0).unwrap()),
312            (None, BigRational::from_float(0.0).unwrap()),
313            (None, BigRational::from_float(0.0).unwrap()),
314            (None, BigRational::from_float(0.0).unwrap()),
315            (None, BigRational::from_float(0.0).unwrap()),
316        ];
317        let neighbors = self.get_neighbors(point);
318        for (i, n) in neighbors.iter().enumerate() {
319            let one = BigRational::from_float(1.0).unwrap();
320            if n != &None {
321                arr[i] = (*n, one);
322            }
323        }
324        arr
325    }
326}
327
328/// Algorithms to get the distance of two points.
329pub enum Distance2D {
330    Pythagoras,
331    Manhattan,
332    Chebyshev,
333}
334
335impl Distance2D {
336    pub fn distance(&self, start: Point3D, end: Point3D) -> f64 {
337        use Distance2D::*;
338        match self {
339            Pythagoras => {
340                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
341                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
342                let dz = (max(start.z, end.z) - min(start.z, end.z)) as f64;
343                libm::sqrt((dx * dx) + (dy * dy) + (dz * dz))
344            }
345            Manhattan => {
346                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
347                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
348                let dz = (max(start.z, end.z) - min(start.z, end.z)) as f64;
349                dx + dy + dz
350            }
351            Chebyshev => {
352                let dx = (max(start.x, end.x) - min(start.x, end.x)) as f64;
353                let dy = (max(start.y, end.y) - min(start.y, end.y)) as f64;
354                let dz = (max(start.z, end.z) - min(start.z, end.z)) as f64;
355                libm::sqrt((dx * dx) + (dy * dy) + (dz * dz))
356            }
357        }
358    }
359}
360
361pub fn breadth_first_search(graph: &dyn Grid3D, start: Point3D, goal: Point3D) -> Map<Point3D, Option<Point3D>> {
362    let mut frontier: Queue<Point3D, U16> = Queue::new();
363    frontier.enqueue(start).expect("Failed to enqueue");
364    let mut came_from = Map::<Point3D, Option<Point3D>>::new();
365    came_from.insert(start, None);
366    while !frontier.is_empty() {
367        let current = frontier.dequeue().unwrap();
368        if current == goal { break }
369        for next in &graph.get_neighbors(current) {
370            let next = next.unwrap();
371            if !came_from.contains_key(&next) {
372                frontier.enqueue(next).expect("Failed to enqueue");
373                came_from.insert(next, Some(current));
374            }
375        }
376    }
377    came_from
378}
379
380pub fn dijkstra_search(
381    graph: &dyn Grid3D,
382    start: Point3D,
383    goal: Point3D
384) -> (
385    Map<Point3D, Option<Point3D>>,
386    Map<Point3D, BigRational>,
387) {
388    let mut frontier: BinaryHeap<(BigRational, Point3D), U16, Max> = BinaryHeap::new();
389    
390    let zero1 = BigRational::from_float(0.0).unwrap();
391    let zero2 = BigRational::from_float(0.0).unwrap();
392    
393    frontier.push((zero1, start)).expect("fail to push to heap");
394    let mut came_from = Map::<Point3D, Option<Point3D>>::new();
395    let mut cost_so_far = Map::<Point3D, BigRational>::new();
396    
397    came_from.insert(start, None);
398    cost_so_far.insert(start, zero2);
399    
400    while !frontier.is_empty() {
401        let current = frontier.pop().unwrap();
402        let point = current.1;
403        
404        if current.1 == goal { break }
405        
406        for (next, cost) in &graph.get_neighbors_with_cost(point) {
407            let next = next.unwrap();
408            let new_cost1 = cost_so_far.get(&point).unwrap() + cost;
409            let new_cost2 = cost_so_far.get(&point).unwrap() + cost;
410            
411            if !cost_so_far.contains_key(&next) || new_cost1 < *cost_so_far.get(&next).unwrap() {
412                cost_so_far.insert(next, new_cost1);
413                
414                let priority = new_cost2;
415                
416                frontier.push((priority, next)).expect("fail to push to heap");
417                came_from.insert(next, Some(point));
418            }
419        }
420    }
421    (came_from, cost_so_far)
422}
423
424pub fn a_star_search(
425    graph: &dyn Grid3D,
426    start: Point3D,
427    goal: Point3D
428) -> (
429    Map<Point3D, Option<Point3D>>,
430    Map<Point3D, BigRational>,
431) {
432    use num_bigint::ToBigInt;
433    
434    let mut frontier: BinaryHeap<(BigRational, Point3D), U16, Max> = BinaryHeap::new();
435    
436    let zero1 = BigRational::from_float(0.0).unwrap();
437    let zero2 = BigRational::from_float(0.0).unwrap();
438    
439    frontier.push((zero1, start)).expect("fail to push to heap");
440    
441    let mut came_from = Map::<Point3D, Option<Point3D>>::new();
442    let mut cost_so_far = Map::<Point3D, BigRational>::new();
443    
444    came_from.insert(start, None);
445    cost_so_far.insert(start, zero2);
446    
447    while !frontier.is_empty() {
448        let current = frontier.pop().unwrap();
449        let point = current.1;
450        
451        if current.1 == goal { break }
452        
453        for (next, cost) in &graph.get_neighbors_with_cost(point) {
454            let next = next.unwrap();
455            
456            let new_cost1 = cost_so_far.get(&point).unwrap() + cost;
457            let new_cost2 = cost_so_far.get(&point).unwrap() + cost;
458            
459            if !cost_so_far.contains_key(&next) || new_cost1 < *cost_so_far.get(&next).unwrap() {
460                cost_so_far.insert(next, new_cost1);
461                
462                let h = BigRational::from_integer(
463                    heuristic(goal, next).to_bigint().unwrap()
464                );
465                let priority = new_cost2 + h;
466                
467                frontier.push((priority, next)).expect("fail to push to heap");
468                came_from.insert(next, Some(point));
469            }
470        }
471    }
472    (came_from, cost_so_far)
473}
474
475
476fn heuristic(p1: Point3D, p2: Point3D) -> i32 {
477    (p1.x - p2.x).abs() + (p1.y - p2.y).abs()
478}
479
480pub fn reconstruct_path(
481    came_from: Map<Point3D, Option<Point3D>>,
482    start: Point3D,
483    goal: Point3D
484) -> Vec<Point3D, U16> {
485    let mut current = goal;
486    let mut path = Vec::<Point3D, U16>::new();
487    while current != start {
488        path.push(current).expect("Cannot push to vector");
489        current = came_from.get(&current).unwrap().unwrap();
490    }
491    path.push(start).expect("Cannot push to vector");
492    #[cfg(feature = "reverse_path")]
493    let path = path.iter()
494        .cloned()
495        .rev()
496        .collect::<Vec<_, U16>>();
497    path
498}
499
500// TODO: Create some tests
501#[cfg(test)]
502mod tests {
503    #[test]
504    fn it_works() {
505        assert_eq!(2 + 2, 4);
506    }
507}