1use crate::path::Path;
2use crate::{DiscoveredSet, Graph, NodeCost};
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::ops::Add;
6
7pub struct AStar<'a, Node, Cost, DS, G>
9where
10 Node: Copy + Eq + Hash,
11 Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
12 DS: DiscoveredSet<Node, Cost>,
13 G: Graph<Node, Cost>,
14{
15 graph: &'a G,
16 discovered: DS,
17 closed: HashSet<Node>,
18 previous: HashMap<Node, Node>,
19 g_scores: HashMap<Node, Cost>,
20 neighbors: Vec<NodeCost<Node, Cost>>,
21 path: Path<Node, Cost>,
22}
23
24impl<'a, Node, Cost, DS, G> AStar<'a, Node, Cost, DS, G>
25where
26 Node: Copy + Eq + Hash,
27 Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
28 DS: DiscoveredSet<Node, Cost>,
29 G: Graph<Node, Cost>,
30{
31 pub fn new(graph: &'a G, discovered: DS) -> Self {
35 Self {
36 graph,
37 discovered,
38 closed: HashSet::new(),
39 previous: HashMap::new(),
40 g_scores: HashMap::new(),
41 neighbors: Vec::new(),
42 path: Path::new(0),
43 }
44 }
45}
46
47impl<'a, Node, Cost, DS, G> AStar<'a, Node, Cost, DS, G>
48where
49 Node: Copy + Eq + Hash,
50 Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
51 DS: DiscoveredSet<Node, Cost>,
52 G: Graph<Node, Cost>,
53{
54 pub fn search(&mut self, start: Node, goal: Node) -> Option<&Path<Node, Cost>> {
58 self.reset();
59
60 let h_score: Cost = self.graph.heuristic(start, goal);
61 self.discovered.push(NodeCost::new(start, h_score));
62 self.g_scores.insert(start, Cost::default());
63
64 while let Some(current) = self.next() {
65 if current == goal {
66 return Some(self.reconstruct_path(current));
67 }
68
69 if !self.closed.insert(current) {
71 continue;
72 }
73
74 let current_g_score: Cost = self
75 .g_scores
76 .get(¤t)
77 .copied()
78 .expect("invalid impl: there must be a g-score for discovered nodes");
79
80 self.neighbors.clear();
81 self.graph.neighbors(current, &mut self.neighbors);
82 for neighbor in &self.neighbors {
83 let (neighbor, edge_cost): (Node, Cost) = (neighbor.node, neighbor.cost);
84 let g_score: Cost = current_g_score + edge_cost;
85 if self.is_best_g_score(neighbor, g_score) {
86 self.previous.insert(neighbor, current);
87 self.g_scores.insert(neighbor, g_score);
88
89 let h_score: Cost = self.graph.heuristic(neighbor, goal);
90 let f_score: Cost = g_score + h_score;
91 self.discovered.push(NodeCost::new(neighbor, f_score));
92 }
93 }
94 }
95
96 None
97 }
98
99 fn reset(&mut self) {
101 self.discovered.clear();
102 self.closed.clear();
103 self.previous.clear();
104 self.g_scores.clear();
105 }
106
107 fn next(&mut self) -> Option<Node> {
109 self.discovered
110 .pop()
111 .map(|node_cost: NodeCost<Node, Cost>| node_cost.node)
112 }
113
114 fn is_best_g_score(&self, node: Node, g_score: Cost) -> bool {
116 if let Some(current_g_score) = self.g_scores.get(&node) {
117 g_score < *current_g_score
118 } else {
119 true
120 }
121 }
122
123 fn reconstruct_path(&mut self, mut last: Node) -> &Path<Node, Cost> {
125 self.path.clear();
126 self.path
127 .set_cost(self.g_scores.get(&last).copied().unwrap_or(Cost::default()));
128 self.path.push(last);
129 while let Some(current) = self.previous.get(&last) {
130 self.path.push(*current);
131 last = *current;
132 }
133 &self.path
134 }
135}