Skip to main content

a_star/
a_star.rs

1use crate::path::Path;
2use crate::{DiscoveredSet, Graph, NodeCost};
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::ops::Add;
6
7/// Responsible for performing searches using the `A*` algorithm.
8pub struct AStar<'a, Node, Cost, DS, G>
9where
10    Node: Copy + Eq + Hash,
11    Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
12    DS: DiscoveredSet<Node, Cost>,
13    G: Graph<Node, Cost>,
14{
15    graph: &'a G,
16    discovered: DS,
17    closed: HashSet<Node>,
18    previous: HashMap<Node, Node>,
19    g_scores: HashMap<Node, Cost>,
20    neighbors: Vec<NodeCost<Node, Cost>>,
21    path: Path<Node, Cost>,
22}
23
24impl<'a, Node, Cost, DS, G> AStar<'a, Node, Cost, DS, G>
25where
26    Node: Copy + Eq + Hash,
27    Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
28    DS: DiscoveredSet<Node, Cost>,
29    G: Graph<Node, Cost>,
30{
31    //! Construction
32
33    /// Creates a new [AStar] instance.
34    pub fn new(graph: &'a G, discovered: DS) -> Self {
35        Self {
36            graph,
37            discovered,
38            closed: HashSet::new(),
39            previous: HashMap::new(),
40            g_scores: HashMap::new(),
41            neighbors: Vec::new(),
42            path: Path::new(0),
43        }
44    }
45}
46
47impl<'a, Node, Cost, DS, G> AStar<'a, Node, Cost, DS, G>
48where
49    Node: Copy + Eq + Hash,
50    Cost: Copy + Ord + Default + Add<Cost, Output = Cost>,
51    DS: DiscoveredSet<Node, Cost>,
52    G: Graph<Node, Cost>,
53{
54    //! Search
55
56    /// Finds the shortest path from the `start` node to the `goal` node, inclusive.
57    pub fn search(&mut self, start: Node, goal: Node) -> Option<&Path<Node, Cost>> {
58        self.reset();
59
60        let h_score: Cost = self.graph.heuristic(start, goal);
61        self.discovered.push(NodeCost::new(start, h_score));
62        self.g_scores.insert(start, Cost::default());
63
64        while let Some(current) = self.next() {
65            if current == goal {
66                return Some(self.reconstruct_path(current));
67            }
68
69            // Skip already-expanded nodes.
70            if !self.closed.insert(current) {
71                continue;
72            }
73
74            let current_g_score: Cost = self
75                .g_scores
76                .get(&current)
77                .copied()
78                .expect("invalid impl: there must be a g-score for discovered nodes");
79
80            self.neighbors.clear();
81            self.graph.neighbors(current, &mut self.neighbors);
82            for neighbor in &self.neighbors {
83                let (neighbor, edge_cost): (Node, Cost) = (neighbor.node, neighbor.cost);
84                let g_score: Cost = current_g_score + edge_cost;
85                if self.is_best_g_score(neighbor, g_score) {
86                    self.previous.insert(neighbor, current);
87                    self.g_scores.insert(neighbor, g_score);
88
89                    let h_score: Cost = self.graph.heuristic(neighbor, goal);
90                    let f_score: Cost = g_score + h_score;
91                    self.discovered.push(NodeCost::new(neighbor, f_score));
92                }
93            }
94        }
95
96        None
97    }
98
99    /// Resets the state for another iteration of the `A*` algorithm.
100    fn reset(&mut self) {
101        self.discovered.clear();
102        self.closed.clear();
103        self.previous.clear();
104        self.g_scores.clear();
105    }
106
107    /// Gets the next node to process.
108    fn next(&mut self) -> Option<Node> {
109        self.discovered
110            .pop()
111            .map(|node_cost: NodeCost<Node, Cost>| node_cost.node)
112    }
113
114    /// Checks if the `g_score` is the best `g_score` known for the `node`.
115    fn is_best_g_score(&self, node: Node, g_score: Cost) -> bool {
116        if let Some(current_g_score) = self.g_scores.get(&node) {
117            g_score < *current_g_score
118        } else {
119            true
120        }
121    }
122
123    /// Reconstructs the path from `start` to `last`.
124    fn reconstruct_path(&mut self, mut last: Node) -> &Path<Node, Cost> {
125        self.path.clear();
126        self.path
127            .set_cost(self.g_scores.get(&last).copied().unwrap_or(Cost::default()));
128        self.path.push(last);
129        while let Some(current) = self.previous.get(&last) {
130            self.path.push(*current);
131            last = *current;
132        }
133        &self.path
134    }
135}