scirs2_spatial/pathplanning/
rrt.rs

1//! Rapidly-exploring Random Tree (RRT) implementation
2//!
3//! The RRT algorithm incrementally builds a space-filling tree that
4//! efficiently explores high-dimensional spaces. It is particularly useful
5//! for motion planning in robotics and other applications with complex
6//! configuration spaces.
7//!
8//! This implementation includes:
9//! - Basic RRT for pathfinding
10//! - RRT* for optimal pathfinding
11//! - RRT-Connect for bi-directional search
12//!
13//! The algorithm works by:
14//! 1. Sampling random points in the space
15//! 2. Finding the nearest node in the tree to the sampled point
16//! 3. Extending the tree toward the sampled point
17//! 4. Repeating until the goal is reached or max iterations are exceeded
18
19use scirs2_core::ndarray::{Array1, ArrayView1};
20use scirs2_core::random::rngs::StdRng;
21use scirs2_core::random::{Rng, SeedableRng};
22use std::fmt::Debug;
23
24use crate::distance::EuclideanDistance;
25use crate::error::{SpatialError, SpatialResult};
26use crate::kdtree::KDTree;
27use crate::pathplanning::astar::Path;
28// use crate::safe_conversions::*;
29
30/// Type alias for the collision checking function
31type CollisionCheckFn = Box<dyn Fn(&Array1<f64>, &Array1<f64>) -> bool>;
32
33/// Configuration options for the RRT algorithm
34#[derive(Clone, Debug)]
35pub struct RRTConfig {
36    /// Maximum number of iterations before giving up
37    pub max_iterations: usize,
38    /// Maximum step size for tree extension
39    pub step_size: f64,
40    /// Goal bias (probability of sampling the goal directly)
41    pub goal_bias: f64,
42    /// Optional random seed for reproducibility
43    pub seed: Option<u64>,
44    /// Whether to use RRT* algorithm for optimality
45    pub use_rrt_star: bool,
46    /// Neighborhood radius for RRT* rewiring
47    pub neighborhood_radius: Option<f64>,
48    /// Whether to use bi-directional RRT (RRT-Connect)
49    pub bidirectional: bool,
50}
51
52impl Default for RRTConfig {
53    fn default() -> Self {
54        RRTConfig {
55            max_iterations: 10000,
56            step_size: 0.5,
57            goal_bias: 0.05,
58            seed: None,
59            use_rrt_star: false,
60            neighborhood_radius: None,
61            bidirectional: false,
62        }
63    }
64}
65
66/// Tree node for RRT algorithm
67#[derive(Clone, Debug)]
68struct RRTNode {
69    /// Position in the configuration space
70    position: Array1<f64>,
71    /// Index of parent node
72    parent: Option<usize>,
73    /// Cost from the start (used in RRT*)
74    cost: f64,
75}
76
77/// Rapidly-exploring Random Tree (RRT) planner
78pub struct RRTPlanner {
79    /// Configuration options
80    config: RRTConfig,
81    /// Collision checking function
82    collision_checker: Option<CollisionCheckFn>,
83    /// Random number generator
84    rng: StdRng,
85    /// Dimension of the configuration space
86    dimension: usize,
87    /// Bounds of the configuration space (min, max)
88    bounds: Option<(Array1<f64>, Array1<f64>)>,
89}
90
91impl Debug for RRTPlanner {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("RRTPlanner")
94            .field("config", &self.config)
95            .field("dimension", &self.dimension)
96            .field("bounds", &self.bounds)
97            .field("collision_checker", &"<function>")
98            .finish()
99    }
100}
101
102impl Clone for RRTPlanner {
103    fn clone(&self) -> Self {
104        Self {
105            config: self.config.clone(),
106            collision_checker: None, // We can't clone the collision checker function
107            rng: StdRng::seed_from_u64(scirs2_core::random::random()), // Create a new random number generator
108            dimension: self.dimension,
109            bounds: self.bounds.clone(),
110        }
111    }
112}
113
114impl RRTPlanner {
115    /// Create a new RRT planner with the given configuration
116    pub fn new(config: RRTConfig, dimension: usize) -> Self {
117        let seed = config.seed.unwrap_or_else(scirs2_core::random::random);
118        let rng = StdRng::seed_from_u64(seed);
119
120        RRTPlanner {
121            config,
122            collision_checker: None,
123            rng,
124            dimension,
125            bounds: None,
126        }
127    }
128
129    /// Set the collision checking function
130    ///
131    /// The function should return true if there is a collision between the two points
132    pub fn with_collision_checker<F>(mut self, collisionchecker: F) -> Self
133    where
134        F: Fn(&Array1<f64>, &Array1<f64>) -> bool + 'static,
135    {
136        self.collision_checker = Some(Box::new(collisionchecker));
137        self
138    }
139
140    /// Set the bounds of the configuration space
141    pub fn with_bounds(
142        mut self,
143        min_bounds: Array1<f64>,
144        max_bounds: Array1<f64>,
145    ) -> SpatialResult<Self> {
146        if min_bounds.len() != self.dimension || max_bounds.len() != self.dimension {
147            return Err(SpatialError::DimensionError(format!(
148                "Bounds dimensions ({}, {}) don't match planner dimension ({})",
149                min_bounds.len(),
150                max_bounds.len(),
151                self.dimension
152            )));
153        }
154
155        // Ensure min_bounds are actually less than max_bounds
156        for i in 0..self.dimension {
157            if min_bounds[i] >= max_bounds[i] {
158                return Err(SpatialError::ValueError(format!(
159                    "Min bound {} is not less than max bound {} at index {}",
160                    min_bounds[i], max_bounds[i], i
161                )));
162            }
163        }
164
165        self.bounds = Some((min_bounds, max_bounds));
166        Ok(self)
167    }
168
169    /// Sample a random point in the configuration space
170    fn sample_random_point(&mut self) -> SpatialResult<Array1<f64>> {
171        let (min_bounds, max_bounds) = self.bounds.as_ref().ok_or_else(|| {
172            SpatialError::ValueError("Bounds must be set before sampling".to_string())
173        })?;
174        let mut point = Array1::zeros(self.dimension);
175
176        for i in 0..self.dimension {
177            point[i] = self.rng.gen_range(min_bounds[i]..max_bounds[i]);
178        }
179
180        Ok(point)
181    }
182
183    /// Sample a random point with goal bias
184    fn sample_with_goal_bias(&mut self, goal: &ArrayView1<f64>) -> SpatialResult<Array1<f64>> {
185        if self.rng.gen_range(0.0..1.0) < self.config.goal_bias {
186            Ok(goal.to_owned())
187        } else {
188            self.sample_random_point()
189        }
190    }
191
192    /// Find the nearest node in the tree to the given point
193    fn find_nearest_node(
194        &self,
195        point: &ArrayView1<f64>,
196        _nodes: &[RRTNode],
197        kdtree: &KDTree<f64, EuclideanDistance<f64>>,
198    ) -> SpatialResult<usize> {
199        let point_vec = point.to_vec();
200        let (indices, _) = kdtree.query(point_vec.as_slice(), 1)?;
201        Ok(indices[0])
202    }
203
204    /// Compute a new point that is step_size distance from nearest toward randompoint
205    fn steer(&self, nearest: &ArrayView1<f64>, randompoint: &ArrayView1<f64>) -> Array1<f64> {
206        let mut direction = randompoint - nearest;
207        let norm = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
208
209        if norm < 1e-10 {
210            return nearest.to_owned();
211        }
212
213        // Scale to step_size
214        if norm > self.config.step_size {
215            direction *= self.config.step_size / norm;
216        }
217
218        nearest + direction
219    }
220
221    /// Check if there is a valid path between two points
222    fn is_valid_connection(&self, from: &ArrayView1<f64>, to: &ArrayView1<f64>) -> bool {
223        if let Some(ref collision_checker) = self.collision_checker {
224            !collision_checker(&from.to_owned(), &to.to_owned())
225        } else {
226            true // No collision checker provided, assume valid
227        }
228    }
229
230    /// Find the path from start to goal using RRT
231    pub fn find_path(
232        &mut self,
233        start: ArrayView1<f64>,
234        goal: ArrayView1<f64>,
235        goal_threshold: f64,
236    ) -> SpatialResult<Option<Path<Array1<f64>>>> {
237        if start.len() != self.dimension || goal.len() != self.dimension {
238            return Err(SpatialError::DimensionError(format!(
239                "Start or goal dimensions ({}, {}) don't match planner dimension ({})",
240                start.len(),
241                goal.len(),
242                self.dimension
243            )));
244        }
245
246        if self.bounds.is_none() {
247            return Err(SpatialError::ValueError(
248                "Bounds must be set before planning".to_string(),
249            ));
250        }
251
252        if self.config.bidirectional {
253            self.find_path_bidirectional(start, goal, goal_threshold)
254        } else if self.config.use_rrt_star {
255            self.find_path_rrt_star(start, goal, goal_threshold)
256        } else {
257            self.find_path_basic_rrt(start, goal, goal_threshold)
258        }
259    }
260
261    /// Find the path using basic RRT algorithm
262    fn find_path_basic_rrt(
263        &mut self,
264        start: ArrayView1<f64>,
265        goal: ArrayView1<f64>,
266        goal_threshold: f64,
267    ) -> SpatialResult<Option<Path<Array1<f64>>>> {
268        // Initialize tree with start node
269        let mut nodes = vec![RRTNode {
270            position: start.to_owned(),
271            parent: None,
272            cost: 0.0,
273        }];
274
275        for _ in 0..self.config.max_iterations {
276            // Sample random point with goal bias
277            let randompoint = self.sample_with_goal_bias(&goal)?;
278
279            // Build KDTree for nearest neighbor search
280            let points: Vec<_> = nodes.iter().map(|node| node.position.clone()).collect();
281            let points_array = scirs2_core::ndarray::stack(
282                scirs2_core::ndarray::Axis(0),
283                &points.iter().map(|p| p.view()).collect::<Vec<_>>(),
284            )
285            .expect("Failed to stack points");
286            let kdtree = KDTree::<f64, EuclideanDistance<f64>>::new(&points_array)
287                .expect("Failed to build KDTree");
288
289            // Find nearest node
290            let nearest_idx = self.find_nearest_node(&randompoint.view(), &nodes, &kdtree)?;
291
292            // Create temporary copies to avoid borrowing conflicts
293            let nearest_position = nodes[nearest_idx].position.clone();
294            let nearest_cost = nodes[nearest_idx].cost;
295
296            // Steer toward random point
297            let new_point = self.steer(&nearest_position.view(), &randompoint.view());
298
299            // Check if the connection is valid
300            if self.is_valid_connection(&nearest_position.view(), &new_point.view()) {
301                // Add new node to the tree
302                let new_node = RRTNode {
303                    position: new_point.clone(),
304                    parent: Some(nearest_idx),
305                    cost: nearest_cost
306                        + euclidean_distance(&nearest_position.view(), &new_point.view()),
307                };
308                nodes.push(new_node);
309
310                // Check if we've reached the goal
311                if euclidean_distance(&new_point.view(), &goal) <= goal_threshold {
312                    // Extract the path
313                    return Ok(Some(RRTPlanner::extract_path(&nodes, nodes.len() - 1)));
314                }
315            }
316        }
317
318        // Failed to find a path within max_iterations
319        Ok(None)
320    }
321
322    /// Find the path using RRT* algorithm (optimized version of RRT)
323    fn find_path_rrt_star(
324        &mut self,
325        start: ArrayView1<f64>,
326        goal: ArrayView1<f64>,
327        goal_threshold: f64,
328    ) -> SpatialResult<Option<Path<Array1<f64>>>> {
329        // Initialize tree with start node
330        let mut nodes = vec![RRTNode {
331            position: start.to_owned(),
332            parent: None,
333            cost: 0.0,
334        }];
335
336        // Goal node index, if found
337        let mut goalidx: Option<usize> = None;
338        let neighborhood_radius = self
339            .config
340            .neighborhood_radius
341            .unwrap_or(self.config.step_size * 2.0);
342
343        for _ in 0..self.config.max_iterations {
344            // Sample random point with goal bias
345            let randompoint = self.sample_with_goal_bias(&goal)?;
346
347            // Build KDTree for nearest neighbor search
348            let points: Vec<_> = nodes.iter().map(|node| node.position.clone()).collect();
349            let points_array = scirs2_core::ndarray::stack(
350                scirs2_core::ndarray::Axis(0),
351                &points.iter().map(|p| p.view()).collect::<Vec<_>>(),
352            )
353            .expect("Failed to stack points");
354            let kdtree = KDTree::<f64, EuclideanDistance<f64>>::new(&points_array)
355                .expect("Failed to build KDTree");
356
357            // Find nearest node
358            let nearest_idx = self.find_nearest_node(&randompoint.view(), &nodes, &kdtree)?;
359
360            // Create temporary copies to avoid borrowing conflicts
361            let nearest_position = nodes[nearest_idx].position.clone();
362
363            // Steer toward random point
364            let new_point = self.steer(&nearest_position.view(), &randompoint.view());
365
366            // Check if the connection is valid
367            if self.is_valid_connection(&nearest_position.view(), &new_point.view()) {
368                // Find the best parent for the new node
369                let (parent_idx, cost_from_parent) = self.find_best_parent(
370                    &new_point,
371                    &nodes,
372                    &kdtree,
373                    nearest_idx,
374                    neighborhood_radius,
375                );
376
377                // Add new node to the tree
378                let new_node_idx = nodes.len();
379                let parent_cost = nodes[parent_idx].cost;
380                let new_node = RRTNode {
381                    position: new_point.clone(),
382                    parent: Some(parent_idx),
383                    cost: parent_cost + cost_from_parent,
384                };
385                nodes.push(new_node);
386
387                // Rewire the tree (RRT* optimization)
388                self.rewire_tree(&mut nodes, new_node_idx, &kdtree, neighborhood_radius);
389
390                // Check if we've reached the goal
391                let dist_to_goal = euclidean_distance(&new_point.view(), &goal);
392                if dist_to_goal <= goal_threshold {
393                    // Update goal index if we found a better path
394                    let new_cost = nodes[new_node_idx].cost + dist_to_goal;
395                    if let Some(idx) = goalidx {
396                        if new_cost < nodes[idx].cost {
397                            goalidx = Some(new_node_idx);
398                        }
399                    } else {
400                        goalidx = Some(new_node_idx);
401                    }
402                }
403            }
404        }
405
406        // Extract the path if goal was reached
407        if let Some(idx) = goalidx {
408            Ok(Some(RRTPlanner::extract_path(&nodes, idx)))
409        } else {
410            Ok(None)
411        }
412    }
413
414    /// Find the best parent for a new node
415    fn find_best_parent(
416        &self,
417        new_point: &Array1<f64>,
418        nodes: &[RRTNode],
419        kdtree: &KDTree<f64, EuclideanDistance<f64>>,
420        nearest_idx: usize,
421        radius: f64,
422    ) -> (usize, f64) {
423        let mut best_parent_idx = nearest_idx;
424        let mut best_cost = nodes[nearest_idx].cost
425            + euclidean_distance(&nodes[nearest_idx].position.view(), &new_point.view());
426
427        // Find all nodes within the neighborhood
428        let (near_indices, near_distances) = kdtree
429            .query_radius(new_point.as_slice().expect("Operation failed"), radius)
430            .expect("KDTree query failed");
431
432        // Check each nearby node as a potential parent
433        for (_idx, &nodeidx) in near_indices.iter().enumerate() {
434            let node = &nodes[nodeidx];
435            let dist = near_distances[_idx];
436
437            // Calculate the cost if we came through this node
438            let cost_from_start = node.cost + dist;
439
440            // Update best parent if this path is cheaper
441            if cost_from_start < best_cost
442                && self.is_valid_connection(&node.position.view(), &new_point.view())
443            {
444                best_parent_idx = nodeidx;
445                best_cost = cost_from_start;
446            }
447        }
448
449        // Return the best parent and the cost from that parent
450        let cost_from_parent =
451            euclidean_distance(&nodes[best_parent_idx].position.view(), &new_point.view());
452        (best_parent_idx, cost_from_parent)
453    }
454
455    /// Rewire the tree to optimize paths (RRT* step)
456    fn rewire_tree(
457        &self,
458        nodes: &mut [RRTNode],
459        new_node_idx: usize,
460        kdtree: &KDTree<f64, EuclideanDistance<f64>>,
461        radius: f64,
462    ) {
463        // Create temporary copies to avoid borrowing conflicts
464        let new_point = nodes[new_node_idx].position.clone();
465        let new_cost = nodes[new_node_idx].cost;
466
467        // Find all nodes within the neighborhood
468        let (near_indices, near_distances) = kdtree
469            .query_radius(new_point.as_slice().expect("Operation failed"), radius)
470            .expect("KDTree query failed");
471
472        // Check if we can improve the path to any nearby node by going through the new node
473        for (_idx, &nodeidx) in near_indices.iter().enumerate() {
474            // Skip the node itself
475            if nodeidx == new_node_idx {
476                continue;
477            }
478
479            let dist = near_distances[_idx];
480            let cost_through_new = new_cost + dist;
481
482            // Create temporary copy of the position to avoid borrowing conflicts
483            let node_position = nodes[nodeidx].position.clone();
484
485            // If the path through the new node is better, rewire
486            if cost_through_new < nodes[nodeidx].cost
487                && self.is_valid_connection(&new_point.view(), &node_position.view())
488            {
489                nodes[nodeidx].parent = Some(new_node_idx);
490                nodes[nodeidx].cost = cost_through_new;
491
492                // Recursively update costs in the subtree
493                RRTPlanner::update_subtree_costs(nodes, nodeidx);
494            }
495        }
496    }
497
498    /// Update costs in a subtree after rewiring
499    #[allow(clippy::only_used_in_recursion)]
500    fn update_subtree_costs(nodes: &mut [RRTNode], nodeidx: usize) {
501        // Find all children of this node
502        let children: Vec<usize> = nodes
503            .iter()
504            .enumerate()
505            .filter(|(_, node)| node.parent == Some(nodeidx))
506            .map(|(idx, _)| idx)
507            .collect();
508
509        // Update each child's cost and recursively update its subtree
510        for &child_idx in &children {
511            // Create temporary copies to avoid borrowing conflicts
512            let parent_cost = nodes[nodeidx].cost;
513            let parent_position = nodes[nodeidx].position.clone();
514            let child_position = nodes[child_idx].position.clone();
515
516            let edge_cost = euclidean_distance(&parent_position.view(), &child_position.view());
517            nodes[child_idx].cost = parent_cost + edge_cost;
518
519            // Recursively update this child's subtree
520            Self::update_subtree_costs(nodes, child_idx);
521        }
522    }
523
524    /// Find the path using bi-directional RRT (RRT-Connect)
525    fn find_path_bidirectional(
526        &mut self,
527        start: ArrayView1<f64>,
528        goal: ArrayView1<f64>,
529        goal_threshold: f64,
530    ) -> SpatialResult<Option<Path<Array1<f64>>>> {
531        // Initialize two trees (one from start, one from goal)
532        let mut start_tree = vec![RRTNode {
533            position: start.to_owned(),
534            parent: None,
535            cost: 0.0,
536        }];
537
538        let mut goal_tree = vec![RRTNode {
539            position: goal.to_owned(),
540            parent: None,
541            cost: 0.0,
542        }];
543
544        // Tree A starts as the start tree, Tree B as the goal tree
545        let mut tree_a = &mut start_tree;
546        let mut tree_b = &mut goal_tree;
547        let mut a_is_start = true;
548
549        // Indices of connecting nodes between trees, if found
550        let mut connection: Option<(usize, usize)> = None;
551
552        for _ in 0..self.config.max_iterations {
553            // Swap trees every iteration
554            std::mem::swap(&mut tree_a, &mut tree_b);
555            a_is_start = !a_is_start;
556
557            // Sample random point (from tree A's perspective)
558            let target = if a_is_start {
559                goal.to_owned()
560            } else {
561                start.to_owned()
562            };
563            let randompoint = self.sample_with_goal_bias(&target.view())?;
564
565            // Build KDTree for tree A
566            let points_a: Vec<_> = tree_a.iter().map(|node| node.position.clone()).collect();
567            let points_array_a = scirs2_core::ndarray::stack(
568                scirs2_core::ndarray::Axis(0),
569                &points_a.iter().map(|p| p.view()).collect::<Vec<_>>(),
570            )
571            .expect("Failed to stack points");
572            let kdtree_a = KDTree::<f64, EuclideanDistance<f64>>::new(&points_array_a)
573                .expect("Failed to build KDTree");
574
575            // Find nearest node in tree A
576            let nearest_idxa = self.find_nearest_node(&randompoint.view(), tree_a, &kdtree_a)?;
577
578            // Create temporary copy to avoid borrowing conflicts
579            let nearest_position = tree_a[nearest_idxa].position.clone();
580            let nearest_cost = tree_a[nearest_idxa].cost;
581
582            // Steer from nearest in A toward random point
583            let new_point = self.steer(&nearest_position.view(), &randompoint.view());
584
585            // Check if the connection is valid
586            if self.is_valid_connection(&nearest_position.view(), &new_point.view()) {
587                // Add new node to tree A
588                let new_cost =
589                    nearest_cost + euclidean_distance(&nearest_position.view(), &new_point.view());
590                let new_node_idx_a = tree_a.len();
591                tree_a.push(RRTNode {
592                    position: new_point.clone(),
593                    parent: Some(nearest_idxa),
594                    cost: new_cost,
595                });
596
597                // Build KDTree for tree B
598                let points_b: Vec<_> = tree_b.iter().map(|node| node.position.clone()).collect();
599                let points_array_b = scirs2_core::ndarray::stack(
600                    scirs2_core::ndarray::Axis(0),
601                    &points_b.iter().map(|p| p.view()).collect::<Vec<_>>(),
602                )
603                .expect("Failed to stack points");
604                let kdtree_b = KDTree::<f64, EuclideanDistance<f64>>::new(&points_array_b)
605                    .expect("Failed to build KDTree");
606
607                // Find nearest node in tree B
608                let nearest_idxb = self.find_nearest_node(&new_point.view(), tree_b, &kdtree_b)?;
609
610                // Create temporary copy to avoid borrowing conflicts
611                let nearest_position_b = tree_b[nearest_idxb].position.clone();
612
613                // Check if trees can be connected
614                let dist_between_trees =
615                    euclidean_distance(&new_point.view(), &nearest_position_b.view());
616                if dist_between_trees <= goal_threshold
617                    && self.is_valid_connection(&new_point.view(), &nearest_position_b.view())
618                {
619                    // Trees can be connected! Store the connection indices
620                    connection = if a_is_start {
621                        Some((new_node_idx_a, nearest_idxb))
622                    } else {
623                        Some((nearest_idxb, new_node_idx_a))
624                    };
625                    break;
626                }
627            }
628        }
629
630        // Extract the path if trees were connected
631        if let Some((start_idx, goalidx)) = connection {
632            let path = self.extract_bidirectional_path(&start_tree, &goal_tree, start_idx, goalidx);
633            Ok(Some(path))
634        } else {
635            Ok(None)
636        }
637    }
638
639    /// Extract the path from a bi-directional search
640    fn extract_bidirectional_path(
641        &self,
642        start_tree: &[RRTNode],
643        goal_tree: &[RRTNode],
644        start_idx: usize,
645        goalidx: usize,
646    ) -> Path<Array1<f64>> {
647        // Extract path from start to connection point
648        let mut forward_path = Vec::new();
649        let mut current_idx = Some(start_idx);
650        while let Some(_idx) = current_idx {
651            forward_path.push(start_tree[_idx].position.clone());
652            current_idx = start_tree[_idx].parent;
653        }
654        forward_path.reverse(); // Reverse to get start to connection
655
656        // Extract path from goal to connection point
657        let mut backward_path = Vec::new();
658        let mut current_idx = Some(goalidx);
659        while let Some(_idx) = current_idx {
660            backward_path.push(goal_tree[_idx].position.clone());
661            current_idx = goal_tree[_idx].parent;
662        }
663        // No need to reverse - we want connection to goal
664
665        // Combine paths
666        let mut full_path = forward_path;
667        full_path.extend(backward_path);
668
669        // Calculate total cost
670        let mut total_cost = 0.0;
671        for i in 1..full_path.len() {
672            total_cost += euclidean_distance(&full_path[i - 1].view(), &full_path[i].view());
673        }
674
675        Path::new(full_path, total_cost)
676    }
677
678    /// Extract the path from the RRT tree
679    fn extract_path(nodes: &[RRTNode], goalidx: usize) -> Path<Array1<f64>> {
680        let mut path = Vec::new();
681        let mut current_idx = Some(goalidx);
682        let cost = nodes[goalidx].cost;
683
684        while let Some(_idx) = current_idx {
685            path.push(nodes[_idx].position.clone());
686            current_idx = nodes[_idx].parent;
687        }
688
689        // Reverse to get start to goal
690        path.reverse();
691
692        Path::new(path, cost)
693    }
694}
695
696/// Helper function to calculate Euclidean distance between points
697#[allow(dead_code)]
698fn euclidean_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
699    let mut sum = 0.0;
700    for i in 0..a.len() {
701        let diff = a[i] - b[i];
702        sum += diff * diff;
703    }
704    sum.sqrt()
705}
706
707/// A 2D RRT planner that works with polygon obstacles
708#[derive(Clone)]
709pub struct RRT2DPlanner {
710    /// The RRT planner
711    planner: RRTPlanner,
712    /// Obstacle polygons (each polygon is a vector of 2D points)
713    obstacles: Vec<Vec<[f64; 2]>>,
714    /// Step size for collision checking
715    _collision_step_size: f64,
716}
717
718impl RRT2DPlanner {
719    /// Create a new 2D RRT planner
720    pub fn new(
721        config: RRTConfig,
722        obstacles: Vec<Vec<[f64; 2]>>,
723        min_bounds: [f64; 2],
724        max_bounds: [f64; 2],
725        collision_step_size: f64,
726    ) -> SpatialResult<Self> {
727        let mut planner = RRTPlanner::new(config, 2);
728        planner = planner.with_bounds(
729            scirs2_core::ndarray::arr1(&min_bounds),
730            scirs2_core::ndarray::arr1(&max_bounds),
731        )?;
732
733        let obstacles_clone = obstacles.clone();
734        planner = planner.with_collision_checker(move |from, to| {
735            Self::check_collision_with_obstacles(from, to, &obstacles_clone, collision_step_size)
736        });
737
738        Ok(RRT2DPlanner {
739            planner,
740            obstacles: obstacles.clone(),
741            _collision_step_size: collision_step_size,
742        })
743    }
744
745    /// Find a path from start to goal
746    pub fn find_path(
747        &mut self,
748        start: [f64; 2],
749        goal: [f64; 2],
750        goal_threshold: f64,
751    ) -> SpatialResult<Option<Path<[f64; 2]>>> {
752        let start_arr = scirs2_core::ndarray::arr1(&start);
753        let goal_arr = scirs2_core::ndarray::arr1(&goal);
754
755        // Check if start or goal are in collision
756        for obstacle in &self.obstacles {
757            if Self::point_in_polygon(&start, obstacle) {
758                return Err(SpatialError::ValueError(
759                    "Start point is inside an obstacle".to_string(),
760                ));
761            }
762            if Self::point_in_polygon(&goal, obstacle) {
763                return Err(SpatialError::ValueError(
764                    "Goal point is inside an obstacle".to_string(),
765                ));
766            }
767        }
768
769        // Find path using RRT
770        let result = self
771            .planner
772            .find_path(start_arr.view(), goal_arr.view(), goal_threshold)?;
773
774        // Convert path to [f64; 2] format
775        if let Some(path) = result {
776            let nodes: Vec<[f64; 2]> = path.nodes.iter().map(|p| [p[0], p[1]]).collect();
777            Ok(Some(Path::new(nodes, path.cost)))
778        } else {
779            Ok(None)
780        }
781    }
782
783    /// Check if a line segment collides with any obstacle
784    fn check_collision_with_obstacles(
785        from: &Array1<f64>,
786        to: &Array1<f64>,
787        obstacles: &[Vec<[f64; 2]>],
788        step_size: f64,
789    ) -> bool {
790        let from_point = [from[0], from[1]];
791        let to_point = [to[0], to[1]];
792
793        // First, check if either endpoint is inside an obstacle
794        for obstacle in obstacles {
795            if Self::point_in_polygon(&from_point, obstacle)
796                || Self::point_in_polygon(&to_point, obstacle)
797            {
798                return true;
799            }
800        }
801
802        // Check if the line segment intersects any obstacle
803        let dx = to[0] - from[0];
804        let dy = to[1] - from[1];
805        let distance = (dx * dx + dy * dy).sqrt();
806
807        if distance < 1e-6 {
808            return false; // Points are too close
809        }
810
811        let steps = (distance / step_size).ceil() as usize;
812
813        for i in 1..steps {
814            let t = i as f64 / steps as f64;
815            let x = from[0] + dx * t;
816            let y = from[1] + dy * t;
817            let point = [x, y];
818
819            for obstacle in obstacles {
820                if Self::point_in_polygon(&point, obstacle) {
821                    return true;
822                }
823            }
824        }
825
826        false
827    }
828
829    /// Check if a point is inside a polygon using ray casting algorithm
830    fn point_in_polygon(point: &[f64; 2], polygon: &[[f64; 2]]) -> bool {
831        if polygon.len() < 3 {
832            return false;
833        }
834
835        let mut inside = false;
836        let mut j = polygon.len() - 1;
837
838        for i in 0..polygon.len() {
839            let xi = polygon[i][0];
840            let yi = polygon[i][1];
841            let xj = polygon[j][0];
842            let yj = polygon[j][1];
843
844            let intersect = ((yi > point[1]) != (yj > point[1]))
845                && (point[0] < (xj - xi) * (point[1] - yi) / (yj - yi) + xi);
846
847            if intersect {
848                inside = !inside;
849            }
850
851            j = i;
852        }
853
854        inside
855    }
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861
862    #[test]
863    fn test_rrt_empty_space() {
864        // Create an RRT planner in an empty 2D space
865        let config = RRTConfig {
866            max_iterations: 1000,
867            step_size: 0.5,
868            goal_bias: 0.1,
869            seed: Some(42), // For reproducibility
870            use_rrt_star: false,
871            neighborhood_radius: None,
872            bidirectional: false,
873        };
874
875        let mut planner = RRT2DPlanner::new(
876            config,
877            vec![],       // No obstacles
878            [0.0, 0.0],   // Min bounds
879            [10.0, 10.0], // Max bounds
880            0.1,          // Collision step size
881        )
882        .expect("Operation failed");
883
884        // Find a path from (1,1) to (9,9)
885        let start = [1.0, 1.0];
886        let goal = [9.0, 9.0];
887        let goal_threshold = 0.5;
888
889        let path = planner
890            .find_path(start, goal, goal_threshold)
891            .expect("Operation failed");
892
893        // A path should be found
894        assert!(path.is_some());
895        let path = path.expect("Operation failed");
896
897        // Path should start at start and end near goal
898        assert_eq!(path.nodes[0], start);
899        let last_node = path.nodes.last().expect("Operation failed");
900        let dist_to_goal =
901            ((last_node[0] - goal[0]).powi(2) + (last_node[1] - goal[1]).powi(2)).sqrt();
902        assert!(dist_to_goal <= goal_threshold);
903    }
904
905    #[test]
906    fn test_rrt_star_optimization() {
907        // Create an RRT* planner in an empty 2D space
908        let config = RRTConfig {
909            max_iterations: 1000,
910            step_size: 0.5,
911            goal_bias: 0.1,
912            seed: Some(42), // For reproducibility
913            use_rrt_star: true,
914            neighborhood_radius: Some(1.0),
915            bidirectional: false,
916        };
917
918        let mut planner = RRT2DPlanner::new(
919            config,
920            vec![],       // No obstacles
921            [0.0, 0.0],   // Min bounds
922            [10.0, 10.0], // Max bounds
923            0.1,          // Collision step size
924        )
925        .expect("Operation failed");
926
927        // Find a path from (1,1) to (9,9)
928        let start = [1.0, 1.0];
929        let goal = [9.0, 9.0];
930        let goal_threshold = 0.5;
931
932        let path = planner
933            .find_path(start, goal, goal_threshold)
934            .expect("Operation failed");
935
936        // A path should be found
937        assert!(path.is_some());
938        let path = path.expect("Operation failed");
939
940        // Path should start at start and end near goal
941        assert_eq!(path.nodes[0], start);
942        let last_node = path.nodes.last().expect("Operation failed");
943        let dist_to_goal =
944            ((last_node[0] - goal[0]).powi(2) + (last_node[1] - goal[1]).powi(2)).sqrt();
945        assert!(dist_to_goal <= goal_threshold);
946
947        // RRT* should produce a reasonably direct path
948        // Check that the path cost is not too much longer than the direct distance
949        let direct_distance = ((goal[0] - start[0]).powi(2) + (goal[1] - start[1]).powi(2)).sqrt();
950        assert!(path.cost <= direct_distance * 1.5);
951    }
952
953    #[test]
954    fn test_rrt_bidirectional() {
955        // Create a bidirectional RRT planner in an empty 2D space
956        let config = RRTConfig {
957            max_iterations: 1000,
958            step_size: 0.5,
959            goal_bias: 0.1,
960            seed: Some(42), // For reproducibility
961            use_rrt_star: false,
962            neighborhood_radius: None,
963            bidirectional: true,
964        };
965
966        let mut planner = RRT2DPlanner::new(
967            config,
968            vec![],       // No obstacles
969            [0.0, 0.0],   // Min bounds
970            [10.0, 10.0], // Max bounds
971            0.1,          // Collision step size
972        )
973        .expect("Operation failed");
974
975        // Find a path from (1,1) to (9,9)
976        let start = [1.0, 1.0];
977        let goal = [9.0, 9.0];
978        let goal_threshold = 0.5;
979
980        let path = planner
981            .find_path(start, goal, goal_threshold)
982            .expect("Operation failed");
983
984        // A path should be found
985        assert!(path.is_some());
986        let path = path.expect("Operation failed");
987
988        // Path should start at start and end near goal
989        assert_eq!(path.nodes[0], start);
990        let last_node = path.nodes.last().expect("Operation failed");
991        let dist_to_goal =
992            ((last_node[0] - goal[0]).powi(2) + (last_node[1] - goal[1]).powi(2)).sqrt();
993        assert!(dist_to_goal <= goal_threshold);
994    }
995
996    #[test]
997    fn test_rrt_with_obstacles() {
998        // Create an RRT planner with obstacles
999        let config = RRTConfig {
1000            max_iterations: 2000,
1001            step_size: 0.3,
1002            goal_bias: 0.1,
1003            seed: Some(42), // For reproducibility
1004            use_rrt_star: false,
1005            neighborhood_radius: None,
1006            bidirectional: false,
1007        };
1008
1009        // Define a wall obstacle that divides the space
1010        let obstacles = vec![vec![[4.0, 0.0], [5.0, 0.0], [5.0, 8.0], [4.0, 8.0]]];
1011
1012        let mut planner = RRT2DPlanner::new(
1013            config,
1014            obstacles,
1015            [0.0, 0.0],   // Min bounds
1016            [10.0, 10.0], // Max bounds
1017            0.1,          // Collision step size
1018        )
1019        .expect("Operation failed");
1020
1021        // Find a path from left side to right side of the wall
1022        let start = [2.0, 5.0];
1023        let goal = [7.0, 5.0];
1024        let goal_threshold = 0.5;
1025
1026        let path = planner
1027            .find_path(start, goal, goal_threshold)
1028            .expect("Operation failed");
1029
1030        // A path should be found
1031        assert!(path.is_some());
1032        let path = path.expect("Operation failed");
1033
1034        // Path should start at start and end near goal
1035        assert_eq!(path.nodes[0], start);
1036        let last_node = path.nodes.last().expect("Operation failed");
1037        let dist_to_goal =
1038            ((last_node[0] - goal[0]).powi(2) + (last_node[1] - goal[1]).powi(2)).sqrt();
1039        assert!(dist_to_goal <= goal_threshold);
1040
1041        // The path should go around the wall (y < 0 or y > 8)
1042        // Check that no point in the path is inside the wall
1043        for node in &path.nodes {
1044            assert!(!(node[0] >= 4.0 && node[0] <= 5.0 && node[1] >= 0.0 && node[1] <= 8.0));
1045        }
1046    }
1047}