1use std::cmp::Ordering;
62use std::collections::{BinaryHeap, HashMap, HashSet};
63use std::f64;
64use std::fmt::Debug;
65
66use scirs2_core::ndarray::{Array1, Array2};
67use scirs2_core::random::rngs::StdRng;
68use scirs2_core::random::Rng;
69use scirs2_core::random::SeedableRng;
70
71use crate::distance::EuclideanDistance;
72use crate::error::{SpatialError, SpatialResult};
73use crate::kdtree::KDTree;
74use crate::pathplanning::astar::{euclidean_distance, Path};
75
76type CollisionCheckFn = Box<dyn Fn(&Array1<f64>) -> bool>;
78
79#[derive(Debug, Clone)]
81pub struct PRMConfig {
82 pub num_samples: usize,
84 pub connection_radius: f64,
86 pub max_connections: usize,
88 pub seed: Option<u64>,
90 pub goal_bias: f64,
92 pub goal_threshold: f64,
94 pub bidirectional: bool,
96 pub lazy_evaluation: bool,
98}
99
100impl PRMConfig {
101 pub fn new() -> Self {
103 PRMConfig {
104 num_samples: 1000,
105 connection_radius: 1.0,
106 max_connections: 10,
107 seed: None,
108 goal_bias: 0.05,
109 goal_threshold: 0.1,
110 bidirectional: false,
111 lazy_evaluation: false,
112 }
113 }
114
115 pub fn with_num_samples(mut self, numsamples: usize) -> Self {
117 self.num_samples = numsamples;
118 self
119 }
120
121 pub fn with_connection_radius(mut self, radius: f64) -> Self {
123 self.connection_radius = radius;
124 self
125 }
126
127 pub fn with_max_connections(mut self, maxconnections: usize) -> Self {
129 self.max_connections = maxconnections;
130 self
131 }
132
133 pub fn with_seed(mut self, seed: u64) -> Self {
135 self.seed = Some(seed);
136 self
137 }
138
139 pub fn with_goal_bias(mut self, bias: f64) -> Self {
141 self.goal_bias = bias.clamp(0.0, 1.0);
142 self
143 }
144
145 pub fn with_goal_threshold(mut self, threshold: f64) -> Self {
147 self.goal_threshold = threshold;
148 self
149 }
150
151 pub fn with_bidirectional(mut self, bidirectional: bool) -> Self {
153 self.bidirectional = bidirectional;
154 self
155 }
156
157 pub fn with_lazy_evaluation(mut self, lazyevaluation: bool) -> Self {
159 self.lazy_evaluation = lazyevaluation;
160 self
161 }
162}
163
164impl Default for PRMConfig {
165 fn default() -> Self {
166 PRMConfig::new()
167 }
168}
169
170#[derive(Debug, Clone)]
172struct PRMNode {
173 #[allow(dead_code)]
175 id: usize,
176 config: Array1<f64>,
178 neighbors: Vec<(usize, f64)>,
180}
181
182impl PRMNode {
183 fn new(id: usize, config: Array1<f64>) -> Self {
185 PRMNode {
186 id,
187 config,
188 neighbors: Vec::new(),
189 }
190 }
191
192 fn add_neighbor(&mut self, _neighborid: usize, cost: f64) {
194 if !self.neighbors.iter().any(|(id_, _)| *id_ == _neighborid) {
196 self.neighbors.push((_neighborid, cost));
197 }
198 }
199}
200
201#[derive(Clone, Debug)]
203struct SearchNode {
204 id: usize,
206 g_cost: f64,
208 f_cost: f64,
210 _parent: Option<usize>,
212}
213
214impl Ord for SearchNode {
216 fn cmp(&self, other: &Self) -> Ordering {
217 other
219 .f_cost
220 .partial_cmp(&self.f_cost)
221 .unwrap_or(Ordering::Equal)
222 }
223}
224
225impl PartialOrd for SearchNode {
226 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
227 Some(self.cmp(other))
228 }
229}
230
231impl PartialEq for SearchNode {
232 fn eq(&self, other: &Self) -> bool {
233 self.id == other.id
234 }
235}
236
237impl Eq for SearchNode {}
238
239pub struct PRMPlanner {
242 config: PRMConfig,
244 bounds: (Array1<f64>, Array1<f64>),
246 dimension: usize,
248 nodes: Vec<PRMNode>,
250 kdtree: Option<KDTree<f64, EuclideanDistance<f64>>>,
252 rng: StdRng,
254 collision_checker: Option<CollisionCheckFn>,
256 roadmap_built: bool,
258}
259
260impl Debug for PRMPlanner {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("PRMPlanner")
263 .field("config", &self.config)
264 .field("bounds", &self.bounds)
265 .field("dimension", &self.dimension)
266 .field("nodes", &self.nodes.len())
267 .field("kdtree", &self.kdtree)
268 .field("roadmap_built", &self.roadmap_built)
269 .field("collision_checker", &"<function>")
270 .finish()
271 }
272}
273
274impl PRMPlanner {
275 pub fn new(
277 config: PRMConfig,
278 lower_bounds: Array1<f64>,
279 upper_bounds: Array1<f64>,
280 ) -> SpatialResult<Self> {
281 let dimension = lower_bounds.len();
282
283 if lower_bounds.len() != upper_bounds.len() {
284 return Err(SpatialError::DimensionError(
285 "Lower and upper _bounds must have the same dimension".to_string(),
286 ));
287 }
288
289 let seed = config.seed.unwrap_or_else(scirs2_core::random::random);
291 let rng = StdRng::seed_from_u64(seed);
292
293 Ok(PRMPlanner {
294 config,
295 bounds: (lower_bounds, upper_bounds),
296 dimension,
297 nodes: Vec::new(),
298 kdtree: None,
299 rng,
300 collision_checker: None,
301 roadmap_built: false,
302 })
303 }
304
305 pub fn set_collision_checker<F>(&mut self, checker: Box<F>)
307 where
308 F: Fn(&Array1<f64>) -> bool + 'static,
309 {
310 self.collision_checker = Some(checker);
311 }
312
313 fn sample_config(&mut self) -> Array1<f64> {
315 let mut config = Array1::zeros(self.dimension);
316
317 for i in 0..self.dimension {
318 let lower = self.bounds.0[i];
319 let upper = self.bounds.1[i];
320 config[i] = self.rng.gen_range(lower..upper);
321 }
322
323 config
324 }
325
326 #[allow(dead_code)]
328 fn sample_near(&mut self, target: &Array1<f64>, radius: f64) -> Array1<f64> {
329 let mut config = Array1::zeros(self.dimension);
330
331 for i in 0..self.dimension {
332 let lower = (target[i] - radius).max(self.bounds.0[i]);
333 let upper = (target[i] + radius).min(self.bounds.1[i]);
334 config[i] = self.rng.gen_range(lower..upper);
335 }
336
337 config
338 }
339
340 fn is_collision_free(&self, config: &Array1<f64>) -> bool {
342 match &self.collision_checker {
343 Some(checker) => !checker(config),
344 None => true, }
346 }
347
348 fn is_path_collision_free(&self, from: &Array1<f64>, to: &Array1<f64>) -> bool {
350 const NUM_CHECKS: usize = 10;
353
354 for i in 0..=NUM_CHECKS {
355 let t = i as f64 / NUM_CHECKS as f64;
356
357 let mut point = Array1::zeros(self.dimension);
359 for j in 0..self.dimension {
360 point[j] = from[j] * (1.0 - t) + to[j] * t;
361 }
362
363 if !self.is_collision_free(&point) {
364 return false;
365 }
366 }
367
368 true
369 }
370
371 pub fn build_roadmap(&mut self) -> SpatialResult<()> {
373 if self.roadmap_built {
374 return Ok(());
375 }
376
377 self.nodes.clear();
379
380 let mut configs = Vec::new();
382 for _ in 0..self.config.num_samples {
383 let config = self.sample_config();
384
385 if self.is_collision_free(&config) {
386 configs.push(config);
387 }
388 }
389
390 for (i, config) in configs.iter().enumerate() {
392 self.nodes.push(PRMNode::new(i, config.clone()));
393 }
394
395 let mut points = Vec::new();
397 for node in &self.nodes {
398 points.push(node.config.clone());
399 }
400
401 let n_points = points.len();
403 let dim = if n_points > 0 { points[0].len() } else { 0 };
404 let mut points_array = Array2::<f64>::zeros((n_points, dim));
405 for (i, p) in points.iter().enumerate() {
406 points_array.row_mut(i).assign(&p.view());
407 }
408
409 self.kdtree = Some(KDTree::new(&points_array)?);
411
412 for i in 0..self.nodes.len() {
414 let node_config = self.nodes[i].config.clone();
415
416 let nearby = match &self.kdtree {
418 Some(kdtree) => {
419 let node_slice = node_config.as_slice().ok_or_else(|| {
421 SpatialError::ComputationError(
422 "Failed to convert node config to slice (non-contiguous memory layout)"
423 .into(),
424 )
425 })?;
426 kdtree.query_radius(node_slice, self.config.connection_radius)?
427 }
428 None => (Vec::new(), Vec::new()),
429 };
430
431 let mut connections = Vec::new();
433
434 let (indices, distances) = nearby;
435 for (idx, &j) in indices.iter().enumerate() {
436 let distance = distances[idx];
437 if i == j {
439 continue;
440 }
441
442 let from_config = &self.nodes[i].config;
443 let to_config = &self.nodes[j].config;
444
445 if self.is_path_collision_free(from_config, to_config) {
447 connections.push((j, distance));
448 }
449 }
450
451 connections.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
453 connections.truncate(self.config.max_connections);
454
455 for (j, distance) in connections {
457 self.nodes[i].add_neighbor(j, distance);
458 self.nodes[j].add_neighbor(i, distance); }
460 }
461
462 self.roadmap_built = true;
463 Ok(())
464 }
465
466 pub fn find_path(
468 &mut self,
469 start: &Array1<f64>,
470 goal: &Array1<f64>,
471 ) -> SpatialResult<Option<Path<Array1<f64>>>> {
472 if !self.roadmap_built {
474 self.build_roadmap()?;
475 }
476
477 if !self.is_collision_free(start) {
479 return Err(SpatialError::ValueError(
480 "Start configuration is in collision".to_string(),
481 ));
482 }
483
484 if !self.is_collision_free(goal) {
485 return Err(SpatialError::ValueError(
486 "Goal configuration is in collision".to_string(),
487 ));
488 }
489
490 let start_id = self.nodes.len();
492 let goalid = start_id + 1;
493
494 let mut start_node = PRMNode::new(start_id, start.clone());
495 let mut goal_node = PRMNode::new(goalid, goal.clone());
496
497 for i in 0..self.nodes.len() {
499 let node_config = self.nodes[i].config.clone();
500
501 let start_distance = euclidean_distance(&start.view(), &node_config.view())?;
503 if start_distance <= self.config.connection_radius
504 && self.is_path_collision_free(start, &node_config)
505 {
506 start_node.add_neighbor(i, start_distance);
507 self.nodes[i].add_neighbor(start_id, start_distance);
508 }
509
510 let goal_distance = euclidean_distance(&goal.view(), &node_config.view())?;
512 if goal_distance <= self.config.connection_radius
513 && self.is_path_collision_free(goal, &node_config)
514 {
515 goal_node.add_neighbor(i, goal_distance);
516 self.nodes[i].add_neighbor(goalid, goal_distance);
517 }
518 }
519
520 let start_goal_distance = euclidean_distance(&start.view(), &goal.view())?;
522 if start_goal_distance <= self.config.connection_radius
523 && self.is_path_collision_free(start, goal)
524 {
525 start_node.add_neighbor(goalid, start_goal_distance);
526 goal_node.add_neighbor(start_id, start_goal_distance);
527 }
528
529 self.nodes.push(start_node);
531 self.nodes.push(goal_node);
532
533 let path = self.astar_search(start_id, goalid);
535
536 self.nodes.pop(); self.nodes.pop(); for node in &mut self.nodes {
542 node.neighbors.retain(|(id_, _)| *id_ < start_id);
543 }
544
545 match path {
547 Some((node_path, cost)) => {
548 let mut configs = Vec::new();
549 for &id in &node_path {
550 if id == start_id {
551 configs.push(start.clone());
552 } else if id == goalid {
553 configs.push(goal.clone());
554 } else {
555 configs.push(self.nodes[id].config.clone());
556 }
557 }
558
559 Ok(Some(Path::new(configs, cost)))
560 }
561 None => Ok(None),
562 }
563 }
564
565 fn astar_search(&self, start_id: usize, goalid: usize) -> Option<(Vec<usize>, f64)> {
567 let mut open_set = BinaryHeap::new();
568 let mut closed_set = HashSet::new();
569 let mut came_from = HashMap::new();
570 let mut g_scores = HashMap::new();
571
572 g_scores.insert(start_id, 0.0);
574
575 let h_score = euclidean_distance(
577 &self.nodes[start_id].config.view(),
578 &self.nodes[goalid].config.view(),
579 )
580 .unwrap_or(f64::MAX);
581
582 open_set.push(SearchNode {
583 id: start_id,
584 g_cost: 0.0,
585 f_cost: h_score,
586 _parent: None,
587 });
588
589 while let Some(current) = open_set.pop() {
590 if current.id == goalid {
592 let mut path = Vec::new();
594 let mut current_id = current.id;
595
596 path.push(current_id);
597
598 while let Some(parent_id) = came_from.get(¤t_id) {
599 path.push(*parent_id);
600 current_id = *parent_id;
601 }
602
603 path.reverse();
604
605 return Some((path, current.g_cost));
606 }
607
608 if closed_set.contains(¤t.id) {
610 continue;
611 }
612
613 closed_set.insert(current.id);
615
616 for &(_neighborid, edge_cost) in &self.nodes[current.id].neighbors {
618 if closed_set.contains(&_neighborid) {
620 continue;
621 }
622
623 let tentative_g_score = g_scores[¤t.id] + edge_cost;
625
626 if !g_scores.contains_key(&_neighborid)
628 || tentative_g_score < g_scores[&_neighborid]
629 {
630 came_from.insert(_neighborid, current.id);
632 g_scores.insert(_neighborid, tentative_g_score);
633
634 let h_score = euclidean_distance(
636 &self.nodes[_neighborid].config.view(),
637 &self.nodes[goalid].config.view(),
638 )
639 .unwrap_or(f64::MAX);
640
641 let f_score = tentative_g_score + h_score;
642
643 open_set.push(SearchNode {
645 id: _neighborid,
646 g_cost: tentative_g_score,
647 f_cost: f_score,
648 _parent: Some(current.id),
649 });
650 }
651 }
652 }
653
654 None
656 }
657
658 pub fn create_2d_with_polygons(
660 config: PRMConfig,
661 obstacles: Vec<Vec<[f64; 2]>>,
662 x_range: (f64, f64),
663 y_range: (f64, f64),
664 ) -> Self {
665 let lower_bounds = Array1::from_vec(vec![x_range.0, y_range.0]);
666 let upper_bounds = Array1::from_vec(vec![x_range.1, y_range.1]);
667
668 let collision_checker = Box::new(move |p: &Array1<f64>| {
670 let point = [p[0], p[1]];
671
672 for obstacle in &obstacles {
674 if point_in_polygon(&point, obstacle) {
675 return true; }
677 }
678
679 false });
681
682 let mut planner = Self::new(config, lower_bounds, upper_bounds)
683 .expect("Lower and upper bounds should have same dimension (2)");
684 planner.set_collision_checker(collision_checker);
685
686 planner
687 }
688}
689
690#[derive(Debug)]
692pub struct PRM2DPlanner {
693 planner: PRMPlanner,
695 obstacles: Vec<Vec<[f64; 2]>>,
697}
698
699impl PRM2DPlanner {
700 pub fn new(
702 config: PRMConfig,
703 obstacles: Vec<Vec<[f64; 2]>>,
704 x_range: (f64, f64),
705 y_range: (f64, f64),
706 ) -> Self {
707 let planner =
708 PRMPlanner::create_2d_with_polygons(config, obstacles.clone(), x_range, y_range);
709
710 PRM2DPlanner { planner, obstacles }
711 }
712
713 pub fn build_roadmap(&mut self) -> SpatialResult<()> {
715 self.planner.build_roadmap()
716 }
717
718 pub fn find_path(
720 &mut self,
721 start: [f64; 2],
722 goal: [f64; 2],
723 ) -> SpatialResult<Option<Path<Array1<f64>>>> {
724 let start_array = Array1::from_vec(vec![start[0], start[1]]);
725 let goal_array = Array1::from_vec(vec![goal[0], goal[1]]);
726
727 for obstacle in &self.obstacles {
729 if point_in_polygon(&start, obstacle) {
730 return Err(SpatialError::ValueError(
731 "Start point is inside an obstacle".to_string(),
732 ));
733 }
734
735 if point_in_polygon(&goal, obstacle) {
736 return Err(SpatialError::ValueError(
737 "Goal point is inside an obstacle".to_string(),
738 ));
739 }
740 }
741
742 self.planner.find_path(&start_array, &goal_array)
743 }
744
745 pub fn obstacles(&self) -> &Vec<Vec<[f64; 2]>> {
747 &self.obstacles
748 }
749}
750
751#[allow(dead_code)]
753fn point_in_polygon(point: &[f64; 2], polygon: &[[f64; 2]]) -> bool {
754 let (x, y) = (point[0], point[1]);
755 let mut inside = false;
756
757 let n = polygon.len();
759 for i in 0..n {
760 let (x1, y1) = (polygon[i][0], polygon[i][1]);
761 let (x2, y2) = (polygon[(i + 1) % n][0], polygon[(i + 1) % n][1]);
762
763 let intersects = ((y1 > y) != (y2 > y)) && (x < (x2 - x1) * (y - y1) / (y2 - y1) + x1);
764
765 if intersects {
766 inside = !inside;
767 }
768 }
769
770 inside
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use approx::assert_relative_eq;
777 use scirs2_core::ndarray::array;
778
779 #[test]
780 fn test_point_in_polygon() {
781 let square = vec![[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]];
783
784 assert!(point_in_polygon(&[0.5, 0.5], &square));
786 assert!(point_in_polygon(&[0.1, 0.1], &square));
787 assert!(point_in_polygon(&[0.9, 0.9], &square));
788
789 assert!(!point_in_polygon(&[-0.1, 0.5], &square));
791 assert!(!point_in_polygon(&[0.5, -0.1], &square));
792 assert!(!point_in_polygon(&[1.1, 0.5], &square));
793 assert!(!point_in_polygon(&[0.5, 1.1], &square));
794
795 let complex = vec![[0.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]];
797
798 assert!(point_in_polygon(&[1.0, 1.5], &complex));
803
804 assert!(!point_in_polygon(&[3.0, 1.0], &complex));
806 }
807
808 #[test]
809 fn test_prm_config() {
810 let config = PRMConfig::new()
811 .with_num_samples(500)
812 .with_connection_radius(0.8)
813 .with_max_connections(5)
814 .with_seed(42)
815 .with_goal_bias(0.1)
816 .with_goal_threshold(0.2);
817
818 assert_eq!(config.num_samples, 500);
819 assert_eq!(config.connection_radius, 0.8);
820 assert_eq!(config.max_connections, 5);
821 assert_eq!(config.seed, Some(42));
822 assert_eq!(config.goal_bias, 0.1);
823 assert_eq!(config.goal_threshold, 0.2);
824 }
825
826 #[test]
827 fn test_simple_path() {
828 let config = PRMConfig::new()
832 .with_num_samples(1000) .with_connection_radius(3.0) .with_seed(42);
835
836 let lower_bounds = array![0.0, 0.0];
837 let upper_bounds = array![10.0, 10.0];
838
839 let mut planner =
840 PRMPlanner::new(config, lower_bounds, upper_bounds).expect("Operation failed");
841
842 planner.build_roadmap().expect("Operation failed");
844
845 let start = array![1.0, 1.0];
847 let goal = array![9.0, 9.0];
848
849 if let Ok(Some(path)) = planner.find_path(&start, &goal) {
854 assert_eq!(path.nodes[0], start);
856
857 let last = path.nodes.last().expect("Operation failed");
859 let dx = last[0] - goal[0];
860 let dy = last[1] - goal[1];
861 let dist = (dx * dx + dy * dy).sqrt();
862
863 assert!(dist < 3.0);
865
866 assert!(path.cost < 20.0); } else {
869 println!(
871 "⚠️ No path found in PRM test - this is expected occasionally with random sampling"
872 );
873 }
874 }
875
876 #[test]
877 fn test_2d_planner() {
878 let obstacle = vec![[4.0, 4.0], [6.0, 4.0], [6.0, 6.0], [4.0, 6.0]];
880
881 let config = PRMConfig::new()
882 .with_num_samples(200)
883 .with_connection_radius(2.0)
884 .with_seed(42);
885
886 let mut planner = PRM2DPlanner::new(config, vec![obstacle], (0.0, 10.0), (0.0, 10.0));
887
888 planner.build_roadmap().expect("Operation failed");
890
891 let start = [1.0, 5.0];
893 let goal = [9.0, 5.0];
894
895 let path = planner.find_path(start, goal).expect("Operation failed");
896
897 assert!(path.is_some());
899
900 let path = path.expect("Operation failed");
901
902 assert!(path.nodes.len() > 2);
904
905 assert_relative_eq!(path.nodes[0][0], start[0], epsilon = 1e-5);
907 assert_relative_eq!(path.nodes[0][1], start[1], epsilon = 1e-5);
908
909 let last = path.nodes.last().expect("Operation failed");
910 assert_relative_eq!(last[0], goal[0], epsilon = 1e-5);
911 assert_relative_eq!(last[1], goal[1], epsilon = 1e-5);
912 }
913}