use num_traits::Zero;
use std::collections::{BinaryHeap, HashMap};
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::cmp::Ordering;
use std::hash::Hash;
use std::rc::Rc;
use std::borrow::Borrow;
use super::reverse_path;
pub fn astar<N, C, FN, IN, FH, FS>(
start: &N,
neighbours: FN,
heuristic: FH,
success: FS,
) -> Option<(Vec<N>, C)>
where
N: Eq + Hash + Clone,
C: Zero + Ord + Copy,
FN: Fn(&N) -> IN,
IN: IntoIterator<Item = (N, C)>,
FH: Fn(&N) -> C,
FS: Fn(&N) -> bool,
{
let mut to_see = BinaryHeap::new();
to_see.push(SmallestCostHolder {
estimated_cost: heuristic(start),
cost: Zero::zero(),
payload: (Zero::zero(), start.clone()),
});
let mut parents: HashMap<N, (N, C)> = HashMap::new();
while let Some(SmallestCostHolder {
payload: (cost, node),
..
}) = to_see.pop()
{
if success(&node) {
let parents = parents.into_iter().map(|(n, (p, _))| (n, p)).collect();
return Some((reverse_path(&parents, node), cost));
}
if let Some(&(_, c)) = parents.get(&node) {
if cost > c {
continue;
}
}
for (neighbour, move_cost) in neighbours(&node) {
let new_cost = cost + move_cost;
if neighbour != *start {
let mut inserted = true;
match parents.entry(neighbour.clone()) {
Vacant(e) => {
e.insert((node.clone(), new_cost));
}
Occupied(mut e) => if e.get().1 > new_cost {
e.insert((node.clone(), new_cost));
} else {
inserted = false;
},
};
if inserted {
let new_predicted_cost = new_cost + heuristic(&neighbour);
to_see.push(SmallestCostHolder {
estimated_cost: new_predicted_cost,
cost: cost,
payload: (new_cost, neighbour),
});
}
}
}
}
None
}
#[derive(Eq, Hash, PartialEq)]
struct PathNode<N>
where
N: Clone,
{
node: Rc<N>,
parent: Option<Rc<PathNode<N>>>,
}
pub fn astar_bag<N, C, FN, IN, FH, FS>(
start: &N,
neighbours: FN,
heuristic: FH,
success: FS,
) -> (Vec<Vec<N>>, C)
where
N: Eq + Hash + Clone,
C: Zero + Ord + Copy,
FN: Fn(&N) -> IN,
IN: IntoIterator<Item = (N, C)>,
FH: Fn(&N) -> C,
FS: Fn(&N) -> bool,
{
let mut to_see = BinaryHeap::new();
let mut out = Vec::new();
to_see.push(SmallestCostHolder {
estimated_cost: heuristic(start),
cost: Zero::zero(),
payload: Rc::new(PathNode {
node: Rc::new(start.clone()),
parent: None,
}),
});
let mut lowest_cost = HashMap::new();
let mut min_cost = None;
while let Some(SmallestCostHolder {
cost,
estimated_cost,
payload,
}) = to_see.pop()
{
let pn: &PathNode<N> = Rc::borrow(&payload);
if let Some(mc) = min_cost {
if estimated_cost > mc {
break;
}
}
if success(pn.node.borrow()) {
min_cost = Some(cost);
out.push(mk_path(&payload));
continue;
}
if let Some(c) = lowest_cost.get(&pn.node) {
if cost > *c {
continue;
}
}
for (neighbour, move_cost) in neighbours(pn.node.borrow()) {
let new_cost = cost + move_cost;
if neighbour != *start {
let new_predicted_cost = new_cost + heuristic(&neighbour);
let nrc = Rc::new(neighbour);
match lowest_cost.entry(Rc::clone(&nrc)) {
Vacant(e) => {
e.insert(new_cost);
}
Occupied(mut e) => if *e.get() > new_cost {
e.insert(new_cost);
},
}
to_see.push(SmallestCostHolder {
estimated_cost: new_predicted_cost,
cost: new_cost,
payload: Rc::new(PathNode {
node: nrc,
parent: Some(Rc::clone(&payload)),
}),
});
}
}
}
(out, min_cost.unwrap_or_else(Zero::zero))
}
fn mk_path<N>(mut pl: &Rc<PathNode<N>>) -> Vec<N>
where
N: Clone,
{
let mut path = Vec::new();
loop {
let cur: &PathNode<N> = Rc::borrow(pl);
let n: &N = cur.node.borrow();
path.push(n.clone());
match cur.parent {
Some(ref p) => pl = p,
None => break,
}
}
path.reverse();
path
}
struct SmallestCostHolder<K, P> {
estimated_cost: K,
cost: K,
payload: P,
}
impl<K: PartialEq, P> PartialEq for SmallestCostHolder<K, P> {
fn eq(&self, other: &SmallestCostHolder<K, P>) -> bool {
self.estimated_cost.eq(&other.estimated_cost) && self.cost.eq(&other.cost)
}
}
impl<K: PartialEq, P> Eq for SmallestCostHolder<K, P> {}
impl<K: Ord, P> PartialOrd for SmallestCostHolder<K, P> {
fn partial_cmp(&self, other: &SmallestCostHolder<K, P>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<K: Ord, P> Ord for SmallestCostHolder<K, P> {
fn cmp(&self, other: &SmallestCostHolder<K, P>) -> Ordering {
match other.estimated_cost.cmp(&self.estimated_cost) {
Ordering::Equal => self.cost.cmp(&other.cost),
s => s,
}
}
}