retworkx/astar.rs
1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13// This module is copied and forked from the upstream petgraph repository,
14// specifically:
15// https://github.com/petgraph/petgraph/blob/0.5.1/src/astar.rs and
16// https://github.com/petgraph/petgraph/blob/0.5.1/src/scored.rs
17// this was necessary to modify the error handling to allow python callables
18// to be use for the input functions for is_goal, edge_cost, estimate_cost
19// and return any exceptions raised in Python instead of panicking
20
21use std::cmp::Ordering;
22use std::collections::BinaryHeap;
23use std::hash::Hash;
24
25use hashbrown::hash_map::Entry::{Occupied, Vacant};
26use hashbrown::HashMap;
27
28use petgraph::visit::{EdgeRef, GraphBase, IntoEdges, VisitMap, Visitable};
29
30use petgraph::algo::Measure;
31use pyo3::prelude::*;
32
33/// `MinScored<K, T>` holds a score `K` and a scored object `T` in
34/// a pair for use with a `BinaryHeap`.
35///
36/// `MinScored` compares in reverse order by the score, so that we can
37/// use `BinaryHeap` as a min-heap to extract the score-value pair with the
38/// least score.
39///
40/// **Note:** `MinScored` implements a total order (`Ord`), so that it is
41/// possible to use float types as scores.
42#[derive(Copy, Clone, Debug)]
43pub struct MinScored<K, T>(pub K, pub T);
44
45impl<K: PartialOrd, T> PartialEq for MinScored<K, T> {
46 #[inline]
47 fn eq(&self, other: &MinScored<K, T>) -> bool {
48 self.cmp(other) == Ordering::Equal
49 }
50}
51
52impl<K: PartialOrd, T> Eq for MinScored<K, T> {}
53
54impl<K: PartialOrd, T> PartialOrd for MinScored<K, T> {
55 #[inline]
56 fn partial_cmp(&self, other: &MinScored<K, T>) -> Option<Ordering> {
57 Some(self.cmp(other))
58 }
59}
60
61impl<K: PartialOrd, T> Ord for MinScored<K, T> {
62 #[inline]
63 fn cmp(&self, other: &MinScored<K, T>) -> Ordering {
64 let a = &self.0;
65 let b = &other.0;
66 if a == b {
67 Ordering::Equal
68 } else if a < b {
69 Ordering::Greater
70 } else if a > b {
71 Ordering::Less
72 } else if a.ne(a) && b.ne(b) {
73 // these are the NaN cases
74 Ordering::Equal
75 } else if a.ne(a) {
76 // Order NaN less, so that it is last in the MinScore order
77 Ordering::Less
78 } else {
79 Ordering::Greater
80 }
81 }
82}
83
84/// \[Generic\] A* shortest path algorithm.
85///
86/// Computes the shortest path from `start` to `finish`, including the total path cost.
87///
88/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
89/// given node is the finish node.
90///
91/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
92/// non-negative.
93///
94/// The function `estimate_cost` should return the estimated cost to the finish for a particular
95/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
96/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
97/// must also be non-negative.
98///
99/// The graph should be `Visitable` and implement `IntoEdges`.
100///
101/// # Example
102/// ```
103/// use petgraph::Graph;
104/// use petgraph::algo::astar;
105///
106/// let mut g = Graph::new();
107/// let a = g.add_node((0., 0.));
108/// let b = g.add_node((2., 0.));
109/// let c = g.add_node((1., 1.));
110/// let d = g.add_node((0., 2.));
111/// let e = g.add_node((3., 3.));
112/// let f = g.add_node((4., 2.));
113/// g.extend_with_edges(&[
114/// (a, b, 2),
115/// (a, d, 4),
116/// (b, c, 1),
117/// (b, f, 7),
118/// (c, e, 5),
119/// (e, f, 1),
120/// (d, e, 1),
121/// ]);
122///
123/// // Graph represented with the weight of each edge
124/// // Edges with '*' are part of the optimal path.
125/// //
126/// // 2 1
127/// // a ----- b ----- c
128/// // | 4* | 7 |
129/// // d f | 5
130/// // | 1* | 1* |
131/// // \------ e ------/
132///
133/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
134/// assert_eq!(path, Some((6, vec![a, d, e, f])));
135/// ```
136///
137/// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
138/// found.
139pub fn astar<G, F, H, K, IsGoal>(
140 graph: G,
141 start: G::NodeId,
142 mut is_goal: IsGoal,
143 mut edge_cost: F,
144 mut estimate_cost: H,
145) -> PyResult<Option<(K, Vec<G::NodeId>)>>
146where
147 G: IntoEdges + Visitable,
148 IsGoal: FnMut(G::NodeId) -> PyResult<bool>,
149 G::NodeId: Eq + Hash,
150 F: FnMut(G::EdgeRef) -> PyResult<K>,
151 H: FnMut(G::NodeId) -> PyResult<K>,
152 K: Measure + Copy,
153{
154 let mut visited = graph.visit_map();
155 let mut visit_next = BinaryHeap::new();
156 let mut scores = HashMap::new();
157 let mut path_tracker = PathTracker::<G>::new();
158
159 let zero_score = K::default();
160 scores.insert(start, zero_score);
161 let estimate = estimate_cost(start)?;
162 visit_next.push(MinScored(estimate, start));
163
164 while let Some(MinScored(_, node)) = visit_next.pop() {
165 let result = is_goal(node)?;
166 if result {
167 let path = path_tracker.reconstruct_path_to(node);
168 let cost = scores[&node];
169 return Ok(Some((cost, path)));
170 }
171
172 // Don't visit the same node several times, as the first time it was visited it was using
173 // the shortest available path.
174 if !visited.visit(node) {
175 continue;
176 }
177
178 // This lookup can be unwrapped without fear of panic since the node was necessarily scored
179 // before adding him to `visit_next`.
180 let node_score = scores[&node];
181
182 for edge in graph.edges(node) {
183 let next = edge.target();
184 if visited.is_visited(&next) {
185 continue;
186 }
187
188 let cost = edge_cost(edge)?;
189 let mut next_score = node_score + cost;
190
191 match scores.entry(next) {
192 Occupied(ent) => {
193 let old_score = *ent.get();
194 if next_score < old_score {
195 *ent.into_mut() = next_score;
196 path_tracker.set_predecessor(next, node);
197 } else {
198 next_score = old_score;
199 }
200 }
201 Vacant(ent) => {
202 ent.insert(next_score);
203 path_tracker.set_predecessor(next, node);
204 }
205 }
206
207 let estimate = estimate_cost(next)?;
208 let next_estimate_score = next_score + estimate;
209 visit_next.push(MinScored(next_estimate_score, next));
210 }
211 }
212
213 Ok(None)
214}
215
216struct PathTracker<G>
217where
218 G: GraphBase,
219 G::NodeId: Eq + Hash,
220{
221 came_from: HashMap<G::NodeId, G::NodeId>,
222}
223
224impl<G> PathTracker<G>
225where
226 G: GraphBase,
227 G::NodeId: Eq + Hash,
228{
229 fn new() -> PathTracker<G> {
230 PathTracker {
231 came_from: HashMap::new(),
232 }
233 }
234
235 fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
236 self.came_from.insert(node, previous);
237 }
238
239 fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
240 let mut path = vec![last];
241
242 let mut current = last;
243 while let Some(&previous) = self.came_from.get(¤t) {
244 path.push(previous);
245 current = previous;
246 }
247
248 path.reverse();
249
250 path
251 }
252}