1#![warn(missing_docs)]
28
29use crate::core::{
30 algebra::Vector3,
31 math::{self, PositionProvider},
32 visitor::prelude::*,
33};
34
35use std::{
36 cmp::Ordering,
37 collections::BinaryHeap,
38 fmt::{Display, Formatter},
39 ops::{Deref, DerefMut},
40};
41
42#[derive(Clone, Debug, Visit, PartialEq)]
45pub struct VertexData {
46 pub position: Vector3<f32>,
48 pub neighbours: Vec<u32>,
50 #[visit(skip)]
52 pub g_penalty: f32,
53}
54
55impl Default for VertexData {
56 fn default() -> Self {
57 Self {
58 position: Default::default(),
59 g_penalty: 1f32,
60 neighbours: Default::default(),
61 }
62 }
63}
64
65impl VertexData {
66 pub fn new(position: Vector3<f32>) -> Self {
68 Self {
69 position,
70 g_penalty: 1f32,
71 neighbours: Default::default(),
72 }
73 }
74}
75
76pub trait VertexDataProvider: Deref<Target = VertexData> + DerefMut + PositionProvider {}
79
80#[derive(Default, PartialEq, Debug)]
82pub struct GraphVertex {
83 pub data: VertexData,
85}
86
87impl GraphVertex {
88 pub fn new(position: Vector3<f32>) -> Self {
90 Self {
91 data: VertexData::new(position),
92 }
93 }
94}
95
96impl Deref for GraphVertex {
97 type Target = VertexData;
98
99 fn deref(&self) -> &Self::Target {
100 &self.data
101 }
102}
103
104impl DerefMut for GraphVertex {
105 fn deref_mut(&mut self) -> &mut Self::Target {
106 &mut self.data
107 }
108}
109
110impl PositionProvider for GraphVertex {
111 fn position(&self) -> Vector3<f32> {
112 self.data.position
113 }
114}
115
116impl Visit for GraphVertex {
117 fn visit(&mut self, name: &str, visitor: &mut Visitor) -> VisitResult {
118 self.data.visit(name, visitor)
119 }
120}
121
122impl VertexDataProvider for GraphVertex {}
123
124#[derive(Clone, Debug, Visit, PartialEq)]
128pub struct Graph<T>
129where
130 T: VertexDataProvider,
131{
132 pub vertices: Vec<T>,
134 pub max_search_iterations: i32,
146}
147
148#[derive(Copy, Clone, PartialEq, Eq, Debug)]
150pub enum PathKind {
151 Full,
153 Partial,
161}
162
163fn heuristic(a: Vector3<f32>, b: Vector3<f32>) -> f32 {
164 (a - b).norm_squared()
165}
166
167impl<T: VertexDataProvider> Default for Graph<T> {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl PositionProvider for VertexData {
174 fn position(&self) -> Vector3<f32> {
175 self.position
176 }
177}
178
179#[derive(Clone, Debug)]
182pub enum PathError {
183 InvalidIndex(usize),
186
187 CyclicReferenceFound(usize),
189
190 HitMaxSearchIterations(i32),
200
201 Empty,
203}
204
205impl Display for PathError {
206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207 match self {
208 PathError::InvalidIndex(v) => {
209 write!(f, "Invalid vertex index {v}.")
210 }
211 PathError::CyclicReferenceFound(v) => {
212 write!(f, "Cyclical reference was found {v}.")
213 }
214 PathError::HitMaxSearchIterations(v) => {
215 write!(
216 f,
217 "Maximum search iterations ({v}) hit, returning with partial path."
218 )
219 }
220 PathError::Empty => {
221 write!(f, "Graph was empty")
222 }
223 }
224 }
225}
226
227#[derive(Clone)]
228pub struct PartialPath {
230 vertices: Vec<usize>,
231 g_score: f32,
232 f_score: f32,
233}
234
235impl Default for PartialPath {
236 fn default() -> Self {
237 Self {
238 vertices: Vec::new(),
239 g_score: f32::MAX,
240 f_score: f32::MAX,
241 }
242 }
243}
244
245impl Ord for PartialPath {
246 fn cmp(&self, other: &Self) -> Ordering {
248 (self.f_score.total_cmp(&other.f_score))
249 .then((self.f_score - self.g_score).total_cmp(&(other.f_score - other.g_score)))
250 .reverse()
251 }
252}
253
254impl PartialOrd for PartialPath {
255 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
257 Some(self.cmp(other))
258 }
259}
260
261impl PartialEq for PartialPath {
262 fn eq(&self, other: &Self) -> bool {
264 self.f_score == other.f_score && self.g_score == other.g_score
265 }
266}
267
268impl Eq for PartialPath {}
269
270impl PartialPath {
271 pub fn new(start: usize) -> Self {
273 Self {
274 vertices: vec![start],
275 g_score: 0f32,
276 f_score: f32::MAX,
277 }
278 }
279
280 pub fn clone_and_add(
282 &self,
283 new_vertex: usize,
284 new_g_score: f32,
285 new_f_score: f32,
286 ) -> PartialPath {
287 let mut clone = self.clone();
288 clone.vertices.push(new_vertex);
289 clone.g_score = new_g_score;
290 clone.f_score = new_f_score;
291
292 clone
293 }
294}
295
296impl<T: VertexDataProvider> Graph<T> {
297 pub fn new() -> Self {
299 Self {
300 vertices: Default::default(),
301 max_search_iterations: 1000i32,
302 }
303 }
304
305 pub fn set_vertices(&mut self, vertices: Vec<T>) {
309 self.vertices = vertices;
310 }
311
312 pub fn get_closest_vertex_to(&self, point: Vector3<f32>) -> Option<usize> {
318 math::get_closest_point(&self.vertices, point)
319 }
320
321 pub fn link_bidirect(&mut self, a: usize, b: usize) {
325 self.link_unidirect(a, b);
326 self.link_unidirect(b, a);
327 }
328
329 pub fn link_unidirect(&mut self, a: usize, b: usize) {
332 if let Some(vertex_a) = self.vertices.get_mut(a) {
333 if vertex_a.neighbours.iter().all(|n| *n != b as u32) {
334 vertex_a.neighbours.push(b as u32);
335 }
336 }
337 }
338
339 pub fn vertex(&self, index: usize) -> Option<&T> {
341 self.vertices.get(index)
342 }
343
344 pub fn vertex_mut(&mut self, index: usize) -> Option<&mut T> {
346 self.vertices.get_mut(index)
347 }
348
349 pub fn vertices(&self) -> &[T] {
351 &self.vertices
352 }
353
354 pub fn vertices_mut(&mut self) -> &mut [T] {
356 &mut self.vertices
357 }
358
359 pub fn add_vertex(&mut self, vertex: T) -> u32 {
361 let index = self.vertices.len();
362 self.vertices.push(vertex);
365 index as u32
366 }
367
368 pub fn pop_vertex(&mut self) -> Option<T> {
372 if self.vertices.is_empty() {
373 None
374 } else {
375 Some(self.remove_vertex(self.vertices.len() - 1))
376 }
377 }
378
379 pub fn remove_vertex(&mut self, index: usize) -> T {
383 for other_vertex in self.vertices.iter_mut() {
384 if let Some(position) = other_vertex
386 .neighbours
387 .iter()
388 .position(|n| *n == index as u32)
389 {
390 other_vertex.neighbours.remove(position);
391 }
392
393 for neighbour_index in other_vertex.neighbours.iter_mut() {
395 if *neighbour_index > index as u32 {
396 *neighbour_index -= 1;
397 }
398 }
399 }
400
401 self.vertices.remove(index)
402 }
403
404 pub fn insert_vertex(&mut self, index: u32, vertex: T) {
407 self.vertices.insert(index as usize, vertex);
408
409 for other_vertex in self.vertices.iter_mut() {
411 for neighbour_index in other_vertex.neighbours.iter_mut() {
412 if *neighbour_index >= index {
413 *neighbour_index += 1;
414 }
415 }
416 }
417 }
418
419 pub fn build_indexed_path(
432 &self,
433 from: usize,
434 to: usize,
435 path: &mut Vec<usize>,
436 ) -> Result<PathKind, PathError> {
437 path.clear();
438
439 if self.vertices.is_empty() {
440 return Err(PathError::Empty);
441 }
442
443 let end_pos = self
444 .vertices
445 .get(to)
446 .ok_or(PathError::InvalidIndex(to))?
447 .position;
448
449 if from == to {
451 path.push(to);
452 return Ok(PathKind::Full);
453 }
454
455 let mut searched_vertices = vec![false; self.vertices.len()];
457
458 let mut search_heap: BinaryHeap<PartialPath> = BinaryHeap::new();
460
461 search_heap.push(PartialPath::new(from));
463
464 let mut best_path = PartialPath::default();
466
467 let mut search_iteration = 0i32;
469
470 while self.max_search_iterations < 0 || search_iteration < self.max_search_iterations {
471 if search_heap.is_empty() {
473 break;
474 }
475
476 let current_path = search_heap.pop().unwrap();
478
479 let current_index = *current_path.vertices.last().unwrap();
480 let current_vertex = self
481 .vertices
482 .get(current_index)
483 .ok_or(PathError::InvalidIndex(current_index))?;
484
485 if current_path > best_path {
487 best_path = current_path.clone();
488
489 if current_index == to {
491 break;
492 }
493 }
494
495 for i in current_vertex.neighbours.iter() {
497 let neighbour_index = *i as usize;
498
499 if neighbour_index == current_index {
502 return Err(PathError::CyclicReferenceFound(current_index));
503 }
504
505 if searched_vertices[neighbour_index] {
507 continue;
508 }
509
510 let neighbour = self
511 .vertices
512 .get(neighbour_index)
513 .ok_or(PathError::InvalidIndex(neighbour_index))?;
514
515 let neighbour_g_score = current_path.g_score
516 + ((current_vertex.position - neighbour.position).norm_squared()
517 * neighbour.g_penalty);
518
519 let neighbour_f_score = neighbour_g_score + heuristic(neighbour.position, end_pos);
520
521 search_heap.push(current_path.clone_and_add(
522 neighbour_index,
523 neighbour_g_score,
524 neighbour_f_score,
525 ));
526 }
527
528 searched_vertices[current_index] = true;
530
531 search_iteration += 1;
532 }
533
534 path.clone_from(&best_path.vertices);
536 path.reverse();
537
538 if *path.first().unwrap() == to {
539 Ok(PathKind::Full)
540 } else if search_iteration == self.max_search_iterations - 1 {
541 Err(PathError::HitMaxSearchIterations(
542 self.max_search_iterations,
543 ))
544 } else {
545 Ok(PathKind::Partial)
546 }
547 }
548
549 pub fn build_positional_path(
562 &self,
563 from: usize,
564 to: usize,
565 path: &mut Vec<Vector3<f32>>,
566 ) -> Result<PathKind, PathError> {
567 path.clear();
568
569 let mut indices: Vec<usize> = Vec::new();
570 let path_kind = self.build_indexed_path(from, to, &mut indices)?;
571
572 for index in indices.iter() {
574 let vertex = self
575 .vertices
576 .get(*index)
577 .ok_or(PathError::InvalidIndex(*index))?;
578
579 path.push(vertex.position);
580 }
581
582 Ok(path_kind)
583 }
584
585 #[deprecated = "name is too ambiguous use build_positional_path instead"]
600 pub fn build(
601 &self,
602 from: usize,
603 to: usize,
604 path: &mut Vec<Vector3<f32>>,
605 ) -> Result<PathKind, PathError> {
606 self.build_positional_path(from, to, path)
607 }
608}
609
610#[cfg(test)]
611mod test {
612 use crate::rand::Rng;
613 use crate::utils::astar::PathError;
614 use crate::{
615 core::{algebra::Vector3, rand},
616 utils::astar::{Graph, GraphVertex, PathKind},
617 };
618 use std::time::Instant;
619
620 #[test]
621 fn astar_random_points() {
622 let mut pathfinder = Graph::<GraphVertex>::new();
623
624 let mut path = Vec::new();
625 assert!(pathfinder
626 .build_positional_path(0, 0, &mut path)
627 .is_err_and(|e| matches!(e, PathError::Empty)));
628 assert!(path.is_empty());
629
630 let size = 40;
631
632 let mut vertices = Vec::new();
634 for y in 0..size {
635 for x in 0..size {
636 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
637 }
638 }
639 pathfinder.set_vertices(vertices);
640
641 assert!(pathfinder
642 .build_positional_path(100000, 99999, &mut path)
643 .is_err_and(|e| matches!(e, PathError::InvalidIndex(_))));
644
645 for y in 0..(size - 1) {
647 for x in 0..(size - 1) {
648 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
649 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
650 }
651 }
652
653 let mut paths_count = 0;
654
655 for _ in 0..1000 {
656 let sx = rand::thread_rng().gen_range(0..(size - 1));
657 let sy = rand::thread_rng().gen_range(0..(size - 1));
658
659 let tx = rand::thread_rng().gen_range(0..(size - 1));
660 let ty = rand::thread_rng().gen_range(0..(size - 1));
661
662 let from = sy * size + sx;
663 let to = ty * size + tx;
664
665 assert!(pathfinder
666 .build_positional_path(from, to, &mut path)
667 .is_ok());
668 assert!(!path.is_empty());
669
670 if path.len() > 1 {
671 paths_count += 1;
672
673 assert_eq!(
674 *path.first().unwrap(),
675 pathfinder.vertex(to).unwrap().position
676 );
677 assert_eq!(
678 *path.last().unwrap(),
679 pathfinder.vertex(from).unwrap().position
680 );
681 } else {
682 let point = *path.first().unwrap();
683 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
684 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
685 }
686
687 for pair in path.chunks(2) {
688 if pair.len() == 2 {
689 let a = pair[0];
690 let b = pair[1];
691
692 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
693 }
694 }
695 }
696
697 assert!(paths_count > 0);
698 }
699
700 #[test]
701 fn test_remove_vertex() {
702 let mut pathfinder = Graph::<GraphVertex>::new();
703
704 pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
705 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
706 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
707
708 pathfinder.link_bidirect(0, 1);
709 pathfinder.link_bidirect(1, 2);
710 pathfinder.link_bidirect(2, 0);
711
712 pathfinder.remove_vertex(0);
713
714 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1]);
715 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0]);
716 assert_eq!(pathfinder.vertex(2), None);
717
718 pathfinder.remove_vertex(0);
719
720 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
721 assert_eq!(pathfinder.vertex(1), None);
722 assert_eq!(pathfinder.vertex(2), None);
723 }
724
725 #[test]
726 fn test_insert_vertex() {
727 let mut pathfinder = Graph::new();
728
729 pathfinder.add_vertex(GraphVertex::new(Vector3::new(0.0, 0.0, 0.0)));
730 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 0.0, 0.0)));
731 pathfinder.add_vertex(GraphVertex::new(Vector3::new(1.0, 1.0, 0.0)));
732
733 pathfinder.link_bidirect(0, 1);
734 pathfinder.link_bidirect(1, 2);
735 pathfinder.link_bidirect(2, 0);
736
737 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, vec![1, 2]);
738 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![0, 2]);
739 assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 0]);
740
741 pathfinder.insert_vertex(0, GraphVertex::new(Vector3::new(1.0, 1.0, 1.0)));
742
743 assert_eq!(pathfinder.vertex(0).unwrap().neighbours, Vec::<u32>::new());
744 assert_eq!(pathfinder.vertex(1).unwrap().neighbours, vec![2, 3]);
745 assert_eq!(pathfinder.vertex(2).unwrap().neighbours, vec![1, 3]);
746 assert_eq!(pathfinder.vertex(3).unwrap().neighbours, vec![2, 1]);
747 }
748
749 #[ignore = "takes multiple seconds to run"]
750 #[test]
751 fn astar_complete_grid_benchmark() {
753 let start_time = Instant::now();
754 let mut path = Vec::new();
755
756 println!();
757 for size in [10, 40, 100, 500] {
758 println!("benchmarking grid size of: {size}^2");
759 let setup_start_time = Instant::now();
760
761 let mut pathfinder = Graph::new();
762
763 let mut vertices = Vec::new();
765 for y in 0..size {
766 for x in 0..size {
767 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
768 }
769 }
770 pathfinder.set_vertices(vertices);
771
772 for y in 0..(size - 1) {
774 for x in 0..(size - 1) {
775 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
776 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
777 }
778 }
779
780 let setup_complete_time = Instant::now();
781 println!(
782 "setup in: {:?}",
783 setup_complete_time.duration_since(setup_start_time)
784 );
785
786 for _ in 0..1000 {
787 let sx = rand::thread_rng().gen_range(0..(size - 1));
788 let sy = rand::thread_rng().gen_range(0..(size - 1));
789
790 let tx = rand::thread_rng().gen_range(0..(size - 1));
791 let ty = rand::thread_rng().gen_range(0..(size - 1));
792
793 let from = sy * size + sx;
794 let to = ty * size + tx;
795
796 assert!(pathfinder
797 .build_positional_path(from, to, &mut path)
798 .is_ok());
799 assert!(!path.is_empty());
800
801 if path.len() > 1 {
802 assert_eq!(
803 *path.first().unwrap(),
804 pathfinder.vertex(to).unwrap().position
805 );
806 assert_eq!(
807 *path.last().unwrap(),
808 pathfinder.vertex(from).unwrap().position
809 );
810 } else {
811 let point = *path.first().unwrap();
812 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
813 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
814 }
815
816 for pair in path.chunks(2) {
817 if pair.len() == 2 {
818 let a = pair[0];
819 let b = pair[1];
820
821 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
822 }
823 }
824 }
825 println!("paths found in: {:?}", setup_complete_time.elapsed());
826 println!(
827 "Current size complete in: {:?}\n",
828 setup_start_time.elapsed()
829 );
830 }
831 println!("Total time: {:?}\n", start_time.elapsed());
832 }
833
834 #[ignore = "takes multiple seconds to run"]
835 #[test]
836 fn astar_island_benchmark() {
838 let start_time = Instant::now();
839
840 let size = 100;
841 let mut path = Vec::new();
842 let mut pathfinder = Graph::new();
843
844 let mut vertices = Vec::new();
846 for y in 0..size {
847 for x in 0..size {
848 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
849 }
850 }
851 pathfinder.set_vertices(vertices);
852
853 for y in 0..(size - 1) {
856 for x in 0..(size - 1) {
857 if x != ((size / 2) - 1) {
858 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
859 }
860 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
861 }
862 }
863
864 let setup_complete_time = Instant::now();
865
866 println!(
867 "\nsetup in: {:?}",
868 setup_complete_time.duration_since(start_time)
869 );
870
871 for _ in 0..1000 {
872 let sx = rand::thread_rng().gen_range(0..((size / 2) - 1));
874 let sy = rand::thread_rng().gen_range(0..(size - 1));
875
876 let tx = rand::thread_rng().gen_range((size / 2)..(size - 1));
878 let ty = rand::thread_rng().gen_range(0..(size - 1));
879
880 let from = sy * size + sx;
881 let to = ty * size + tx;
882
883 let path_result = pathfinder.build_positional_path(from, to, &mut path);
884
885 let is_result_expected = path_result.as_ref().is_ok_and(|k| k.eq(&PathKind::Partial))
886 || path_result.is_err_and(|e| matches!(e, PathError::HitMaxSearchIterations(_)));
887
888 assert!(is_result_expected);
889 assert!(!path.is_empty());
890
891 if path.len() > 1 {
892 assert_eq!(path.first().unwrap().x as i32, ((size / 2) - 1) as i32);
894 assert_eq!(
896 *path.last().unwrap(),
897 pathfinder.vertex(from).unwrap().position
898 );
899 } else {
900 let point = *path.first().unwrap();
901 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
902 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
903 }
904
905 for pair in path.chunks(2) {
906 if pair.len() == 2 {
907 let a = pair[0];
908 let b = pair[1];
909
910 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
911 }
912 }
913 }
914
915 println!("paths found in: {:?}", setup_complete_time.elapsed());
916 println!("Total time: {:?}\n", start_time.elapsed());
917 }
918
919 #[ignore = "takes multiple seconds to run"]
920 #[test]
921 fn astar_backwards_travel_benchmark() {
923 let start_time = Instant::now();
924
925 let size = 100;
926 let mut path = Vec::new();
927 let mut pathfinder = Graph::new();
928
929 let mut vertices = Vec::new();
931 for y in 0..size {
932 for x in 0..size {
933 vertices.push(GraphVertex::new(Vector3::new(x as f32, y as f32, 0.0)));
934 }
935 }
936 pathfinder.set_vertices(vertices);
937
938 for y in 0..(size - 1) {
941 for x in (0..(size - 1)).rev() {
942 if y == 0 || x != y {
943 pathfinder.link_bidirect(y * size + x, y * size + x + 1);
944 pathfinder.link_bidirect(y * size + x, (y + 1) * size + x);
945 }
946 }
947 }
948
949 let setup_complete_time = Instant::now();
950
951 println!(
952 "\nsetup in: {:?}",
953 setup_complete_time.duration_since(start_time)
954 );
955
956 for _ in 0..1000 {
957 let from = (size / 2) * size + (size - 1);
959 let to = (size - 1) * size + (size / 2);
961
962 assert!(pathfinder
963 .build_positional_path(from, to, &mut path)
964 .is_ok());
965 assert!(!path.is_empty());
966
967 if path.len() > 1 {
968 assert_eq!(
969 *path.first().unwrap(),
970 pathfinder.vertex(to).unwrap().position
971 );
972 assert_eq!(
973 *path.last().unwrap(),
974 pathfinder.vertex(from).unwrap().position
975 );
976 } else {
977 let point = *path.first().unwrap();
978 assert_eq!(point, pathfinder.vertex(to).unwrap().position);
979 assert_eq!(point, pathfinder.vertex(from).unwrap().position);
980 }
981
982 for pair in path.chunks(2) {
983 if pair.len() == 2 {
984 let a = pair[0];
985 let b = pair[1];
986
987 assert!(a.metric_distance(&b) <= 2.0f32.sqrt());
988 }
989 }
990 }
991
992 println!("paths found in: {:?}", setup_complete_time.elapsed());
993 println!("Total time: {:?}\n", start_time.elapsed());
994 }
995}