1#![warn(missing_docs)]
28
29use crate::core::{
30 algebra::Vector3,
31 math::{self, PositionProvider},
32 visitor::prelude::*,
33};
34
35use std::{
36 cmp::Ordering,
37 collections::BinaryHeap,
38 fmt::{Display, Formatter},
39 ops::{Deref, DerefMut},
40};
41
42#[derive(Clone, Debug, Visit, PartialEq)]
45pub struct VertexData {
46 pub position: Vector3<f32>,
48 pub neighbours: Vec<u32>,
50 #[visit(skip)]
52 pub g_penalty: f32,
53}
54
55impl Default for VertexData {
56 fn default() -> Self {
57 Self {
58 position: Default::default(),
59 g_penalty: 1f32,
60 neighbours: Default::default(),
61 }
62 }
63}
64
65impl VertexData {
66 pub fn new(position: Vector3<f32>) -> Self {
68 Self {
69 position,
70 g_penalty: 1f32,
71 neighbours: Default::default(),
72 }
73 }
74}
75
76pub trait VertexDataProvider:
79 Deref<Target = VertexData> + DerefMut + PositionProvider + Default + Visit + 'static
80{
81}
82
83#[derive(Default, PartialEq, Debug)]
85pub struct GraphVertex {
86 pub data: VertexData,
88}
89
90impl GraphVertex {
91 pub fn new(position: Vector3<f32>) -> Self {
93 Self {
94 data: VertexData::new(position),
95 }
96 }
97}
98
99impl Deref for GraphVertex {
100 type Target = VertexData;
101
102 fn deref(&self) -> &Self::Target {
103 &self.data
104 }
105}
106
107impl DerefMut for GraphVertex {
108 fn deref_mut(&mut self) -> &mut Self::Target {
109 &mut self.data
110 }
111}
112
113impl PositionProvider for GraphVertex {
114 fn position(&self) -> Vector3<f32> {
115 self.data.position
116 }
117}
118
119impl Visit for GraphVertex {
120 fn visit(&mut self, name: &str, visitor: &mut Visitor) -> VisitResult {
121 self.data.visit(name, visitor)
122 }
123}
124
125impl VertexDataProvider for GraphVertex {}
126
127#[derive(Clone, Debug, Visit, PartialEq)]
131pub struct Graph<T>
132where
133 T: VertexDataProvider,
134{
135 pub vertices: Vec<T>,
137 pub max_search_iterations: i32,
149}
150
151#[derive(Copy, Clone, PartialEq, Eq, Debug)]
153pub enum PathKind {
154 Full,
156 Partial,
164}
165
166fn heuristic(a: Vector3<f32>, b: Vector3<f32>) -> f32 {
167 (a - b).norm_squared()
168}
169
170impl<T: VertexDataProvider> Default for Graph<T> {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176impl PositionProvider for VertexData {
177 fn position(&self) -> Vector3<f32> {
178 self.position
179 }
180}
181
182#[derive(Clone, Debug)]
185pub enum PathError {
186 InvalidIndex(usize),
189
190 CyclicReferenceFound(usize),
192
193 HitMaxSearchIterations(i32),
203
204 Empty,
206}
207
208impl Display for PathError {
209 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210 match self {
211 PathError::InvalidIndex(v) => {
212 write!(f, "Invalid vertex index {v}.")
213 }
214 PathError::CyclicReferenceFound(v) => {
215 write!(f, "Cyclical reference was found {v}.")
216 }
217 PathError::HitMaxSearchIterations(v) => {
218 write!(
219 f,
220 "Maximum search iterations ({v}) hit, returning with partial path."
221 )
222 }
223 PathError::Empty => {
224 write!(f, "Graph was empty")
225 }
226 }
227 }
228}
229
230#[derive(Clone)]
231pub struct PartialPath {
233 vertices: Vec<usize>,
234 g_score: f32,
235 f_score: f32,
236}
237
238impl Default for PartialPath {
239 fn default() -> Self {
240 Self {
241 vertices: Vec::new(),
242 g_score: f32::MAX,
243 f_score: f32::MAX,
244 }
245 }
246}
247
248impl Ord for PartialPath {
249 fn cmp(&self, other: &Self) -> Ordering {
251 (self.f_score.total_cmp(&other.f_score))
252 .then((self.f_score - self.g_score).total_cmp(&(other.f_score - other.g_score)))
253 .reverse()
254 }
255}
256
257impl PartialOrd for PartialPath {
258 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
260 Some(self.cmp(other))
261 }
262}
263
264impl PartialEq for PartialPath {
265 fn eq(&self, other: &Self) -> bool {
267 self.f_score == other.f_score && self.g_score == other.g_score
268 }
269}
270
271impl Eq for PartialPath {}
272
273impl PartialPath {
274 pub fn new(start: usize) -> Self {
276 Self {
277 vertices: vec![start],
278 g_score: 0f32,
279 f_score: f32::MAX,
280 }
281 }
282
283 pub fn clone_and_add(
285 &self,
286 new_vertex: usize,
287 new_g_score: f32,
288 new_f_score: f32,
289 ) -> PartialPath {
290 let mut clone = self.clone();
291 clone.vertices.push(new_vertex);
292 clone.g_score = new_g_score;
293 clone.f_score = new_f_score;
294
295 clone
296 }
297}
298
299impl<T: VertexDataProvider> Graph<T> {
300 pub fn new() -> Self {
302 Self {
303 vertices: Default::default(),
304 max_search_iterations: 1000i32,
305 }
306 }
307
308 pub fn set_vertices(&mut self, vertices: Vec<T>) {
312 self.vertices = vertices;
313 }
314
315 pub fn get_closest_vertex_to(&self, point: Vector3<f32>) -> Option<usize> {
321 math::get_closest_point(&self.vertices, point)
322 }
323
324 pub fn link_bidirect(&mut self, a: usize, b: usize) {
328 self.link_unidirect(a, b);
329 self.link_unidirect(b, a);
330 }
331
332 pub fn link_unidirect(&mut self, a: usize, b: usize) {
335 if let Some(vertex_a) = self.vertices.get_mut(a) {
336 if vertex_a.neighbours.iter().all(|n| *n != b as u32) {
337 vertex_a.neighbours.push(b as u32);
338 }
339 }
340 }
341
342 pub fn vertex(&self, index: usize) -> Option<&T> {
344 self.vertices.get(index)
345 }
346
347 pub fn vertex_mut(&mut self, index: usize) -> Option<&mut T> {
349 self.vertices.get_mut(index)
350 }
351
352 pub fn vertices(&self) -> &[T] {
354 &self.vertices
355 }
356
357 pub fn vertices_mut(&mut self) -> &mut [T] {
359 &mut self.vertices
360 }
361
362 pub fn add_vertex(&mut self, vertex: T) -> u32 {
364 let index = self.vertices.len();
365 self.vertices.push(vertex);
368 index as u32
369 }
370
371 pub fn pop_vertex(&mut self) -> Option<T> {
375 if self.vertices.is_empty() {
376 None
377 } else {
378 Some(self.remove_vertex(self.vertices.len() - 1))
379 }
380 }
381
382 pub fn remove_vertex(&mut self, index: usize) -> T {
386 for other_vertex in self.vertices.iter_mut() {
387 if let Some(position) = other_vertex
389 .neighbours
390 .iter()
391 .position(|n| *n == index as u32)
392 {
393 other_vertex.neighbours.remove(position);
394 }
395
396 for neighbour_index in other_vertex.neighbours.iter_mut() {
398 if *neighbour_index > index as u32 {
399 *neighbour_index -= 1;
400 }
401 }
402 }
403
404 self.vertices.remove(index)
405 }
406
407 pub fn insert_vertex(&mut self, index: u32, vertex: T) {
410 self.vertices.insert(index as usize, vertex);
411
412 for other_vertex in self.vertices.iter_mut() {
414 for neighbour_index in other_vertex.neighbours.iter_mut() {
415 if *neighbour_index >= index {
416 *neighbour_index += 1;
417 }
418 }
419 }
420 }
421
422 pub fn build_indexed_path(
435 &self,
436 from: usize,
437 to: usize,
438 path: &mut Vec<usize>,
439 ) -> Result<PathKind, PathError> {
440 path.clear();
441
442 if self.vertices.is_empty() {
443 return Err(PathError::Empty);
444 }
445
446 let end_pos = self
447 .vertices
448 .get(to)
449 .ok_or(PathError::InvalidIndex(to))?
450 .position;
451
452 if from == to {
454 path.push(to);
455 return Ok(PathKind::Full);
456 }
457
458 let mut searched_vertices = vec![false; self.vertices.len()];
460
461 let mut search_heap: BinaryHeap<PartialPath> = BinaryHeap::new();
463
464 search_heap.push(PartialPath::new(from));
466
467 let mut best_path = PartialPath::default();
469
470 let mut search_iteration = 0i32;
472
473 while self.max_search_iterations < 0 || search_iteration < self.max_search_iterations {
474 if search_heap.is_empty() {
476 break;
477 }
478
479 let current_path = search_heap.pop().unwrap();
481
482 let current_index = *current_path.vertices.last().unwrap();
483 let current_vertex = self
484 .vertices
485 .get(current_index)
486 .ok_or(PathError::InvalidIndex(current_index))?;
487
488 if current_path > best_path {
490 best_path = current_path.clone();
491
492 if current_index == to {
494 break;
495 }
496 }
497
498 for i in current_vertex.neighbours.iter() {
500 let neighbour_index = *i as usize;
501
502 if neighbour_index == current_index {
505 return Err(PathError::CyclicReferenceFound(current_index));
506 }
507
508 if searched_vertices[neighbour_index] {
510 continue;
511 }
512
513 let neighbour = self
514 .vertices
515 .get(neighbour_index)
516 .ok_or(PathError::InvalidIndex(neighbour_index))?;
517
518 let neighbour_g_score = current_path.g_score
519 + ((current_vertex.position - neighbour.position).norm_squared()
520 * neighbour.g_penalty);
521
522 let neighbour_f_score = neighbour_g_score + heuristic(neighbour.position, end_pos);
523
524 search_heap.push(current_path.clone_and_add(
525 neighbour_index,
526 neighbour_g_score,
527 neighbour_f_score,
528 ));
529 }
530
531 searched_vertices[current_index] = true;
533
534 search_iteration += 1;
535 }
536
537 path.clone_from(&best_path.vertices);
539 path.reverse();
540
541 if *path.first().unwrap() == to {
542 Ok(PathKind::Full)
543 } else if search_iteration == self.max_search_iterations - 1 {
544 Err(PathError::HitMaxSearchIterations(
545 self.max_search_iterations,
546 ))
547 } else {
548 Ok(PathKind::Partial)
549 }
550 }
551
552 pub fn build_positional_path(
565 &self,
566 from: usize,
567 to: usize,
568 path: &mut Vec<Vector3<f32>>,
569 ) -> Result<PathKind, PathError> {
570 path.clear();
571
572 let mut indices: Vec<usize> = Vec::new();
573 let path_kind = self.build_indexed_path(from, to, &mut indices)?;
574
575 for index in indices.iter() {
577 let vertex = self
578 .vertices
579 .get(*index)
580 .ok_or(PathError::InvalidIndex(*index))?;
581
582 path.push(vertex.position);
583 }
584
585 Ok(path_kind)
586 }
587
588 #[deprecated = "name is too ambiguous use build_positional_path instead"]
603 pub fn build(
604 &self,
605 from: usize,
606 to: usize,
607 path: &mut Vec<Vector3<f32>>,
608 ) -> Result<PathKind, PathError> {
609 self.build_positional_path(from, to, path)
610 }
611}
612
613#[cfg(test)]
614mod test {
615 use crate::rand::Rng;
616 use crate::utils::astar::PathError;
617 use crate::{
618 core::{algebra::Vector3, rand},
619 utils::astar::{Graph, GraphVertex, PathKind},
620 };
621 use std::time::Instant;
622
623 #[test]
624 fn astar_random_points() {
625 let mut pathfinder = Graph::<GraphVertex>::new();
626
627 let mut path = Vec::new();
628 assert!(pathfinder
629 .build_positional_path(0, 0, &mut path)
630 .is_err_and(|e| matches!(e, PathError::Empty)));
631 assert!(path.is_empty());
632
633 let size = 40;
634
635 let mut vertices = Vec::new();
637 for y in 0..size {
638 for x in 0..size {
639 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
640 }
641 }
642 pathfinder.set_vertices(vertices);
643
644 assert!(pathfinder
645 .build_positional_path(100000, 99999, &mut path)
646 .is_err_and(|e| matches!(e, PathError::InvalidIndex(_))));
647
648 for y in 0..(size - 1) {
650 for x in 0..(size - 1) {
651 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
652 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
653 }
654 }
655
656 let mut paths_count = 0;
657
658 for _ in 0..1000 {
659 let sx = rand::thread_rng().gen_range(0..(size - 1));
660 let sy = rand::thread_rng().gen_range(0..(size - 1));
661
662 let tx = rand::thread_rng().gen_range(0..(size - 1));
663 let ty = rand::thread_rng().gen_range(0..(size - 1));
664
665 let from = sy * size + sx;
666 let to = ty * size + tx;
667
668 assert!(pathfinder
669 .build_positional_path(from, to, &mut path)
670 .is_ok());
671 assert!(!path.is_empty());
672
673 if path.len() > 1 {
674 paths_count += 1;
675
676 assert_eq!(
677 *path.first().unwrap(),
678 pathfinder.vertex(to).unwrap().position
679 );
680 assert_eq!(
681 *path.last().unwrap(),
682 pathfinder.vertex(from).unwrap().position
683 );
684 } else {
685 let point = *path.first().unwrap();
686 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
687 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
688 }
689
690 for pair in path.chunks(2) {
691 if pair.len() == 2 {
692 let a = pair[0];
693 let b = pair[1];
694
695 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
696 }
697 }
698 }
699
700 assert!(paths_count > 0);
701 }
702
703 #[test]
704 fn test_remove_vertex() {
705 let mut pathfinder = Graph::<GraphVertex>::new();
706
707 pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
708 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
709 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
710
711 pathfinder.link_bidirect(0, 1);
712 pathfinder.link_bidirect(1, 2);
713 pathfinder.link_bidirect(2, 0);
714
715 pathfinder.remove_vertex(0);
716
717 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1]);
718 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0]);
719 assert_eq!(pathfinder.vertex(2), None);
720
721 pathfinder.remove_vertex(0);
722
723 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
724 assert_eq!(pathfinder.vertex(1), None);
725 assert_eq!(pathfinder.vertex(2), None);
726 }
727
728 #[test]
729 fn test_insert_vertex() {
730 let mut pathfinder = Graph::new();
731
732 pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
733 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
734 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
735
736 pathfinder.link_bidirect(0, 1);
737 pathfinder.link_bidirect(1, 2);
738 pathfinder.link_bidirect(2, 0);
739
740 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1, 2]);
741 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0, 2]);
742 assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 0]);
743
744 pathfinder.insert_vertex(0, GraphVertex::new(Vector3::new(1.0, 1.0, 1.0)));
745
746 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
747 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![2, 3]);
748 assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 3]);
749 assert_eq!(pathfinder.vertex(3).unwrap().neighbours, vec![2, 1]);
750 }
751
752 #[ignore = "takes multiple seconds to run"]
753 #[test]
754 fn astar_complete_grid_benchmark() {
756 let start_time = Instant::now();
757 let mut path = Vec::new();
758
759 println!();
760 for size in [10, 40, 100, 500] {
761 println!("benchmarking grid size of: {size}^2");
762 let setup_start_time = Instant::now();
763
764 let mut pathfinder = Graph::new();
765
766 let mut vertices = Vec::new();
768 for y in 0..size {
769 for x in 0..size {
770 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
771 }
772 }
773 pathfinder.set_vertices(vertices);
774
775 for y in 0..(size - 1) {
777 for x in 0..(size - 1) {
778 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
779 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
780 }
781 }
782
783 let setup_complete_time = Instant::now();
784 println!(
785 "setup in: {:?}",
786 setup_complete_time.duration_since(setup_start_time)
787 );
788
789 for _ in 0..1000 {
790 let sx = rand::thread_rng().gen_range(0..(size - 1));
791 let sy = rand::thread_rng().gen_range(0..(size - 1));
792
793 let tx = rand::thread_rng().gen_range(0..(size - 1));
794 let ty = rand::thread_rng().gen_range(0..(size - 1));
795
796 let from = sy * size + sx;
797 let to = ty * size + tx;
798
799 assert!(pathfinder
800 .build_positional_path(from, to, &mut path)
801 .is_ok());
802 assert!(!path.is_empty());
803
804 if path.len() > 1 {
805 assert_eq!(
806 *path.first().unwrap(),
807 pathfinder.vertex(to).unwrap().position
808 );
809 assert_eq!(
810 *path.last().unwrap(),
811 pathfinder.vertex(from).unwrap().position
812 );
813 } else {
814 let point = *path.first().unwrap();
815 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
816 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
817 }
818
819 for pair in path.chunks(2) {
820 if pair.len() == 2 {
821 let a = pair[0];
822 let b = pair[1];
823
824 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
825 }
826 }
827 }
828 println!("paths found in: {:?}", setup_complete_time.elapsed());
829 println!(
830 "Current size complete in: {:?}\n",
831 setup_start_time.elapsed()
832 );
833 }
834 println!("Total time: {:?}\n", start_time.elapsed());
835 }
836
837 #[ignore = "takes multiple seconds to run"]
838 #[test]
839 fn astar_island_benchmark() {
841 let start_time = Instant::now();
842
843 let size = 100;
844 let mut path = Vec::new();
845 let mut pathfinder = Graph::new();
846
847 let mut vertices = Vec::new();
849 for y in 0..size {
850 for x in 0..size {
851 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
852 }
853 }
854 pathfinder.set_vertices(vertices);
855
856 for y in 0..(size - 1) {
859 for x in 0..(size - 1) {
860 if x != ((size / 2) - 1) {
861 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
862 }
863 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
864 }
865 }
866
867 let setup_complete_time = Instant::now();
868
869 println!(
870 "\nsetup in: {:?}",
871 setup_complete_time.duration_since(start_time)
872 );
873
874 for _ in 0..1000 {
875 let sx = rand::thread_rng().gen_range(0..((size / 2) - 1));
877 let sy = rand::thread_rng().gen_range(0..(size - 1));
878
879 let tx = rand::thread_rng().gen_range((size / 2)..(size - 1));
881 let ty = rand::thread_rng().gen_range(0..(size - 1));
882
883 let from = sy * size + sx;
884 let to = ty * size + tx;
885
886 let path_result = pathfinder.build_positional_path(from, to, &mut path);
887
888 let is_result_expected = path_result.as_ref().is_ok_and(|k| k.eq(&PathKind::Partial))
889 || path_result.is_err_and(|e| matches!(e, PathError::HitMaxSearchIterations(_)));
890
891 assert!(is_result_expected);
892 assert!(!path.is_empty());
893
894 if path.len() > 1 {
895 assert_eq!(path.first().unwrap().x as i32, ((size / 2) - 1) as i32);
897 assert_eq!(
899 *path.last().unwrap(),
900 pathfinder.vertex(from).unwrap().position
901 );
902 } else {
903 let point = *path.first().unwrap();
904 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
905 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
906 }
907
908 for pair in path.chunks(2) {
909 if pair.len() == 2 {
910 let a = pair[0];
911 let b = pair[1];
912
913 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
914 }
915 }
916 }
917
918 println!("paths found in: {:?}", setup_complete_time.elapsed());
919 println!("Total time: {:?}\n", start_time.elapsed());
920 }
921
922 #[ignore = "takes multiple seconds to run"]
923 #[test]
924 fn astar_backwards_travel_benchmark() {
926 let start_time = Instant::now();
927
928 let size = 100;
929 let mut path = Vec::new();
930 let mut pathfinder = Graph::new();
931
932 let mut vertices = Vec::new();
934 for y in 0..size {
935 for x in 0..size {
936 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
937 }
938 }
939 pathfinder.set_vertices(vertices);
940
941 for y in 0..(size - 1) {
944 for x in (0..(size - 1)).rev() {
945 if y == 0 || x != y {
946 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
947 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
948 }
949 }
950 }
951
952 let setup_complete_time = Instant::now();
953
954 println!(
955 "\nsetup in: {:?}",
956 setup_complete_time.duration_since(start_time)
957 );
958
959 for _ in 0..1000 {
960 let from = (size / 2) * size + (size - 1);
962 let to = (size - 1) * size + (size / 2);
964
965 assert!(pathfinder
966 .build_positional_path(from, to, &mut path)
967 .is_ok());
968 assert!(!path.is_empty());
969
970 if path.len() > 1 {
971 assert_eq!(
972 *path.first().unwrap(),
973 pathfinder.vertex(to).unwrap().position
974 );
975 assert_eq!(
976 *path.last().unwrap(),
977 pathfinder.vertex(from).unwrap().position
978 );
979 } else {
980 let point = *path.first().unwrap();
981 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
982 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
983 }
984
985 for pair in path.chunks(2) {
986 if pair.len() == 2 {
987 let a = pair[0];
988 let b = pair[1];
989
990 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
991 }
992 }
993 }
994
995 println!("paths found in: {:?}", setup_complete_time.elapsed());
996 println!("Total time: {:?}\n", start_time.elapsed());
997 }
998}