path_finding/search/
dijkstra.rs

1use std::collections::{HashMap, HashSet};
2
3use ordered_float::NotNan;
4use priority_queue::DoublePriorityQueue;
5
6use crate::graph::{Edge, Graph};
7use crate::grid::{Direction, Grid};
8use crate::node::{Node, Vec3};
9use crate::path::PathFinding;
10
11pub struct BreadthFirstSearch {}
12
13pub struct Dijkstra {}
14
15pub(crate) fn dijkstra(source: Node,
16                       target: Node,
17                       graph: &Graph,
18                       heuristic: &dyn Fn(&Vec3, &Vec3) -> f32) -> Graph {
19    let mut visited: HashSet<usize> = HashSet::new();
20    let mut node_to_edges: HashMap<usize, Vec<Edge>> = HashMap::new();
21    let mut queue: DoublePriorityQueue<usize, NotNan<f32>> = DoublePriorityQueue::new();
22
23    queue.push(source.id, NotNan::new(0.0).unwrap());
24    node_to_edges.insert(source.id, Vec::new());
25
26    while !visited.contains(&target.id) && !queue.is_empty() {
27        let current = queue.pop_min().unwrap();
28        visited.insert(current.0);
29
30        if let Some(node) = graph.nodes_lookup.get(&current.0) {
31            for edge in &node.edges {
32                let dest_id = edge.destination;
33
34                if !visited.contains(&dest_id) {
35                    let mut cost = current.1 + edge.weight;
36
37                    if graph.position_is_set() {
38                        cost = cost + heuristic(graph.get_position(&edge.destination),
39                            graph.get_position(&target.id), );
40                    }
41
42                    queue.push(edge.destination, cost);
43
44                    let mut from_edges = node_to_edges.get(&current.0).unwrap_or(&Vec::new()).clone();
45                    from_edges.push(edge.clone());
46                    node_to_edges.insert(dest_id, from_edges);
47                }
48            }
49        }
50    }
51
52    return Graph::from(node_to_edges.get(&target.id).cloned().unwrap_or_default().into());
53}
54
55pub(crate) fn dijkstra_grid(source: (usize, usize),
56                            target: (usize, usize),
57                            grid: &Grid,
58                            directions: &[Direction],
59                            heuristic: &dyn Fn(&Vec3, &Vec3) -> f32) -> Graph {
60    let mut visited: HashSet<usize> = HashSet::new();
61    let mut node_to_edges: HashMap<usize, Vec<Edge>> = HashMap::new();
62    let mut queue: DoublePriorityQueue<usize, NotNan<f32>> = DoublePriorityQueue::new();
63
64    let src_id = grid.node_id(source);
65    let trg_id = grid.node_id(target);
66
67    queue.push(src_id, NotNan::new(0.0).unwrap());
68    node_to_edges.insert(src_id, Vec::new());
69
70    while !visited.contains(&trg_id) && !queue.is_empty() {
71        let current = queue.pop_min().unwrap();
72        visited.insert(current.0);
73
74        for direction in directions {
75            let dest_coord = direction.attempt_move(grid.coords(current.0));
76
77            if grid.outside(dest_coord) {
78                continue;
79            }
80
81            let dest_id = grid.node_id(dest_coord);
82
83            if !visited.contains(&dest_id) {
84                let cost = current.1 + grid.cost(dest_id) + heuristic(
85                    &Vec3::from(dest_coord.0 as f32, dest_coord.1 as f32, 0.0),
86                    &Vec3::from(target.0 as f32, target.1 as f32, 0.0),
87                );
88                queue.push(dest_id, cost);
89                let edge = Edge::from(dest_id, current.0, dest_id, grid.cost(dest_id));
90
91                let mut from_edges = node_to_edges.get(&current.0).unwrap_or(&Vec::new()).clone();
92                from_edges.push(edge);
93                node_to_edges.insert(dest_id, from_edges);
94            }
95        }
96    }
97
98    return Graph::from(node_to_edges.get(&trg_id).cloned().unwrap_or_default().into());
99}
100
101fn dijkstra_heuristic(_src: &Vec3, _dest: &Vec3) -> f32 {
102    return 0.0;
103}
104
105impl PathFinding for Dijkstra {
106    fn graph(&self, source: Node, target: Node, graph: &Graph) -> Graph {
107        return dijkstra(source, target, graph, &dijkstra_heuristic);
108    }
109
110    fn grid(&self, source: (usize, usize), target: (usize, usize), grid: &Grid, directions: &[Direction]) -> Graph {
111        return dijkstra_grid(source, target, grid, directions, &dijkstra_heuristic);
112    }
113}
114
115#[test]
116fn should_find_path_with_dijkstra_between_a_and_b() {
117    let graph = graph();
118
119    let dij = Dijkstra {};
120    let path = dij.graph(graph.nodes_lookup.get(&0).unwrap().clone(),
121                         graph.nodes_lookup.get(&1).unwrap().clone(), &graph);
122
123    assert_eq!(3.0, calc_cost(&path.edges));
124    assert_eq!(2, path.edges.len());
125}
126
127#[test]
128fn should_find_path_with_dijkstra_between_a_and_c() {
129    let graph = graph();
130
131    let dij = Dijkstra {};
132    let path = dij.graph(get_node(0, &graph), get_node(2, &graph), &graph);
133
134
135    assert_eq!(2.0, calc_cost(&path.edges));
136    assert_eq!(1, path.edges.len());
137}
138
139#[test]
140fn should_find_path_with_dijkstra_between_a_and_d() {
141    let graph = graph();
142
143    let dij = Dijkstra {};
144    let path = dij.graph(get_node(0, &graph), get_node(3, &graph), &graph);
145
146
147    assert_eq!(5.0, calc_cost(&path.edges));
148    assert_eq!(3, path.edges.len());
149}
150
151#[test]
152fn should_find_path_with_dijkstra_between_a_and_e() {
153    let graph = graph();
154
155    let dij = Dijkstra {};
156    let path = dij.graph(get_node(0, &graph), get_node(4, &graph), &graph);
157
158
159    assert_eq!(6.0, calc_cost(&path.edges));
160    assert_eq!(3, path.edges.len());
161}
162
163#[test]
164fn should_find_path_with_disjoint_graphs() {
165    let graph = disjoint_graph();
166
167    let dij = Dijkstra {};
168    let path = dij.graph(get_node(0, &graph), get_node(3, &graph), &graph);
169
170    assert_eq!(0.0, calc_cost(&path.edges));
171    assert_eq!(0, path.edges.len());
172}
173
174#[cfg(test)]
175fn graph() -> Graph {
176    return Graph::from(Vec::from([
177        Edge::from(0, 0, 1, 4.0),
178        Edge::from(1, 0, 2, 2.0),
179        Edge::from(2, 1, 2, 3.0),
180        Edge::from(3, 1, 3, 2.0),
181        Edge::from(4, 1, 4, 3.0),
182        Edge::from(5, 2, 1, 1.0),
183        Edge::from(6, 2, 3, 4.0),
184        Edge::from(7, 2, 4, 5.0),
185        Edge::from(8, 4, 3, 1.0)
186    ]));
187}
188
189#[cfg(test)]
190fn disjoint_graph() -> Graph {
191    return Graph::from(Vec::from([
192        Edge::from(0, 0, 1, 4.0),
193        Edge::from(1, 2, 3, 2.0),
194    ]));
195}
196
197#[cfg(test)]
198fn get_node(id: usize, graph: &Graph) -> Node {
199    return graph.nodes_lookup.get(&id).unwrap().clone();
200}
201
202#[cfg(test)]
203fn calc_cost(edges: &Vec<Edge>) -> f32 {
204    let mut total_cost: f32 = 0.0;
205    for edge in edges {
206        total_cost += edge.weight;
207    }
208
209    return total_cost;
210}