hierarchical_pathfinding/grid/
dijkstra.rs

1use super::{Element, Path};
2use crate::{neighbors::Neighborhood, Point, PointMap, PointSet};
3
4use std::cmp::Ordering;
5use std::collections::BinaryHeap;
6
7pub fn dijkstra_search<N: Neighborhood>(
8    neighborhood: &N,
9    mut valid: impl FnMut(Point) -> bool,
10    mut get_cost: impl FnMut(Point) -> isize,
11    start: Point,
12    goals: &[Point],
13    only_closest_goal: bool,
14    size_hint: usize,
15) -> PointMap<Path<Point>> {
16    if get_cost(start) < 0 {
17        return PointMap::default();
18    }
19    let mut visited = PointMap::with_capacity(size_hint);
20    let mut next = BinaryHeap::with_capacity(size_hint / 2);
21    next.push(Element(start, 0));
22    visited.insert(start, (0, start));
23
24    let mut remaining_goals: PointSet = goals.iter().copied().collect();
25
26    let mut goal_costs = PointMap::with_capacity(goals.len());
27
28    let mut all_neighbors = vec![];
29
30    while let Some(Element(current_id, current_cost)) = next.pop() {
31        match current_cost.cmp(&visited[&current_id].0) {
32            Ordering::Greater => continue,
33            Ordering::Equal => {}
34            Ordering::Less => panic!("Binary Heap failed"),
35        }
36
37        if remaining_goals.remove(&current_id) {
38            goal_costs.insert(current_id, current_cost);
39            if only_closest_goal || remaining_goals.is_empty() {
40                break;
41            }
42        }
43
44        let delta_cost = get_cost(current_id);
45        if delta_cost < 0 {
46            continue;
47        }
48        let other_cost = current_cost + delta_cost as usize;
49
50        all_neighbors.clear();
51        neighborhood.get_all_neighbors(current_id, &mut all_neighbors);
52        for &other_id in all_neighbors.iter() {
53            if !valid(other_id) {
54                continue;
55            }
56            if get_cost(other_id) < 0 && !remaining_goals.contains(&other_id) {
57                continue;
58            }
59
60            let mut needs_visit = true;
61            if let Some((prev_cost, prev_id)) = visited.get_mut(&other_id) {
62                if *prev_cost > other_cost {
63                    *prev_cost = other_cost;
64                    *prev_id = current_id;
65                } else {
66                    needs_visit = false;
67                }
68            } else {
69                visited.insert(other_id, (other_cost, current_id));
70            }
71
72            if needs_visit {
73                next.push(Element(other_id, other_cost));
74            }
75        }
76    }
77
78    let mut goal_data = PointMap::with_capacity_and_hasher(goal_costs.len(), Default::default());
79
80    for (&goal, &cost) in goal_costs.iter() {
81        let steps = {
82            let mut steps = vec![];
83            let mut current = goal;
84
85            while current != start {
86                steps.push(current);
87                let (_, prev) = visited[&current];
88                current = prev;
89            }
90            steps.push(start);
91            steps.reverse();
92            steps
93        };
94        goal_data.insert(goal, Path::new(steps, cost));
95    }
96
97    goal_data
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn basic() {
106        use crate::prelude::*;
107
108        // create and initialize Grid
109        // 0 = empty, 1 = swamp, 2 = wall
110        let grid = [
111            [0, 2, 0, 0, 0],
112            [0, 2, 2, 2, 2],
113            [0, 1, 0, 0, 0],
114            [0, 1, 0, 2, 0],
115            [0, 0, 0, 2, 0],
116        ];
117        let (width, height) = (grid.len(), grid[0].len());
118
119        let neighborhood = ManhattanNeighborhood::new(width, height);
120
121        const COST_MAP: [isize; 3] = [1, 10, -1];
122
123        fn cost_fn(grid: &[[usize; 5]; 5]) -> impl '_ + FnMut(Point) -> isize {
124            move |(x, y)| COST_MAP[grid[y][x]]
125        }
126
127        let start = (0, 0);
128        let goals = [(4, 4), (2, 0)];
129
130        let paths = dijkstra_search(
131            &neighborhood,
132            |_| true,
133            cost_fn(&grid),
134            start,
135            &goals,
136            false,
137            40,
138        );
139
140        // (4, 4) is reachable
141        assert!(paths.contains_key(&goals[0]));
142
143        // (2, 0) is not reachable
144        assert!(!paths.contains_key(&goals[1]));
145    }
146}