grid_search/
search.rs

1use best::BestMapNonEmpty;
2use config::*;
3use direction::*;
4use distance_map::*;
5use error::*;
6use grid::*;
7use grid_2d::*;
8use metadata::*;
9use num_traits::{One, Zero};
10use path::{self, PathNode};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13use std::ops::{Add, Sub};
14
15#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
16#[derive(Debug, Clone, Copy)]
17pub(crate) struct SearchNode<Cost> {
18    pub(crate) seen: u64,
19    pub(crate) visited: u64,
20    pub(crate) coord: Coord,
21    pub(crate) from_parent: Option<Direction>,
22    pub(crate) cost: Cost,
23}
24
25impl<Cost: Zero> SearchNode<Cost> {
26    fn new(coord: Coord) -> Self {
27        Self {
28            seen: 0,
29            visited: 0,
30            coord,
31            from_parent: None,
32            cost: Zero::zero(),
33        }
34    }
35}
36
37impl<Cost> PathNode for SearchNode<Cost> {
38    fn from_parent(&self) -> Option<Direction> {
39        self.from_parent
40    }
41    fn coord(&self) -> Coord {
42        self.coord
43    }
44}
45
46#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
47#[derive(Debug, Clone)]
48pub(crate) struct PriorityEntry<Cost: PartialOrd<Cost>> {
49    pub(crate) node_index: usize,
50    pub(crate) cost: Cost,
51}
52
53impl<Cost: PartialOrd<Cost>> PriorityEntry<Cost> {
54    pub(crate) fn new(node_index: usize, cost: Cost) -> Self {
55        Self { node_index, cost }
56    }
57}
58
59impl<Cost: PartialOrd<Cost>> PartialEq for PriorityEntry<Cost> {
60    fn eq(&self, other: &Self) -> bool {
61        self.cost == other.cost
62    }
63}
64
65impl<Cost: PartialOrd<Cost>> PartialOrd for PriorityEntry<Cost> {
66    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
67        other.cost.partial_cmp(&self.cost)
68    }
69}
70
71impl<Cost: PartialOrd<Cost>> Eq for PriorityEntry<Cost> {}
72
73impl<Cost: PartialOrd<Cost>> Ord for PriorityEntry<Cost> {
74    fn cmp(&self, other: &Self) -> Ordering {
75        other
76            .cost
77            .partial_cmp(&self.cost)
78            .unwrap_or(Ordering::Equal)
79    }
80}
81
82#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
83#[derive(Debug, Clone)]
84pub struct SearchContext<Cost: PartialOrd<Cost>> {
85    pub(crate) seq: u64,
86    pub(crate) priority_queue: BinaryHeap<PriorityEntry<Cost>>,
87    pub(crate) node_grid: Grid<SearchNode<Cost>>,
88}
89
90impl<Cost: PartialOrd<Cost> + Zero> SearchContext<Cost> {
91    pub fn new(size: Size) -> Self {
92        Self {
93            seq: 0,
94            node_grid: Grid::new_fn(size, SearchNode::new),
95            priority_queue: BinaryHeap::new(),
96        }
97    }
98    pub fn width(&self) -> u32 {
99        self.node_grid.width()
100    }
101    pub fn height(&self) -> u32 {
102        self.node_grid.height()
103    }
104    pub fn size(&self) -> Size {
105        self.node_grid.size()
106    }
107}
108
109impl<Cost: Copy + Add<Cost> + PartialOrd<Cost> + Zero> SearchContext<Cost> {
110    pub(crate) fn init<G, F>(
111        &mut self,
112        start: Coord,
113        predicate: F,
114        grid: &G,
115        config: SearchConfig,
116        path: &mut Vec<Direction>,
117    ) -> Result<PriorityEntry<Cost>, Result<SearchMetadata<Cost>, Error>>
118    where
119        G: SolidGrid,
120        F: Fn(Coord) -> bool,
121    {
122        if let Some(solid) = grid.is_solid(start) {
123            let index = if let Some(index) = self.node_grid.index_of_coord(start) {
124                index
125            } else {
126                return Err(Err(Error::VisitOutsideContext));
127            };
128
129            if solid && !config.allow_solid_start {
130                return Err(Err(Error::StartSolid));
131            };
132
133            if predicate(start) {
134                path.clear();
135                return Err(Ok(SearchMetadata {
136                    num_nodes_visited: 0,
137                    cost: Zero::zero(),
138                    length: 0,
139                }));
140            }
141
142            self.seq += 1;
143            self.priority_queue.clear();
144
145            let node = &mut self.node_grid[index];
146            node.from_parent = None;
147            node.seen = self.seq;
148            node.cost = Zero::zero();
149
150            Ok(PriorityEntry::new(index, Zero::zero()))
151        } else {
152            Err(Err(Error::StartOutsideGrid))
153        }
154    }
155
156    pub(crate) fn search_general<G, V, D, H>(
157        &mut self,
158        grid: &G,
159        start: Coord,
160        goal: Coord,
161        directions: D,
162        heuristic_fn: H,
163        config: SearchConfig,
164        path: &mut Vec<Direction>,
165    ) -> Result<SearchMetadata<Cost>, Error>
166    where
167        G: CostGrid<Cost = Cost>,
168        V: Into<Direction>,
169        D: Copy + IntoIterator<Item = V>,
170        H: Fn(Coord, Coord) -> Cost,
171    {
172        let initial_entry = match self.init(start, |c| c == goal, grid, config, path) {
173            Ok(initial_entry) => initial_entry,
174            Err(result) => return result,
175        };
176
177        self.priority_queue.push(initial_entry);
178
179        let goal_index = self
180            .node_grid
181            .index_of_coord(goal)
182            .ok_or(Error::VisitOutsideContext)?;
183
184        let mut num_nodes_visited = 0;
185
186        while let Some(current_entry) = self.priority_queue.pop() {
187            num_nodes_visited += 1;
188
189            if current_entry.node_index == goal_index {
190                let node = &self.node_grid[goal_index];
191
192                path::make_path_all_adjacent(&self.node_grid, goal_index, path);
193                return Ok(SearchMetadata {
194                    num_nodes_visited,
195                    cost: node.cost,
196                    length: path.len(),
197                });
198            }
199
200            let (current_coord, current_cost) = {
201                let node = &mut self.node_grid[current_entry.node_index];
202                if node.visited == self.seq {
203                    continue;
204                }
205                node.visited = self.seq;
206                (node.coord, node.cost)
207            };
208
209            for d in directions {
210                let direction = d.into();
211                let neighbour_coord = current_coord + direction.coord();
212
213                let neighbour_cost =
214                    if let Some(CostCell::Cost(cost)) = grid.cost(neighbour_coord, direction) {
215                        cost
216                    } else {
217                        continue;
218                    };
219
220                self.see_successor(
221                    current_cost + neighbour_cost,
222                    neighbour_coord,
223                    direction,
224                    &heuristic_fn,
225                    goal,
226                )?;
227            }
228        }
229
230        Err(Error::NoPath)
231    }
232
233    pub(crate) fn see_successor<H>(
234        &mut self,
235        cost: Cost,
236        successor_coord: Coord,
237        direction: Direction,
238        heuristic_fn: H,
239        goal: Coord,
240    ) -> Result<(), Error>
241    where
242        H: Fn(Coord, Coord) -> Cost,
243    {
244        let index = self
245            .node_grid
246            .index_of_coord(successor_coord)
247            .ok_or(Error::VisitOutsideContext)?;
248
249        let node = &mut self.node_grid[index];
250
251        if node.seen != self.seq || node.cost > cost {
252            node.from_parent = Some(direction);
253            node.seen = self.seq;
254            node.cost = cost;
255
256            let heuristic = cost + heuristic_fn(successor_coord, goal);
257            let entry = PriorityEntry::new(index, heuristic);
258            self.priority_queue.push(entry);
259        }
260
261        Ok(())
262    }
263}
264
265impl<Cost> SearchContext<Cost>
266where
267    Cost: Copy + Add + PartialOrd + Zero + One,
268{
269    pub fn populate_distance_map<G, V, D>(
270        &mut self,
271        grid: &G,
272        start: Coord,
273        directions: D,
274        config: SearchConfig,
275        distance_map: &mut DistanceMap<Cost>,
276    ) -> Result<DistanceMapMetadata, Error>
277    where
278        G: CostGrid<Cost = Cost>,
279        V: Into<Direction>,
280        D: Copy + IntoIterator<Item = V>,
281    {
282        if let Some(solid) = grid.is_solid(start) {
283            if solid && !config.allow_solid_start {
284                return Err(Error::StartSolid);
285            };
286
287            let index = distance_map
288                .grid
289                .index_of_coord(start)
290                .ok_or(Error::VisitOutsideDistanceMap)?;
291
292            self.priority_queue.clear();
293            self.priority_queue
294                .push(PriorityEntry::new(index, Zero::zero()));
295
296            distance_map.seq += 1;
297            distance_map.origin = start;
298            let cell = &mut distance_map.grid[index];
299            cell.seen = distance_map.seq;
300            cell.cost = Zero::zero();
301        } else {
302            return Err(Error::StartOutsideGrid);
303        }
304
305        let mut num_nodes_visited = 0;
306
307        while let Some(current_entry) = self.priority_queue.pop() {
308            num_nodes_visited += 1;
309
310            let (current_coord, current_cost) = {
311                let cell = &mut distance_map.grid[current_entry.node_index];
312                if cell.visited == distance_map.seq {
313                    continue;
314                }
315                cell.visited = distance_map.seq;
316                (cell.coord, cell.cost)
317            };
318
319            for d in directions {
320                let direction = d.into();
321                let neighbour_coord = current_coord + direction.coord();
322
323                let neighbour_cost =
324                    if let Some(CostCell::Cost(cost)) = grid.cost(neighbour_coord, direction) {
325                        cost
326                    } else {
327                        continue;
328                    };
329
330                let cost = current_cost + neighbour_cost;
331
332                let index = distance_map
333                    .grid
334                    .index_of_coord(neighbour_coord)
335                    .ok_or(Error::VisitOutsideDistanceMap)?;
336
337                let cell = &mut distance_map.grid[index];
338
339                if cell.seen != distance_map.seq || cell.cost > cost {
340                    cell.direction = direction.opposite();
341                    cell.seen = distance_map.seq;
342                    cell.cost = cost;
343
344                    let entry = PriorityEntry::new(index, cost);
345                    self.priority_queue.push(entry);
346                }
347            }
348        }
349
350        Ok(DistanceMapMetadata { num_nodes_visited })
351    }
352}
353
354impl<Cost> SearchContext<Cost>
355where
356    Cost: Copy + Add + PartialOrd + Zero + One + Sub<Output = Cost>,
357{
358    pub fn best_search_uniform_distance_map<G, V, D>(
359        &mut self,
360        grid: &G,
361        start: Coord,
362        config: SearchConfig,
363        max_depth: Cost,
364        distance_map: &UniformDistanceMap<Cost, D>,
365        path: &mut Vec<Direction>,
366    ) -> Result<SearchMetadata<Cost>, Error>
367    where
368        G: SolidGrid,
369        V: Into<Direction>,
370        D: Copy + IntoIterator<Item = V>,
371    {
372        let mut initial_entry =
373            match self.init(start, |_| max_depth == Zero::zero(), grid, config, path) {
374                Ok(initial_entry) => initial_entry,
375                Err(result) => return result,
376            };
377
378        initial_entry.cost = distance_map
379            .cost(start)
380            .ok_or(Error::InconsistentDistanceMap)?;
381
382        let mut best_map = BestMapNonEmpty::new(initial_entry.cost, initial_entry.node_index);
383        self.priority_queue.push(initial_entry);
384
385        let mut num_nodes_visited = 0;
386
387        while let Some(current_entry) = self.priority_queue.pop() {
388            num_nodes_visited += 1;
389
390            let (current_coord, current_depth) = {
391                let node = &self.node_grid[current_entry.node_index];
392                (node.coord, node.cost)
393            };
394
395            if current_depth >= max_depth {
396                continue;
397            }
398
399            let remaining_depth = max_depth - current_depth;
400            if *best_map.key() + remaining_depth <= current_entry.cost {
401                continue;
402            }
403
404            let next_depth = current_depth + One::one();
405
406            for v in distance_map.directions {
407                let direction = v.into();
408                let offset: Coord = direction.coord();
409                let neighbour_coord = current_coord + offset;
410
411                if let Some(false) = grid.is_solid(neighbour_coord) {
412                } else {
413                    continue;
414                }
415
416                let cost = distance_map
417                    .cost(neighbour_coord)
418                    .ok_or(Error::InconsistentDistanceMap)?;
419
420                let index = self
421                    .node_grid
422                    .index_of_coord(neighbour_coord)
423                    .ok_or(Error::VisitOutsideContext)?;
424
425                {
426                    let node = &mut self.node_grid[index];
427                    if node.seen != self.seq {
428                        node.seen = self.seq;
429                        node.from_parent = Some(direction);
430                        node.cost = next_depth;
431                        self.priority_queue.push(PriorityEntry::new(index, cost));
432                    }
433                }
434
435                best_map.insert_lt(cost, index);
436            }
437        }
438
439        let (cost, index) = best_map.into_key_and_value();
440        path::make_path_all_adjacent(&self.node_grid, index, path);
441        let length = path.len();
442        Ok(SearchMetadata {
443            num_nodes_visited,
444            length,
445            cost,
446        })
447    }
448}