retworkx/
astar.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13// This module is copied and forked from the upstream petgraph repository,
14// specifically:
15// https://github.com/petgraph/petgraph/blob/0.5.1/src/astar.rs and
16// https://github.com/petgraph/petgraph/blob/0.5.1/src/scored.rs
17// this was necessary to modify the error handling to allow python callables
18// to be use for the input functions for is_goal, edge_cost, estimate_cost
19// and return any exceptions raised in Python instead of panicking
20
21use std::cmp::Ordering;
22use std::collections::BinaryHeap;
23use std::hash::Hash;
24
25use hashbrown::hash_map::Entry::{Occupied, Vacant};
26use hashbrown::HashMap;
27
28use petgraph::visit::{EdgeRef, GraphBase, IntoEdges, VisitMap, Visitable};
29
30use petgraph::algo::Measure;
31use pyo3::prelude::*;
32
33/// `MinScored<K, T>` holds a score `K` and a scored object `T` in
34/// a pair for use with a `BinaryHeap`.
35///
36/// `MinScored` compares in reverse order by the score, so that we can
37/// use `BinaryHeap` as a min-heap to extract the score-value pair with the
38/// least score.
39///
40/// **Note:** `MinScored` implements a total order (`Ord`), so that it is
41/// possible to use float types as scores.
42#[derive(Copy, Clone, Debug)]
43pub struct MinScored<K, T>(pub K, pub T);
44
45impl<K: PartialOrd, T> PartialEq for MinScored<K, T> {
46    #[inline]
47    fn eq(&self, other: &MinScored<K, T>) -> bool {
48        self.cmp(other) == Ordering::Equal
49    }
50}
51
52impl<K: PartialOrd, T> Eq for MinScored<K, T> {}
53
54impl<K: PartialOrd, T> PartialOrd for MinScored<K, T> {
55    #[inline]
56    fn partial_cmp(&self, other: &MinScored<K, T>) -> Option<Ordering> {
57        Some(self.cmp(other))
58    }
59}
60
61impl<K: PartialOrd, T> Ord for MinScored<K, T> {
62    #[inline]
63    fn cmp(&self, other: &MinScored<K, T>) -> Ordering {
64        let a = &self.0;
65        let b = &other.0;
66        if a == b {
67            Ordering::Equal
68        } else if a < b {
69            Ordering::Greater
70        } else if a > b {
71            Ordering::Less
72        } else if a.ne(a) && b.ne(b) {
73            // these are the NaN cases
74            Ordering::Equal
75        } else if a.ne(a) {
76            // Order NaN less, so that it is last in the MinScore order
77            Ordering::Less
78        } else {
79            Ordering::Greater
80        }
81    }
82}
83
84/// \[Generic\] A* shortest path algorithm.
85///
86/// Computes the shortest path from `start` to `finish`, including the total path cost.
87///
88/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
89/// given node is the finish node.
90///
91/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
92/// non-negative.
93///
94/// The function `estimate_cost` should return the estimated cost to the finish for a particular
95/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
96/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
97/// must also be non-negative.
98///
99/// The graph should be `Visitable` and implement `IntoEdges`.
100///
101/// # Example
102/// ```
103/// use petgraph::Graph;
104/// use petgraph::algo::astar;
105///
106/// let mut g = Graph::new();
107/// let a = g.add_node((0., 0.));
108/// let b = g.add_node((2., 0.));
109/// let c = g.add_node((1., 1.));
110/// let d = g.add_node((0., 2.));
111/// let e = g.add_node((3., 3.));
112/// let f = g.add_node((4., 2.));
113/// g.extend_with_edges(&[
114///     (a, b, 2),
115///     (a, d, 4),
116///     (b, c, 1),
117///     (b, f, 7),
118///     (c, e, 5),
119///     (e, f, 1),
120///     (d, e, 1),
121/// ]);
122///
123/// // Graph represented with the weight of each edge
124/// // Edges with '*' are part of the optimal path.
125/// //
126/// //     2       1
127/// // a ----- b ----- c
128/// // | 4*    | 7     |
129/// // d       f       | 5
130/// // | 1*    | 1*    |
131/// // \------ e ------/
132///
133/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
134/// assert_eq!(path, Some((6, vec![a, d, e, f])));
135/// ```
136///
137/// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
138/// found.
139pub fn astar<G, F, H, K, IsGoal>(
140    graph: G,
141    start: G::NodeId,
142    mut is_goal: IsGoal,
143    mut edge_cost: F,
144    mut estimate_cost: H,
145) -> PyResult<Option<(K, Vec<G::NodeId>)>>
146where
147    G: IntoEdges + Visitable,
148    IsGoal: FnMut(G::NodeId) -> PyResult<bool>,
149    G::NodeId: Eq + Hash,
150    F: FnMut(G::EdgeRef) -> PyResult<K>,
151    H: FnMut(G::NodeId) -> PyResult<K>,
152    K: Measure + Copy,
153{
154    let mut visited = graph.visit_map();
155    let mut visit_next = BinaryHeap::new();
156    let mut scores = HashMap::new();
157    let mut path_tracker = PathTracker::<G>::new();
158
159    let zero_score = K::default();
160    scores.insert(start, zero_score);
161    let estimate = estimate_cost(start)?;
162    visit_next.push(MinScored(estimate, start));
163
164    while let Some(MinScored(_, node)) = visit_next.pop() {
165        let result = is_goal(node)?;
166        if result {
167            let path = path_tracker.reconstruct_path_to(node);
168            let cost = scores[&node];
169            return Ok(Some((cost, path)));
170        }
171
172        // Don't visit the same node several times, as the first time it was visited it was using
173        // the shortest available path.
174        if !visited.visit(node) {
175            continue;
176        }
177
178        // This lookup can be unwrapped without fear of panic since the node was necessarily scored
179        // before adding him to `visit_next`.
180        let node_score = scores[&node];
181
182        for edge in graph.edges(node) {
183            let next = edge.target();
184            if visited.is_visited(&next) {
185                continue;
186            }
187
188            let cost = edge_cost(edge)?;
189            let mut next_score = node_score + cost;
190
191            match scores.entry(next) {
192                Occupied(ent) => {
193                    let old_score = *ent.get();
194                    if next_score < old_score {
195                        *ent.into_mut() = next_score;
196                        path_tracker.set_predecessor(next, node);
197                    } else {
198                        next_score = old_score;
199                    }
200                }
201                Vacant(ent) => {
202                    ent.insert(next_score);
203                    path_tracker.set_predecessor(next, node);
204                }
205            }
206
207            let estimate = estimate_cost(next)?;
208            let next_estimate_score = next_score + estimate;
209            visit_next.push(MinScored(next_estimate_score, next));
210        }
211    }
212
213    Ok(None)
214}
215
216struct PathTracker<G>
217where
218    G: GraphBase,
219    G::NodeId: Eq + Hash,
220{
221    came_from: HashMap<G::NodeId, G::NodeId>,
222}
223
224impl<G> PathTracker<G>
225where
226    G: GraphBase,
227    G::NodeId: Eq + Hash,
228{
229    fn new() -> PathTracker<G> {
230        PathTracker {
231            came_from: HashMap::new(),
232        }
233    }
234
235    fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
236        self.came_from.insert(node, previous);
237    }
238
239    fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
240        let mut path = vec![last];
241
242        let mut current = last;
243        while let Some(&previous) = self.came_from.get(&current) {
244            path.push(previous);
245            current = previous;
246        }
247
248        path.reverse();
249
250        path
251    }
252}