use std::collections::BinaryHeap;
use std::hash::Hash;
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use petgraph::algo::Measure;
use petgraph::visit::{EdgeRef, GraphBase, IntoEdges, VisitMap, Visitable};
use crate::min_scored::MinScored;
type AstarOutput<K, N> = Option<(K, Vec<N>)>;
pub fn astar<G, F, H, K, IsGoal, E>(
graph: G,
start: G::NodeId,
mut is_goal: IsGoal,
mut edge_cost: F,
mut estimate_cost: H,
) -> Result<AstarOutput<K, G::NodeId>, E>
where
G: IntoEdges + Visitable,
IsGoal: FnMut(G::NodeId) -> Result<bool, E>,
G::NodeId: Eq + Hash,
F: FnMut(G::EdgeRef) -> Result<K, E>,
H: FnMut(G::NodeId) -> Result<K, E>,
K: Measure + Copy,
{
let mut visited = graph.visit_map();
let mut visit_next = BinaryHeap::new();
let mut scores = HashMap::new();
let mut path_tracker = PathTracker::<G>::new();
let zero_score = K::default();
scores.insert(start, zero_score);
let estimate = estimate_cost(start)?;
visit_next.push(MinScored(estimate, start));
while let Some(MinScored(_, node)) = visit_next.pop() {
let result = is_goal(node)?;
if result {
let path = path_tracker.reconstruct_path_to(node);
let cost = scores[&node];
return Ok(Some((cost, path)));
}
if !visited.visit(node) {
continue;
}
let node_score = scores[&node];
for edge in graph.edges(node) {
let next = edge.target();
if visited.is_visited(&next) {
continue;
}
let cost = edge_cost(edge)?;
let mut next_score = node_score + cost;
match scores.entry(next) {
Entry::Occupied(ent) => {
let old_score = *ent.get();
if next_score < old_score {
*ent.into_mut() = next_score;
path_tracker.set_predecessor(next, node);
} else {
next_score = old_score;
}
}
Entry::Vacant(ent) => {
ent.insert(next_score);
path_tracker.set_predecessor(next, node);
}
}
let estimate = estimate_cost(next)?;
let next_estimate_score = next_score + estimate;
visit_next.push(MinScored(next_estimate_score, next));
}
}
Ok(None)
}
struct PathTracker<G>
where
G: GraphBase,
G::NodeId: Eq + Hash,
{
came_from: HashMap<G::NodeId, G::NodeId>,
}
impl<G> PathTracker<G>
where
G: GraphBase,
G::NodeId: Eq + Hash,
{
fn new() -> PathTracker<G> {
PathTracker {
came_from: HashMap::new(),
}
}
fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
self.came_from.insert(node, previous);
}
fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
let mut path = vec![last];
let mut current = last;
while let Some(&previous) = self.came_from.get(¤t) {
path.push(previous);
current = previous;
}
path.reverse();
path
}
}
#[cfg(test)]
mod tests {
use crate::dictmap::DictMap;
use crate::shortest_path::{astar, dijkstra};
use crate::Result;
use petgraph::graph::NodeIndex;
use petgraph::Graph;
#[test]
fn test_astar_null_heuristic() {
let mut g = Graph::new();
let a = g.add_node("A");
let b = g.add_node("B");
let c = g.add_node("C");
let d = g.add_node("D");
let e = g.add_node("E");
let f = g.add_node("F");
g.add_edge(a, b, 7);
g.add_edge(c, a, 9);
g.add_edge(a, d, 14);
g.add_edge(b, c, 10);
g.add_edge(d, c, 2);
g.add_edge(d, e, 9);
g.add_edge(b, f, 15);
g.add_edge(c, f, 11);
g.add_edge(e, f, 6);
let path: Result<Option<(i32, Vec<NodeIndex>)>> = astar(
&g,
a,
|finish| Ok(finish == e),
|e| Ok(*e.weight()),
|_| Ok(0),
);
assert_eq!(path.unwrap(), Some((23, vec![a, d, e])));
let dijkstra_run: Result<DictMap<NodeIndex, i32>> =
dijkstra(&g, a, Some(e), |e| Ok(*e.weight()), None);
assert_eq!(dijkstra_run.unwrap()[&e], 23);
let path: Result<Option<(i32, Vec<NodeIndex>)>> = astar(
&g,
e,
|finish| Ok(finish == b),
|e| Ok(*e.weight()),
|_| Ok(0),
);
assert_eq!(path.unwrap(), None);
}
#[test]
fn test_astar_manhattan_heuristic() {
let mut g = Graph::new();
let a = g.add_node((0., 0.));
let b = g.add_node((2., 0.));
let c = g.add_node((1., 1.));
let d = g.add_node((0., 2.));
let e = g.add_node((3., 3.));
let f = g.add_node((4., 2.));
let _ = g.add_node((5., 5.)); g.add_edge(a, b, 2.);
g.add_edge(a, d, 4.);
g.add_edge(b, c, 1.);
g.add_edge(b, f, 7.);
g.add_edge(c, e, 5.);
g.add_edge(e, f, 1.);
g.add_edge(d, e, 1.);
let heuristic_for = |f: NodeIndex| {
let g = &g;
move |node: NodeIndex| -> Result<f32> {
let (x1, y1): (f32, f32) = g[node];
let (x2, y2): (f32, f32) = g[f];
Ok((x2 - x1).abs() + (y2 - y1).abs())
}
};
let path = astar(
&g,
a,
|finish| Ok(finish == f),
|e| Ok(*e.weight()),
heuristic_for(f),
);
assert_eq!(path.unwrap(), Some((6., vec![a, d, e, f])));
let dijkstra_run: Result<DictMap<NodeIndex, f32>> =
dijkstra(&g, a, None, |e| Ok(*e.weight()), None);
for end in g.node_indices() {
let astar_path = astar(
&g,
a,
|finish| Ok(finish == end),
|e| Ok(*e.weight()),
heuristic_for(end),
);
assert_eq!(
dijkstra_run.as_ref().unwrap().get(&end).cloned(),
astar_path.unwrap().map(|t| t.0)
);
}
}
#[test]
fn test_astar_runtime_optimal() {
let mut g = Graph::new();
let a = g.add_node("A");
let b = g.add_node("B");
let c = g.add_node("C");
let d = g.add_node("D");
let e = g.add_node("E");
g.add_edge(a, b, 2);
g.add_edge(a, c, 3);
g.add_edge(b, d, 3);
g.add_edge(c, d, 1);
g.add_edge(d, e, 1);
let mut times_called = 0;
let _: Result<Option<(i32, Vec<NodeIndex>)>> = astar(
&g,
a,
|n| Ok(n == e),
|edge| {
times_called += 1;
Ok(*edge.weight())
},
|_| Ok(0),
);
assert_eq!(times_called, 5);
}
}