#![warn(missing_docs)]
use crate::core::{
algebra::Vector3,
math::{self, PositionProvider},
visitor::prelude::*,
};
use std::{
cmp::Ordering,
collections::BinaryHeap,
fmt::{Display, Formatter},
ops::{Deref, DerefMut},
};
#[derive(Clone, Debug, Visit, PartialEq)]
pub struct VertexData {
pub position: Vector3<f32>,
pub neighbours: Vec<u32>,
#[visit(skip)]
pub g_penalty: f32,
}
impl Default for VertexData {
fn default() -> Self {
Self {
position: Default::default(),
g_penalty: 1f32,
neighbours: Default::default(),
}
}
}
impl VertexData {
pub fn new(position: Vector3<f32>) -> Self {
Self {
position,
g_penalty: 1f32,
neighbours: Default::default(),
}
}
}
pub trait VertexDataProvider:
Deref<Target = VertexData> + DerefMut + PositionProvider + Default + Visit + 'static
{
}
#[derive(Default, PartialEq, Debug)]
pub struct GraphVertex {
pub data: VertexData,
}
impl GraphVertex {
pub fn new(position: Vector3<f32>) -> Self {
Self {
data: VertexData::new(position),
}
}
}
impl Deref for GraphVertex {
type Target = VertexData;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for GraphVertex {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl PositionProvider for GraphVertex {
fn position(&self) -> Vector3<f32> {
self.data.position
}
}
impl Visit for GraphVertex {
fn visit(&mut self, name: &str, visitor: &mut Visitor) -> VisitResult {
self.data.visit(name, visitor)
}
}
impl VertexDataProvider for GraphVertex {}
#[derive(Clone, Debug, Visit, PartialEq)]
pub struct Graph<T>
where
T: VertexDataProvider,
{
pub vertices: Vec<T>,
pub max_search_iterations: i32,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum PathKind {
Full,
Partial,
}
fn heuristic(a: Vector3<f32>, b: Vector3<f32>) -> f32 {
(a - b).norm_squared()
}
impl<T: VertexDataProvider> Default for Graph<T> {
fn default() -> Self {
Self::new()
}
}
impl PositionProvider for VertexData {
fn position(&self) -> Vector3<f32> {
self.position
}
}
#[derive(Clone, Debug)]
pub enum PathError {
InvalidIndex(usize),
CyclicReferenceFound(usize),
HitMaxSearchIterations(i32),
Empty,
}
impl Display for PathError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
PathError::InvalidIndex(v) => {
write!(f, "Invalid vertex index {v}.")
}
PathError::CyclicReferenceFound(v) => {
write!(f, "Cyclical reference was found {v}.")
}
PathError::HitMaxSearchIterations(v) => {
write!(
f,
"Maximum search iterations ({v}) hit, returning with partial path."
)
}
PathError::Empty => {
write!(f, "Graph was empty")
}
}
}
}
#[derive(Clone)]
pub struct PartialPath {
vertices: Vec<usize>,
g_score: f32,
f_score: f32,
}
impl Default for PartialPath {
fn default() -> Self {
Self {
vertices: Vec::new(),
g_score: f32::MAX,
f_score: f32::MAX,
}
}
}
impl Ord for PartialPath {
fn cmp(&self, other: &Self) -> Ordering {
(self.f_score.total_cmp(&other.f_score))
.then((self.f_score - self.g_score).total_cmp(&(other.f_score - other.g_score)))
.reverse()
}
}
impl PartialOrd for PartialPath {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for PartialPath {
fn eq(&self, other: &Self) -> bool {
self.f_score == other.f_score && self.g_score == other.g_score
}
}
impl Eq for PartialPath {}
impl PartialPath {
pub fn new(start: usize) -> Self {
Self {
vertices: vec![start],
g_score: 0f32,
f_score: f32::MAX,
}
}
pub fn clone_and_add(
&self,
new_vertex: usize,
new_g_score: f32,
new_f_score: f32,
) -> PartialPath {
let mut clone = self.clone();
clone.vertices.push(new_vertex);
clone.g_score = new_g_score;
clone.f_score = new_f_score;
clone
}
}
impl<T: VertexDataProvider> Graph<T> {
pub fn new() -> Self {
Self {
vertices: Default::default(),
max_search_iterations: 1000i32,
}
}
pub fn set_vertices(&mut self, vertices: Vec<T>) {
self.vertices = vertices;
}
pub fn get_closest_vertex_to(&self, point: Vector3<f32>) -> Option<usize> {
math::get_closest_point(&self.vertices, point)
}
pub fn link_bidirect(&mut self, a: usize, b: usize) {
self.link_unidirect(a, b);
self.link_unidirect(b, a);
}
pub fn link_unidirect(&mut self, a: usize, b: usize) {
if let Some(vertex_a) = self.vertices.get_mut(a) {
if vertex_a.neighbours.iter().all(|n| *n != b as u32) {
vertex_a.neighbours.push(b as u32);
}
}
}
pub fn vertex(&self, index: usize) -> Option<&T> {
self.vertices.get(index)
}
pub fn vertex_mut(&mut self, index: usize) -> Option<&mut T> {
self.vertices.get_mut(index)
}
pub fn vertices(&self) -> &[T] {
&self.vertices
}
pub fn vertices_mut(&mut self) -> &mut [T] {
&mut self.vertices
}
pub fn add_vertex(&mut self, vertex: T) -> u32 {
let index = self.vertices.len();
self.vertices.push(vertex);
index as u32
}
pub fn pop_vertex(&mut self) -> Option<T> {
if self.vertices.is_empty() {
None
} else {
Some(self.remove_vertex(self.vertices.len() - 1))
}
}
pub fn remove_vertex(&mut self, index: usize) -> T {
for other_vertex in self.vertices.iter_mut() {
if let Some(position) = other_vertex
.neighbours
.iter()
.position(|n| *n == index as u32)
{
other_vertex.neighbours.remove(position);
}
for neighbour_index in other_vertex.neighbours.iter_mut() {
if *neighbour_index > index as u32 {
*neighbour_index -= 1;
}
}
}
self.vertices.remove(index)
}
pub fn insert_vertex(&mut self, index: u32, vertex: T) {
self.vertices.insert(index as usize, vertex);
for other_vertex in self.vertices.iter_mut() {
for neighbour_index in other_vertex.neighbours.iter_mut() {
if *neighbour_index >= index {
*neighbour_index += 1;
}
}
}
}
pub fn build_indexed_path(
&self,
from: usize,
to: usize,
path: &mut Vec<usize>,
) -> Result<PathKind, PathError> {
path.clear();
if self.vertices.is_empty() {
return Err(PathError::Empty);
}
let end_pos = self
.vertices
.get(to)
.ok_or(PathError::InvalidIndex(to))?
.position;
if from == to {
path.push(to);
return Ok(PathKind::Full);
}
let mut searched_vertices = vec![false; self.vertices.len()];
let mut search_heap: BinaryHeap<PartialPath> = BinaryHeap::new();
search_heap.push(PartialPath::new(from));
let mut best_path = PartialPath::default();
let mut search_iteration = 0i32;
while self.max_search_iterations < 0 || search_iteration < self.max_search_iterations {
if search_heap.is_empty() {
break;
}
let current_path = search_heap.pop().unwrap();
let current_index = *current_path.vertices.last().unwrap();
let current_vertex = self
.vertices
.get(current_index)
.ok_or(PathError::InvalidIndex(current_index))?;
if current_path > best_path {
best_path = current_path.clone();
if current_index == to {
break;
}
}
for i in current_vertex.neighbours.iter() {
let neighbour_index = *i as usize;
if neighbour_index == current_index {
return Err(PathError::CyclicReferenceFound(current_index));
}
if searched_vertices[neighbour_index] {
continue;
}
let neighbour = self
.vertices
.get(neighbour_index)
.ok_or(PathError::InvalidIndex(neighbour_index))?;
let neighbour_g_score = current_path.g_score
+ ((current_vertex.position - neighbour.position).norm_squared()
* neighbour.g_penalty);
let neighbour_f_score = neighbour_g_score + heuristic(neighbour.position, end_pos);
search_heap.push(current_path.clone_and_add(
neighbour_index,
neighbour_g_score,
neighbour_f_score,
));
}
searched_vertices[current_index] = true;
search_iteration += 1;
}
path.clone_from(&best_path.vertices);
path.reverse();
if *path.first().unwrap() == to {
Ok(PathKind::Full)
} else if search_iteration == self.max_search_iterations - 1 {
Err(PathError::HitMaxSearchIterations(
self.max_search_iterations,
))
} else {
Ok(PathKind::Partial)
}
}
pub fn build_positional_path(
&self,
from: usize,
to: usize,
path: &mut Vec<Vector3<f32>>,
) -> Result<PathKind, PathError> {
path.clear();
let mut indices: Vec<usize> = Vec::new();
let path_kind = self.build_indexed_path(from, to, &mut indices)?;
for index in indices.iter() {
let vertex = self
.vertices
.get(*index)
.ok_or(PathError::InvalidIndex(*index))?;
path.push(vertex.position);
}
Ok(path_kind)
}
#[deprecated = "name is too ambiguous use build_positional_path instead"]
pub fn build(
&self,
from: usize,
to: usize,
path: &mut Vec<Vector3<f32>>,
) -> Result<PathKind, PathError> {
self.build_positional_path(from, to, path)
}
}
#[cfg(test)]
mod test {
use crate::rand::Rng;
use crate::utils::astar::PathError;
use crate::{
core::{algebra::Vector3, rand},
utils::astar::{Graph, GraphVertex, PathKind},
};
use std::time::Instant;
#[test]
fn astar_random_points() {
let mut pathfinder = Graph::<GraphVertex>::new();
let mut path = Vec::new();
assert!(pathfinder
.build_positional_path(0, 0, &mut path)
.is_err_and(|e| matches!(e, PathError::Empty)));
assert!(path.is_empty());
let size = 40;
let mut vertices = Vec::new();
for y in 0..size {
for x in 0..size {
vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
}
}
pathfinder.set_vertices(vertices);
assert!(pathfinder
.build_positional_path(100000, 99999, &mut path)
.is_err_and(|e| matches!(e, PathError::InvalidIndex(_))));
for y in 0..(size - 1) {
for x in 0..(size - 1) {
pathfinder.link_bidirect(y * size + x, y * size + x + 1);
pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
}
}
let mut paths_count = 0;
for _ in 0..1000 {
let sx = rand::thread_rng().gen_range(0..(size - 1));
let sy = rand::thread_rng().gen_range(0..(size - 1));
let tx = rand::thread_rng().gen_range(0..(size - 1));
let ty = rand::thread_rng().gen_range(0..(size - 1));
let from = sy * size + sx;
let to = ty * size + tx;
assert!(pathfinder
.build_positional_path(from, to, &mut path)
.is_ok());
assert!(!path.is_empty());
if path.len() > 1 {
paths_count += 1;
assert_eq!(
*path.first().unwrap(),
pathfinder.vertex(to).unwrap().position
);
assert_eq!(
*path.last().unwrap(),
pathfinder.vertex(from).unwrap().position
);
} else {
let point = *path.first().unwrap();
assert_eq!(point, pathfinder.vertex(to).unwrap().position);
assert_eq!(point, pathfinder.vertex(from).unwrap().position);
}
for pair in path.chunks(2) {
if pair.len() == 2 {
let a = pair[0];
let b = pair[1];
assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
}
}
}
assert!(paths_count > 0);
}
#[test]
fn test_remove_vertex() {
let mut pathfinder = Graph::<GraphVertex>::new();
pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
pathfinder.link_bidirect(0, 1);
pathfinder.link_bidirect(1, 2);
pathfinder.link_bidirect(2, 0);
pathfinder.remove_vertex(0);
assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1]);
assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0]);
assert_eq!(pathfinder.vertex(2), None);
pathfinder.remove_vertex(0);
assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
assert_eq!(pathfinder.vertex(1), None);
assert_eq!(pathfinder.vertex(2), None);
}
#[test]
fn test_insert_vertex() {
let mut pathfinder = Graph::new();
pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
pathfinder.link_bidirect(0, 1);
pathfinder.link_bidirect(1, 2);
pathfinder.link_bidirect(2, 0);
assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1, 2]);
assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0, 2]);
assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 0]);
pathfinder.insert_vertex(0, GraphVertex::new(Vector3::new(1.0, 1.0, 1.0)));
assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![2, 3]);
assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 3]);
assert_eq!(pathfinder.vertex(3).unwrap().neighbours, vec![2, 1]);
}
#[ignore = "takes multiple seconds to run"]
#[test]
fn astar_complete_grid_benchmark() {
let start_time = Instant::now();
let mut path = Vec::new();
println!();
for size in [10, 40, 100, 500] {
println!("benchmarking grid size of: {size}^2");
let setup_start_time = Instant::now();
let mut pathfinder = Graph::new();
let mut vertices = Vec::new();
for y in 0..size {
for x in 0..size {
vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
}
}
pathfinder.set_vertices(vertices);
for y in 0..(size - 1) {
for x in 0..(size - 1) {
pathfinder.link_bidirect(y * size + x, y * size + x + 1);
pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
}
}
let setup_complete_time = Instant::now();
println!(
"setup in: {:?}",
setup_complete_time.duration_since(setup_start_time)
);
for _ in 0..1000 {
let sx = rand::thread_rng().gen_range(0..(size - 1));
let sy = rand::thread_rng().gen_range(0..(size - 1));
let tx = rand::thread_rng().gen_range(0..(size - 1));
let ty = rand::thread_rng().gen_range(0..(size - 1));
let from = sy * size + sx;
let to = ty * size + tx;
assert!(pathfinder
.build_positional_path(from, to, &mut path)
.is_ok());
assert!(!path.is_empty());
if path.len() > 1 {
assert_eq!(
*path.first().unwrap(),
pathfinder.vertex(to).unwrap().position
);
assert_eq!(
*path.last().unwrap(),
pathfinder.vertex(from).unwrap().position
);
} else {
let point = *path.first().unwrap();
assert_eq!(point, pathfinder.vertex(to).unwrap().position);
assert_eq!(point, pathfinder.vertex(from).unwrap().position);
}
for pair in path.chunks(2) {
if pair.len() == 2 {
let a = pair[0];
let b = pair[1];
assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
}
}
}
println!("paths found in: {:?}", setup_complete_time.elapsed());
println!(
"Current size complete in: {:?}\n",
setup_start_time.elapsed()
);
}
println!("Total time: {:?}\n", start_time.elapsed());
}
#[ignore = "takes multiple seconds to run"]
#[test]
fn astar_island_benchmark() {
let start_time = Instant::now();
let size = 100;
let mut path = Vec::new();
let mut pathfinder = Graph::new();
let mut vertices = Vec::new();
for y in 0..size {
for x in 0..size {
vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
}
}
pathfinder.set_vertices(vertices);
for y in 0..(size - 1) {
for x in 0..(size - 1) {
if x != ((size / 2) - 1) {
pathfinder.link_bidirect(y * size + x, y * size + x + 1);
}
pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
}
}
let setup_complete_time = Instant::now();
println!(
"\nsetup in: {:?}",
setup_complete_time.duration_since(start_time)
);
for _ in 0..1000 {
let sx = rand::thread_rng().gen_range(0..((size / 2) - 1));
let sy = rand::thread_rng().gen_range(0..(size - 1));
let tx = rand::thread_rng().gen_range((size / 2)..(size - 1));
let ty = rand::thread_rng().gen_range(0..(size - 1));
let from = sy * size + sx;
let to = ty * size + tx;
let path_result = pathfinder.build_positional_path(from, to, &mut path);
let is_result_expected = path_result.as_ref().is_ok_and(|k| k.eq(&PathKind::Partial))
|| path_result.is_err_and(|e| matches!(e, PathError::HitMaxSearchIterations(_)));
assert!(is_result_expected);
assert!(!path.is_empty());
if path.len() > 1 {
assert_eq!(path.first().unwrap().x as i32, ((size / 2) - 1) as i32);
assert_eq!(
*path.last().unwrap(),
pathfinder.vertex(from).unwrap().position
);
} else {
let point = *path.first().unwrap();
assert_eq!(point, pathfinder.vertex(to).unwrap().position);
assert_eq!(point, pathfinder.vertex(from).unwrap().position);
}
for pair in path.chunks(2) {
if pair.len() == 2 {
let a = pair[0];
let b = pair[1];
assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
}
}
}
println!("paths found in: {:?}", setup_complete_time.elapsed());
println!("Total time: {:?}\n", start_time.elapsed());
}
#[ignore = "takes multiple seconds to run"]
#[test]
fn astar_backwards_travel_benchmark() {
let start_time = Instant::now();
let size = 100;
let mut path = Vec::new();
let mut pathfinder = Graph::new();
let mut vertices = Vec::new();
for y in 0..size {
for x in 0..size {
vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
}
}
pathfinder.set_vertices(vertices);
for y in 0..(size - 1) {
for x in (0..(size - 1)).rev() {
if y == 0 || x != y {
pathfinder.link_bidirect(y * size + x, y * size + x + 1);
pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
}
}
}
let setup_complete_time = Instant::now();
println!(
"\nsetup in: {:?}",
setup_complete_time.duration_since(start_time)
);
for _ in 0..1000 {
let from = (size / 2) * size + (size - 1);
let to = (size - 1) * size + (size / 2);
assert!(pathfinder
.build_positional_path(from, to, &mut path)
.is_ok());
assert!(!path.is_empty());
if path.len() > 1 {
assert_eq!(
*path.first().unwrap(),
pathfinder.vertex(to).unwrap().position
);
assert_eq!(
*path.last().unwrap(),
pathfinder.vertex(from).unwrap().position
);
} else {
let point = *path.first().unwrap();
assert_eq!(point, pathfinder.vertex(to).unwrap().position);
assert_eq!(point, pathfinder.vertex(from).unwrap().position);
}
for pair in path.chunks(2) {
if pair.len() == 2 {
let a = pair[0];
let b = pair[1];
assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
}
}
}
println!("paths found in: {:?}", setup_complete_time.elapsed());
println!("Total time: {:?}\n", start_time.elapsed());
}
}