use ahash::{HashMap, HashMapExt};
use glam::IVec2;
use sark_grids::GridPoint;
use std::collections::hash_map::Entry;
use crate::{min_heap::MinHeap, pathmap::PathMap};
#[derive(Default)]
pub struct Pathfinder {
frontier: MinHeap,
came_from: HashMap<IVec2, IVec2>,
costs: HashMap<IVec2, i32>,
path: Vec<IVec2>,
}
impl Pathfinder {
pub fn new() -> Self {
Self::default()
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
frontier: MinHeap::with_capacity(capacity),
came_from: HashMap::with_capacity(capacity),
costs: HashMap::with_capacity(capacity),
path: Vec::with_capacity(capacity / 4),
}
}
pub fn astar(
&mut self,
map: &impl PathMap,
start: impl GridPoint,
goal: impl GridPoint,
) -> Option<&[IVec2]> {
self.clear();
let start = start.to_ivec2();
let goal = goal.to_ivec2();
self.frontier.push(start, 0);
self.costs.insert(start, 0);
while let Some(curr) = self.frontier.pop() {
if curr == goal {
break;
}
for next in map.exits(curr) {
let new_cost = self.costs[&curr] + map.cost(curr, next);
if !self.costs.contains_key(&next) || new_cost < self.costs[&next] {
self.costs.insert(next, new_cost);
self.frontier
.push(next, new_cost + map.distance(goal, next));
self.came_from.insert(next, curr);
}
}
}
self.build_path(start, goal)
}
pub fn dijkstra(
&mut self,
map: &impl PathMap,
start: Option<impl GridPoint>,
goal: impl GridPoint,
) {
self.clear();
let start = start.map(|s| s.to_ivec2());
let goal = goal.to_ivec2();
let p = start.unwrap_or(goal);
self.frontier.push(p, 0);
self.costs.insert(p, 0);
while let Some(curr) = self.frontier.pop() {
if start.is_some() && curr == goal {
break;
}
for next in map.exits(curr) {
let new_cost = self.costs[&curr] + map.cost(curr, next);
let next_cost = self.costs.get(&next);
if next_cost.is_none() || new_cost < *next_cost.unwrap() {
self.costs.insert(next, new_cost);
self.frontier.push(next, new_cost);
self.came_from.insert(next, curr);
}
}
}
}
pub fn bfs(&mut self, map: &impl PathMap, start: Option<impl GridPoint>, goal: impl GridPoint) {
self.clear();
let start = start.map(|s| s.to_ivec2());
let goal = goal.to_ivec2();
let p = start.unwrap_or(goal);
self.frontier.push(p, 0);
while let Some(curr) = self.frontier.pop() {
if start.is_some() && curr == goal {
break;
}
for next in map.exits(curr) {
if let Entry::Vacant(_) = self.came_from.entry(next) {
self.frontier.push(next, self.frontier.len() as i32);
self.came_from.insert(next, curr);
}
}
}
}
pub fn build_path(&mut self, start: impl GridPoint, goal: impl GridPoint) -> Option<&[IVec2]> {
let mut curr = goal.to_ivec2();
let start = start.to_ivec2();
self.path.clear();
self.path.push(curr);
while let Some(next) = self.came_from.get(&curr) {
self.path.push(*next);
if *next == start {
self.path.reverse();
return Some(self.path.as_slice());
}
curr = *next;
}
None
}
pub fn clear(&mut self) {
self.frontier.clear();
self.came_from.clear();
self.costs.clear();
self.path.clear();
}
pub fn visited(&self) -> impl Iterator<Item = &IVec2> {
self.came_from.keys()
}
pub fn came_from(&self) -> &HashMap<IVec2, IVec2> {
&self.came_from
}
pub fn costs(&self) -> &HashMap<IVec2, i32> {
&self.costs
}
pub fn path(&self) -> &[IVec2] {
&self.path
}
}
#[cfg(test)]
mod test {
use crate::PathMap2d;
use super::*;
#[test]
fn right_test() {
let map = PathMap2d::new([10, 10]);
let mut pf = Pathfinder::new();
let path = pf.astar(&map, [0, 0], [5, 0]).unwrap();
assert_eq!(6, path.len());
assert_eq!([0, 0], path[0].to_array());
assert_eq!([5, 0], path[5].to_array());
}
#[test]
fn down_test() {
let map = PathMap2d::new([10, 10]);
let mut astar = Pathfinder::new();
let path = astar.astar(&map, [5, 5], [5, 0]).unwrap();
assert_eq!(6, path.len());
assert_eq!([5, 5], path[0].to_array());
assert_eq!([5, 0], path[5].to_array());
}
#[test]
fn up_test() {
let map = PathMap2d::new([10, 10]);
let mut astar = Pathfinder::new();
let path = astar.astar(&map, [5, 4], [5, 9]).unwrap();
assert_eq!(6, path.len());
assert_eq!([5, 4], path[0].to_array());
assert_eq!([5, 9], path[5].to_array());
}
#[test]
fn left_test() {
let map = PathMap2d::new([10, 10]);
let mut astar = Pathfinder::new();
let path = astar.astar(&map, [9, 5], [4, 5]).unwrap();
assert_eq!(6, path.len());
assert_eq!([9, 5], path[0].to_array());
assert_eq!([4, 5], path[5].to_array());
}
}