use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use crate::metrics::CostMetric;
#[derive(Debug, Clone)]
pub struct AStarResult<N> {
pub path: Vec<N>,
pub cost: f64,
}
#[derive(Clone)]
struct OpenEntry<N> {
node: N,
f: f64,
g: f64,
}
impl<N: PartialEq> PartialEq for OpenEntry<N> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
}
}
impl<N: Eq> Eq for OpenEntry<N> {}
impl<N: Eq> PartialOrd for OpenEntry<N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<N: Eq> Ord for OpenEntry<N> {
fn cmp(&self, other: &Self) -> Ordering {
other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
}
}
pub fn astar<N, FH, FN, I>(
start: N,
goal: N,
mut heuristic: FH,
mut neighbors: FN,
) -> Option<AStarResult<N>>
where
N: Clone + Eq + std::hash::Hash,
FH: FnMut(&N, &N) -> f64,
FN: FnMut(&N) -> I,
I: IntoIterator<Item = (N, f64)>,
{
let mut open = BinaryHeap::new();
let mut g_scores: HashMap<N, f64> = HashMap::new();
let mut came_from: HashMap<N, N> = HashMap::new();
let mut closed: HashSet<N> = HashSet::new();
let h = heuristic(&start, &goal);
g_scores.insert(start.clone(), 0.0);
open.push(OpenEntry {
node: start.clone(),
f: h,
g: 0.0,
});
while let Some(current) = open.pop() {
if current.node == goal {
let mut path = Vec::new();
let mut cur = goal.clone();
loop {
path.push(cur.clone());
match came_from.get(&cur) {
Some(prev) => cur = prev.clone(),
None => break,
}
}
path.reverse();
return Some(AStarResult {
path,
cost: current.g,
});
}
if !closed.insert(current.node.clone()) {
continue;
}
for (neighbor, edge_cost) in neighbors(¤t.node) {
if closed.contains(&neighbor) {
continue;
}
let tentative_g = current.g + edge_cost;
let prev_g = g_scores.get(&neighbor).copied().unwrap_or(f64::INFINITY);
if tentative_g < prev_g {
g_scores.insert(neighbor.clone(), tentative_g);
came_from.insert(neighbor.clone(), current.node.clone());
let h = heuristic(&neighbor, &goal);
open.push(OpenEntry {
node: neighbor,
f: tentative_g + h,
g: tentative_g,
});
}
}
}
None
}
pub fn astar_grid2d(
start: (usize, usize),
goal: (usize, usize),
width: usize,
height: usize,
walkable: &dyn Fn(usize, usize) -> bool,
diagonal: bool,
) -> Option<AStarResult<(usize, usize)>> {
let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
let dx = (a.0 as f64 - b.0 as f64).abs();
let dy = (a.1 as f64 - b.1 as f64).abs();
if diagonal {
dx.max(dy) } else {
dx + dy }
};
let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
let (x, y) = *node;
let mut result = Vec::new();
let deltas: &[(i32, i32)] = if diagonal {
&[
(-1, -1),
(-1, 0),
(-1, 1),
(0, -1),
(0, 1),
(1, -1),
(1, 0),
(1, 1),
]
} else {
&[(-1, 0), (1, 0), (0, -1), (0, 1)]
};
for &(dx, dy) in deltas {
let nx = x as i32 + dx;
let ny = y as i32 + dy;
if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
let nx = nx as usize;
let ny = ny as usize;
if walkable(nx, ny) {
let cost = if dx != 0 && dy != 0 {
std::f64::consts::SQRT_2
} else {
1.0
};
result.push(((nx, ny), cost));
}
}
}
result
};
astar(start, goal, heuristic, neighbors)
}
pub struct GridAStarOpts<'a> {
pub width: usize,
pub height: usize,
pub diagonal: bool,
pub periodic: bool,
pub admissibility: f64,
pub walkable: Option<&'a dyn Fn(usize, usize) -> bool>,
pub cost_metric: Option<&'a dyn CostMetric>,
}
impl<'a> GridAStarOpts<'a> {
pub fn new(width: usize, height: usize) -> Self {
Self {
width,
height,
diagonal: true,
periodic: false,
admissibility: 0.0,
walkable: None,
cost_metric: None,
}
}
}
impl<'a> std::fmt::Debug for GridAStarOpts<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GridAStarOpts")
.field("width", &self.width)
.field("height", &self.height)
.field("diagonal", &self.diagonal)
.field("periodic", &self.periodic)
.field("admissibility", &self.admissibility)
.field("walkable", &self.walkable.is_some())
.field("cost_metric", &self.cost_metric.is_some())
.finish()
}
}
pub fn astar_grid2d_opts(
start: (usize, usize),
goal: (usize, usize),
opts: &GridAStarOpts<'_>,
) -> Option<AStarResult<(usize, usize)>> {
let width = opts.width;
let height = opts.height;
let diagonal = opts.diagonal;
let periodic = opts.periodic;
let admissibility = opts.admissibility;
let default_metric = crate::metrics::DirectDistance::new();
let metric: &dyn CostMetric = opts.cost_metric.unwrap_or(&default_metric);
let always_walkable = |_: usize, _: usize| true;
let walkable: &dyn Fn(usize, usize) -> bool = match &opts.walkable {
Some(f) => *f,
None => &always_walkable,
};
let start_n = normalize_grid_pos(start, periodic, width, height)?;
let goal_n = normalize_grid_pos(goal, periodic, width, height)?;
if !walkable(start_n.0, start_n.1) || !walkable(goal_n.0, goal_n.1) {
return None;
}
let heuristic = |a: &(usize, usize), b: &(usize, usize)| -> f64 {
(1.0 + admissibility) * metric.delta_cost(*a, *b, periodic, width, height, diagonal)
};
let neighbors = |node: &(usize, usize)| -> Vec<((usize, usize), f64)> {
let (x, y) = *node;
let mut result = Vec::new();
let deltas: &[(i32, i32)] = if diagonal {
&[
(-1, -1),
(-1, 0),
(-1, 1),
(0, -1),
(0, 1),
(1, -1),
(1, 0),
(1, 1),
]
} else {
&[(-1, 0), (1, 0), (0, -1), (0, 1)]
};
for &(dx, dy) in deltas {
let nx = x as i32 + dx;
let ny = y as i32 + dy;
let neighbor = if periodic {
let px = ((nx % width as i32) + width as i32) % width as i32;
let py = ((ny % height as i32) + height as i32) % height as i32;
Some((px as usize, py as usize))
} else if nx >= 0 && ny >= 0 && (nx as usize) < width && (ny as usize) < height {
Some((nx as usize, ny as usize))
} else {
None
};
if let Some(n) = neighbor {
if walkable(n.0, n.1) {
let cost = metric.delta_cost(*node, n, periodic, width, height, diagonal);
result.push((n, cost));
}
}
}
result
};
astar(start_n, goal_n, heuristic, neighbors)
}
fn normalize_grid_pos(
pos: (usize, usize),
periodic: bool,
width: usize,
height: usize,
) -> Option<(usize, usize)> {
if periodic {
Some((pos.0 % width, pos.1 % height))
} else if pos.0 < width && pos.1 < height {
Some(pos)
} else {
None
}
}