use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::f64;
use std::fmt::Debug;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::SeedableRng;
use scirs2_core::random::{Rng, RngExt};
use crate::distance::EuclideanDistance;
use crate::error::{SpatialError, SpatialResult};
use crate::kdtree::KDTree;
use crate::pathplanning::astar::{euclidean_distance, Path};
type CollisionCheckFn = Box<dyn Fn(&Array1<f64>) -> bool>;
#[derive(Debug, Clone)]
pub struct PRMConfig {
pub num_samples: usize,
pub connection_radius: f64,
pub max_connections: usize,
pub seed: Option<u64>,
pub goal_bias: f64,
pub goal_threshold: f64,
pub bidirectional: bool,
pub lazy_evaluation: bool,
}
impl PRMConfig {
pub fn new() -> Self {
PRMConfig {
num_samples: 1000,
connection_radius: 1.0,
max_connections: 10,
seed: None,
goal_bias: 0.05,
goal_threshold: 0.1,
bidirectional: false,
lazy_evaluation: false,
}
}
pub fn with_num_samples(mut self, numsamples: usize) -> Self {
self.num_samples = numsamples;
self
}
pub fn with_connection_radius(mut self, radius: f64) -> Self {
self.connection_radius = radius;
self
}
pub fn with_max_connections(mut self, maxconnections: usize) -> Self {
self.max_connections = maxconnections;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_goal_bias(mut self, bias: f64) -> Self {
self.goal_bias = bias.clamp(0.0, 1.0);
self
}
pub fn with_goal_threshold(mut self, threshold: f64) -> Self {
self.goal_threshold = threshold;
self
}
pub fn with_bidirectional(mut self, bidirectional: bool) -> Self {
self.bidirectional = bidirectional;
self
}
pub fn with_lazy_evaluation(mut self, lazyevaluation: bool) -> Self {
self.lazy_evaluation = lazyevaluation;
self
}
}
impl Default for PRMConfig {
fn default() -> Self {
PRMConfig::new()
}
}
#[derive(Debug, Clone)]
struct PRMNode {
#[allow(dead_code)]
id: usize,
config: Array1<f64>,
neighbors: Vec<(usize, f64)>,
}
impl PRMNode {
fn new(id: usize, config: Array1<f64>) -> Self {
PRMNode {
id,
config,
neighbors: Vec::new(),
}
}
fn add_neighbor(&mut self, _neighborid: usize, cost: f64) {
if !self.neighbors.iter().any(|(id_, _)| *id_ == _neighborid) {
self.neighbors.push((_neighborid, cost));
}
}
}
#[derive(Clone, Debug)]
struct SearchNode {
id: usize,
g_cost: f64,
f_cost: f64,
_parent: Option<usize>,
}
impl Ord for SearchNode {
fn cmp(&self, other: &Self) -> Ordering {
other
.f_cost
.partial_cmp(&self.f_cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for SearchNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for SearchNode {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for SearchNode {}
pub struct PRMPlanner {
config: PRMConfig,
bounds: (Array1<f64>, Array1<f64>),
dimension: usize,
nodes: Vec<PRMNode>,
kdtree: Option<KDTree<f64, EuclideanDistance<f64>>>,
rng: StdRng,
collision_checker: Option<CollisionCheckFn>,
roadmap_built: bool,
}
impl Debug for PRMPlanner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PRMPlanner")
.field("config", &self.config)
.field("bounds", &self.bounds)
.field("dimension", &self.dimension)
.field("nodes", &self.nodes.len())
.field("kdtree", &self.kdtree)
.field("roadmap_built", &self.roadmap_built)
.field("collision_checker", &"<function>")
.finish()
}
}
impl PRMPlanner {
pub fn new(
config: PRMConfig,
lower_bounds: Array1<f64>,
upper_bounds: Array1<f64>,
) -> SpatialResult<Self> {
let dimension = lower_bounds.len();
if lower_bounds.len() != upper_bounds.len() {
return Err(SpatialError::DimensionError(
"Lower and upper _bounds must have the same dimension".to_string(),
));
}
let seed = config.seed.unwrap_or_else(scirs2_core::random::random);
let rng = StdRng::seed_from_u64(seed);
Ok(PRMPlanner {
config,
bounds: (lower_bounds, upper_bounds),
dimension,
nodes: Vec::new(),
kdtree: None,
rng,
collision_checker: None,
roadmap_built: false,
})
}
pub fn set_collision_checker<F>(&mut self, checker: Box<F>)
where
F: Fn(&Array1<f64>) -> bool + 'static,
{
self.collision_checker = Some(checker);
}
fn sample_config(&mut self) -> Array1<f64> {
let mut config = Array1::zeros(self.dimension);
for i in 0..self.dimension {
let lower = self.bounds.0[i];
let upper = self.bounds.1[i];
config[i] = self.rng.random_range(lower..upper);
}
config
}
#[allow(dead_code)]
fn sample_near(&mut self, target: &Array1<f64>, radius: f64) -> Array1<f64> {
let mut config = Array1::zeros(self.dimension);
for i in 0..self.dimension {
let lower = (target[i] - radius).max(self.bounds.0[i]);
let upper = (target[i] + radius).min(self.bounds.1[i]);
config[i] = self.rng.random_range(lower..upper);
}
config
}
fn is_collision_free(&self, config: &Array1<f64>) -> bool {
match &self.collision_checker {
Some(checker) => !checker(config),
None => true, }
}
fn is_path_collision_free(&self, from: &Array1<f64>, to: &Array1<f64>) -> bool {
const NUM_CHECKS: usize = 10;
for i in 0..=NUM_CHECKS {
let t = i as f64 / NUM_CHECKS as f64;
let mut point = Array1::zeros(self.dimension);
for j in 0..self.dimension {
point[j] = from[j] * (1.0 - t) + to[j] * t;
}
if !self.is_collision_free(&point) {
return false;
}
}
true
}
pub fn build_roadmap(&mut self) -> SpatialResult<()> {
if self.roadmap_built {
return Ok(());
}
self.nodes.clear();
let mut configs = Vec::new();
for _ in 0..self.config.num_samples {
let config = self.sample_config();
if self.is_collision_free(&config) {
configs.push(config);
}
}
for (i, config) in configs.iter().enumerate() {
self.nodes.push(PRMNode::new(i, config.clone()));
}
let mut points = Vec::new();
for node in &self.nodes {
points.push(node.config.clone());
}
let n_points = points.len();
let dim = if n_points > 0 { points[0].len() } else { 0 };
let mut points_array = Array2::<f64>::zeros((n_points, dim));
for (i, p) in points.iter().enumerate() {
points_array.row_mut(i).assign(&p.view());
}
self.kdtree = Some(KDTree::new(&points_array)?);
for i in 0..self.nodes.len() {
let node_config = self.nodes[i].config.clone();
let nearby = match &self.kdtree {
Some(kdtree) => {
let node_slice = node_config.as_slice().ok_or_else(|| {
SpatialError::ComputationError(
"Failed to convert node config to slice (non-contiguous memory layout)"
.into(),
)
})?;
kdtree.query_radius(node_slice, self.config.connection_radius)?
}
None => (Vec::new(), Vec::new()),
};
let mut connections = Vec::new();
let (indices, distances) = nearby;
for (idx, &j) in indices.iter().enumerate() {
let distance = distances[idx];
if i == j {
continue;
}
let from_config = &self.nodes[i].config;
let to_config = &self.nodes[j].config;
if self.is_path_collision_free(from_config, to_config) {
connections.push((j, distance));
}
}
connections.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
connections.truncate(self.config.max_connections);
for (j, distance) in connections {
self.nodes[i].add_neighbor(j, distance);
self.nodes[j].add_neighbor(i, distance); }
}
self.roadmap_built = true;
Ok(())
}
pub fn find_path(
&mut self,
start: &Array1<f64>,
goal: &Array1<f64>,
) -> SpatialResult<Option<Path<Array1<f64>>>> {
if !self.roadmap_built {
self.build_roadmap()?;
}
if !self.is_collision_free(start) {
return Err(SpatialError::ValueError(
"Start configuration is in collision".to_string(),
));
}
if !self.is_collision_free(goal) {
return Err(SpatialError::ValueError(
"Goal configuration is in collision".to_string(),
));
}
let start_id = self.nodes.len();
let goalid = start_id + 1;
let mut start_node = PRMNode::new(start_id, start.clone());
let mut goal_node = PRMNode::new(goalid, goal.clone());
for i in 0..self.nodes.len() {
let node_config = self.nodes[i].config.clone();
let start_distance = euclidean_distance(&start.view(), &node_config.view())?;
if start_distance <= self.config.connection_radius
&& self.is_path_collision_free(start, &node_config)
{
start_node.add_neighbor(i, start_distance);
self.nodes[i].add_neighbor(start_id, start_distance);
}
let goal_distance = euclidean_distance(&goal.view(), &node_config.view())?;
if goal_distance <= self.config.connection_radius
&& self.is_path_collision_free(goal, &node_config)
{
goal_node.add_neighbor(i, goal_distance);
self.nodes[i].add_neighbor(goalid, goal_distance);
}
}
let start_goal_distance = euclidean_distance(&start.view(), &goal.view())?;
if start_goal_distance <= self.config.connection_radius
&& self.is_path_collision_free(start, goal)
{
start_node.add_neighbor(goalid, start_goal_distance);
goal_node.add_neighbor(start_id, start_goal_distance);
}
self.nodes.push(start_node);
self.nodes.push(goal_node);
let path = self.astar_search(start_id, goalid);
self.nodes.pop(); self.nodes.pop();
for node in &mut self.nodes {
node.neighbors.retain(|(id_, _)| *id_ < start_id);
}
match path {
Some((node_path, cost)) => {
let mut configs = Vec::new();
for &id in &node_path {
if id == start_id {
configs.push(start.clone());
} else if id == goalid {
configs.push(goal.clone());
} else {
configs.push(self.nodes[id].config.clone());
}
}
Ok(Some(Path::new(configs, cost)))
}
None => Ok(None),
}
}
fn astar_search(&self, start_id: usize, goalid: usize) -> Option<(Vec<usize>, f64)> {
let mut open_set = BinaryHeap::new();
let mut closed_set = HashSet::new();
let mut came_from = HashMap::new();
let mut g_scores = HashMap::new();
g_scores.insert(start_id, 0.0);
let h_score = euclidean_distance(
&self.nodes[start_id].config.view(),
&self.nodes[goalid].config.view(),
)
.unwrap_or(f64::MAX);
open_set.push(SearchNode {
id: start_id,
g_cost: 0.0,
f_cost: h_score,
_parent: None,
});
while let Some(current) = open_set.pop() {
if current.id == goalid {
let mut path = Vec::new();
let mut current_id = current.id;
path.push(current_id);
while let Some(parent_id) = came_from.get(¤t_id) {
path.push(*parent_id);
current_id = *parent_id;
}
path.reverse();
return Some((path, current.g_cost));
}
if closed_set.contains(¤t.id) {
continue;
}
closed_set.insert(current.id);
for &(_neighborid, edge_cost) in &self.nodes[current.id].neighbors {
if closed_set.contains(&_neighborid) {
continue;
}
let tentative_g_score = g_scores[¤t.id] + edge_cost;
if !g_scores.contains_key(&_neighborid)
|| tentative_g_score < g_scores[&_neighborid]
{
came_from.insert(_neighborid, current.id);
g_scores.insert(_neighborid, tentative_g_score);
let h_score = euclidean_distance(
&self.nodes[_neighborid].config.view(),
&self.nodes[goalid].config.view(),
)
.unwrap_or(f64::MAX);
let f_score = tentative_g_score + h_score;
open_set.push(SearchNode {
id: _neighborid,
g_cost: tentative_g_score,
f_cost: f_score,
_parent: Some(current.id),
});
}
}
}
None
}
pub fn create_2d_with_polygons(
config: PRMConfig,
obstacles: Vec<Vec<[f64; 2]>>,
x_range: (f64, f64),
y_range: (f64, f64),
) -> Self {
let lower_bounds = Array1::from_vec(vec![x_range.0, y_range.0]);
let upper_bounds = Array1::from_vec(vec![x_range.1, y_range.1]);
let collision_checker = Box::new(move |p: &Array1<f64>| {
let point = [p[0], p[1]];
for obstacle in &obstacles {
if point_in_polygon(&point, obstacle) {
return true; }
}
false });
let mut planner = Self::new(config, lower_bounds, upper_bounds)
.expect("Lower and upper bounds should have same dimension (2)");
planner.set_collision_checker(collision_checker);
planner
}
}
#[derive(Debug)]
pub struct PRM2DPlanner {
planner: PRMPlanner,
obstacles: Vec<Vec<[f64; 2]>>,
}
impl PRM2DPlanner {
pub fn new(
config: PRMConfig,
obstacles: Vec<Vec<[f64; 2]>>,
x_range: (f64, f64),
y_range: (f64, f64),
) -> Self {
let planner =
PRMPlanner::create_2d_with_polygons(config, obstacles.clone(), x_range, y_range);
PRM2DPlanner { planner, obstacles }
}
pub fn build_roadmap(&mut self) -> SpatialResult<()> {
self.planner.build_roadmap()
}
pub fn find_path(
&mut self,
start: [f64; 2],
goal: [f64; 2],
) -> SpatialResult<Option<Path<Array1<f64>>>> {
let start_array = Array1::from_vec(vec![start[0], start[1]]);
let goal_array = Array1::from_vec(vec![goal[0], goal[1]]);
for obstacle in &self.obstacles {
if point_in_polygon(&start, obstacle) {
return Err(SpatialError::ValueError(
"Start point is inside an obstacle".to_string(),
));
}
if point_in_polygon(&goal, obstacle) {
return Err(SpatialError::ValueError(
"Goal point is inside an obstacle".to_string(),
));
}
}
self.planner.find_path(&start_array, &goal_array)
}
pub fn obstacles(&self) -> &Vec<Vec<[f64; 2]>> {
&self.obstacles
}
}
#[allow(dead_code)]
fn point_in_polygon(point: &[f64; 2], polygon: &[[f64; 2]]) -> bool {
let (x, y) = (point[0], point[1]);
let mut inside = false;
let n = polygon.len();
for i in 0..n {
let (x1, y1) = (polygon[i][0], polygon[i][1]);
let (x2, y2) = (polygon[(i + 1) % n][0], polygon[(i + 1) % n][1]);
let intersects = ((y1 > y) != (y2 > y)) && (x < (x2 - x1) * (y - y1) / (y2 - y1) + x1);
if intersects {
inside = !inside;
}
}
inside
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_point_in_polygon() {
let square = vec![[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]];
assert!(point_in_polygon(&[0.5, 0.5], &square));
assert!(point_in_polygon(&[0.1, 0.1], &square));
assert!(point_in_polygon(&[0.9, 0.9], &square));
assert!(!point_in_polygon(&[-0.1, 0.5], &square));
assert!(!point_in_polygon(&[0.5, -0.1], &square));
assert!(!point_in_polygon(&[1.1, 0.5], &square));
assert!(!point_in_polygon(&[0.5, 1.1], &square));
let complex = vec![[0.0, 0.0], [1.0, 1.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]];
assert!(point_in_polygon(&[1.0, 1.5], &complex));
assert!(!point_in_polygon(&[3.0, 1.0], &complex));
}
#[test]
fn test_prm_config() {
let config = PRMConfig::new()
.with_num_samples(500)
.with_connection_radius(0.8)
.with_max_connections(5)
.with_seed(42)
.with_goal_bias(0.1)
.with_goal_threshold(0.2);
assert_eq!(config.num_samples, 500);
assert_eq!(config.connection_radius, 0.8);
assert_eq!(config.max_connections, 5);
assert_eq!(config.seed, Some(42));
assert_eq!(config.goal_bias, 0.1);
assert_eq!(config.goal_threshold, 0.2);
}
#[test]
fn test_simple_path() {
let config = PRMConfig::new()
.with_num_samples(1000) .with_connection_radius(3.0) .with_seed(42);
let lower_bounds = array![0.0, 0.0];
let upper_bounds = array![10.0, 10.0];
let mut planner =
PRMPlanner::new(config, lower_bounds, upper_bounds).expect("Operation failed");
planner.build_roadmap().expect("Operation failed");
let start = array![1.0, 1.0];
let goal = array![9.0, 9.0];
if let Ok(Some(path)) = planner.find_path(&start, &goal) {
assert_eq!(path.nodes[0], start);
let last = path.nodes.last().expect("Operation failed");
let dx = last[0] - goal[0];
let dy = last[1] - goal[1];
let dist = (dx * dx + dy * dy).sqrt();
assert!(dist < 3.0);
assert!(path.cost < 20.0); } else {
println!(
"⚠️ No path found in PRM test - this is expected occasionally with random sampling"
);
}
}
#[test]
fn test_2d_planner() {
let obstacle = vec![[4.0, 4.0], [6.0, 4.0], [6.0, 6.0], [4.0, 6.0]];
let config = PRMConfig::new()
.with_num_samples(200)
.with_connection_radius(2.0)
.with_seed(42);
let mut planner = PRM2DPlanner::new(config, vec![obstacle], (0.0, 10.0), (0.0, 10.0));
planner.build_roadmap().expect("Operation failed");
let start = [1.0, 5.0];
let goal = [9.0, 5.0];
let path = planner.find_path(start, goal).expect("Operation failed");
assert!(path.is_some());
let path = path.expect("Operation failed");
assert!(path.nodes.len() > 2);
assert_relative_eq!(path.nodes[0][0], start[0], epsilon = 1e-5);
assert_relative_eq!(path.nodes[0][1], start[1], epsilon = 1e-5);
let last = path.nodes.last().expect("Operation failed");
assert_relative_eq!(last[0], goal[0], epsilon = 1e-5);
assert_relative_eq!(last[1], goal[1], epsilon = 1e-5);
}
}