use alloc::{collections::BinaryHeap, vec, vec::Vec};
use core::hash::Hash;
use hashbrown::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
};
use crate::algo::Measure;
use crate::scored::MinScored;
use crate::visit::{EdgeRef, GraphBase, IntoEdges, Visitable};
pub fn astar<G, F, H, K, IsGoal>(
graph: G,
start: G::NodeId,
mut is_goal: IsGoal,
mut edge_cost: F,
mut estimate_cost: H,
) -> Option<(K, Vec<G::NodeId>)>
where
G: IntoEdges + Visitable,
IsGoal: FnMut(G::NodeId) -> bool,
G::NodeId: Eq + Hash,
F: FnMut(G::EdgeRef) -> K,
H: FnMut(G::NodeId) -> K,
K: Measure + Copy,
{
let mut visit_next = BinaryHeap::new();
let mut scores = HashMap::new();
let mut path_tracker = PathTracker::<G>::new();
let zero: K = K::default();
let g: K = zero;
let h: K = estimate_cost(start);
let f: K = g + h;
scores.insert(start, (f, h, g));
visit_next.push(MinScored((f, h, g), start));
while let Some(MinScored((f, h, g), node)) = visit_next.pop() {
if is_goal(node) {
let path = path_tracker.reconstruct_path_to(node);
let (goal_f, goal_h, goal_g) = scores[&node];
debug_assert_eq!(goal_h, zero);
debug_assert_eq!(goal_f, goal_g);
return Some((goal_f, path));
}
match scores.entry(node) {
Occupied(mut entry) => {
let (_, _, old_g) = *entry.get();
if old_g < g {
continue;
}
entry.insert((f, h, g));
}
Vacant(entry) => {
entry.insert((f, h, g));
}
}
for edge in graph.edges(node) {
let neigh = edge.target();
let neigh_g = g + edge_cost(edge);
let neigh_h = estimate_cost(neigh);
let neigh_f = neigh_g + neigh_h;
let neigh_score = (neigh_f, neigh_h, neigh_g);
match scores.entry(neigh) {
Occupied(mut entry) => {
let (_, _, old_neigh_g) = *entry.get();
if neigh_g >= old_neigh_g {
continue;
}
entry.insert(neigh_score);
}
Vacant(entry) => {
entry.insert(neigh_score);
}
}
path_tracker.set_predecessor(neigh, node);
visit_next.push(MinScored(neigh_score, neigh));
}
}
None
}
struct PathTracker<G>
where
G: GraphBase,
G::NodeId: Eq + Hash,
{
came_from: HashMap<G::NodeId, G::NodeId>,
}
impl<G> PathTracker<G>
where
G: GraphBase,
G::NodeId: Eq + Hash,
{
fn new() -> PathTracker<G> {
PathTracker {
came_from: HashMap::new(),
}
}
fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
self.came_from.insert(node, previous);
}
fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
let mut path = vec![last];
let mut current = last;
while let Some(&previous) = self.came_from.get(¤t) {
path.push(previous);
current = previous;
}
path.reverse();
path
}
}