use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Pos {
pub x: usize,
pub y: usize,
}
impl Pos {
pub fn new(x: usize, y: usize) -> Self {
Self { x, y }
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
struct AStarNode {
pos: Pos,
f_score: usize, }
impl Ord for AStarNode {
fn cmp(&self, other: &Self) -> Ordering {
other.f_score.cmp(&self.f_score)
}
}
impl PartialOrd for AStarNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct PathGrid {
width: usize,
height: usize,
blocked: HashSet<Pos>,
}
impl PathGrid {
pub fn new(width: usize, height: usize) -> Self {
Self {
width,
height,
blocked: HashSet::new(),
}
}
pub fn block_rect(&mut self, x: usize, y: usize, width: usize, height: usize) {
for dy in 0..height {
for dx in 0..width {
self.blocked.insert(Pos::new(x + dx, y + dy));
}
}
}
#[allow(dead_code)]
pub fn unblock(&mut self, pos: Pos) {
self.blocked.remove(&pos);
}
fn is_valid(&self, pos: Pos) -> bool {
pos.x < self.width && pos.y < self.height && !self.blocked.contains(&pos)
}
fn neighbors(&self, pos: Pos) -> Vec<Pos> {
let mut result = Vec::new();
if pos.x + 1 < self.width {
let p = Pos::new(pos.x + 1, pos.y);
if self.is_valid(p) {
result.push(p);
}
}
if pos.x > 0 {
let p = Pos::new(pos.x - 1, pos.y);
if self.is_valid(p) {
result.push(p);
}
}
if pos.y + 1 < self.height {
let p = Pos::new(pos.x, pos.y + 1);
if self.is_valid(p) {
result.push(p);
}
}
if pos.y > 0 {
let p = Pos::new(pos.x, pos.y - 1);
if self.is_valid(p) {
result.push(p);
}
}
result
}
fn heuristic(from: Pos, to: Pos) -> usize {
from.x.abs_diff(to.x) + from.y.abs_diff(to.y)
}
pub fn find_path(&self, start: Pos, goal: Pos) -> Option<Vec<Pos>> {
if !self.is_valid(start) || !self.is_valid(goal) {
return None;
}
let mut open_set = BinaryHeap::new();
let mut came_from: HashMap<Pos, Pos> = HashMap::new();
let mut g_score: HashMap<Pos, usize> = HashMap::new();
g_score.insert(start, 0);
open_set.push(AStarNode {
pos: start,
f_score: Self::heuristic(start, goal),
});
while let Some(current) = open_set.pop() {
if current.pos == goal {
let mut path = vec![current.pos];
let mut pos = current.pos;
while let Some(&prev) = came_from.get(&pos) {
path.push(prev);
pos = prev;
}
path.reverse();
return Some(path);
}
let current_g = *g_score.get(¤t.pos).unwrap_or(&usize::MAX);
for neighbor in self.neighbors(current.pos) {
let tentative_g = current_g + 1;
if tentative_g < *g_score.get(&neighbor).unwrap_or(&usize::MAX) {
came_from.insert(neighbor, current.pos);
g_score.insert(neighbor, tentative_g);
let f = tentative_g + Self::heuristic(neighbor, goal);
open_set.push(AStarNode {
pos: neighbor,
f_score: f,
});
}
}
}
None }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_path() {
let grid = PathGrid::new(10, 10);
let path = grid.find_path(Pos::new(0, 0), Pos::new(5, 5));
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.first(), Some(&Pos::new(0, 0)));
assert_eq!(path.last(), Some(&Pos::new(5, 5)));
assert_eq!(path.len(), 11); }
#[test]
fn test_path_around_obstacle() {
let mut grid = PathGrid::new(10, 10);
for y in 0..8 {
grid.block_rect(5, y, 1, 1);
}
let path = grid.find_path(Pos::new(3, 5), Pos::new(7, 5));
assert!(path.is_some());
let path = path.unwrap();
assert!(path.iter().all(|p| p.x != 5 || p.y >= 8));
}
#[test]
fn test_no_path() {
let mut grid = PathGrid::new(10, 10);
for y in 0..10 {
grid.block_rect(5, y, 1, 1);
}
let path = grid.find_path(Pos::new(3, 5), Pos::new(7, 5));
assert!(path.is_none());
}
}