use std::sync::Arc;
use crate::{
base::planner::PlannerConfig,
time::{Duration, Instant},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use crate::base::{
error::PlanningError,
goal::{Goal, GoalSampleableRegion},
planner::{Path, Planner},
problem_definition::ProblemDefinition,
space::StateSpace,
state::State,
validity::StateValidityChecker,
};
#[derive(Clone)]
struct Node<S: State> {
state: S,
parent_index: Option<usize>,
cost: f64,
}
pub struct RRTStar<S: State, SP: StateSpace<StateType = S>, G: Goal<S>> {
pub max_distance: f64,
pub goal_bias: f64,
pub search_radius: f64,
problem_def: Option<Arc<ProblemDefinition<S, SP, G>>>,
validity_checker: Option<Arc<dyn StateValidityChecker<S>>>,
tree: Vec<Node<S>>,
rng: Option<Box<StdRng>>,
}
impl<S, SP, G> RRTStar<S, SP, G>
where
S: State + Clone,
SP: StateSpace<StateType = S>,
G: Goal<S> + GoalSampleableRegion<S>,
{
pub fn new(
max_distance: f64,
goal_bias: f64,
search_radius: f64,
config: &PlannerConfig,
) -> Self {
let rng = config.seed.map(|s| Box::new(StdRng::seed_from_u64(s)));
RRTStar {
max_distance,
goal_bias,
search_radius,
problem_def: None,
validity_checker: None,
tree: Vec::new(),
rng,
}
}
fn check_motion(&self, from: &S, to: &S) -> bool {
if let (Some(pd), Some(vc)) = (&self.problem_def, &self.validity_checker) {
let space = &pd.space;
let dist = space.distance(from, to);
let num_steps =
(dist / (space.get_longest_valid_segment_length() * 0.1)).ceil() as usize;
if num_steps <= 1 {
return vc.is_valid(to);
}
let mut interpolated_state = from.clone();
for i in 1..=num_steps {
let t = i as f64 / num_steps as f64;
space.interpolate(from, to, t, &mut interpolated_state);
if !vc.is_valid(&interpolated_state) {
return false;
}
}
true
} else {
false
}
}
fn cost(&self, current_node: &Node<S>, neighbour_node: &Node<S>) -> f64 {
if let Some(pd) = &self.problem_def {
neighbour_node.cost
+ pd.space
.distance(¤t_node.state, &neighbour_node.state)
} else {
f64::INFINITY
}
}
fn find_neighbours(&self, node: &Node<S>) -> Vec<usize> {
let mut neighbours: Vec<usize> = Vec::new();
if let Some(pd) = &self.problem_def {
for i in 0..self.tree.len() {
if pd.space.distance(&node.state, &self.tree[i].state) < self.search_radius {
neighbours.push(i);
}
}
}
neighbours
}
fn reconstruct_path(&self, start_node_idx: usize) -> Path<S> {
let mut path_states = Vec::new();
let mut current_index = Some(start_node_idx);
while let Some(index) = current_index {
path_states.push(self.tree[index].state.clone());
current_index = self.tree[index].parent_index;
}
path_states.reverse();
Path(path_states)
}
}
impl<S, SP, G> Planner<S, SP, G> for RRTStar<S, SP, G>
where
S: State + Clone,
SP: StateSpace<StateType = S>,
G: Goal<S> + GoalSampleableRegion<S>,
{
fn setup(
&mut self,
problem_def: Arc<ProblemDefinition<S, SP, G>>,
validity_checker: Arc<dyn StateValidityChecker<S>>,
) {
self.problem_def = Some(problem_def);
self.validity_checker = Some(validity_checker);
self.tree.clear();
let start_state = self.problem_def.as_ref().unwrap().start_states[0].clone();
let start_node = Node {
state: start_state,
parent_index: None,
cost: 0.0,
};
self.tree.push(start_node);
}
fn solve(&mut self, timeout: Duration) -> Result<Path<S>, PlanningError> {
let pd = self
.problem_def
.as_ref()
.ok_or(PlanningError::PlannerUninitialised)?;
let goal = &pd.goal;
let start_time = Instant::now();
let mut rng = self
.rng
.take()
.unwrap_or_else(|| Box::new(StdRng::from_os_rng()));
loop {
if start_time.elapsed() > timeout {
return Err(PlanningError::Timeout);
}
let q_rand = if rng.random_bool(self.goal_bias) {
goal.sample_goal(&mut rng).unwrap()
} else {
pd.space.sample_uniform(&mut rng).unwrap()
};
let mut nearest_node_index = 0;
let mut min_dist = pd.space.distance(&self.tree[0].state, &q_rand);
for i in 1..self.tree.len() {
let dist = pd.space.distance(&self.tree[i].state, &q_rand);
if dist < min_dist {
min_dist = dist;
nearest_node_index = i;
}
}
let q_near = &self.tree[nearest_node_index].state;
let mut q_new = q_near.clone();
if min_dist > self.max_distance {
let t = self.max_distance / min_dist;
pd.space.interpolate(q_near, &q_rand, t, &mut q_new);
} else {
q_new = q_rand;
}
if !self.check_motion(q_near, &q_new) {
continue;
}
let temp_node = Node {
state: q_new.clone(),
parent_index: None,
cost: 0.0,
};
let neighbours: Vec<usize> = self.find_neighbours(&temp_node);
let mut best_parent_index = nearest_node_index;
let q_near_node = &self.tree[nearest_node_index];
let mut min_cost = self.cost(&temp_node, q_near_node);
for &neighbour_idx in &neighbours {
let neighbour_node = &self.tree[neighbour_idx];
let cost_via_neighbour = self.cost(&temp_node, neighbour_node);
if cost_via_neighbour < min_cost && self.check_motion(&neighbour_node.state, &q_new)
{
min_cost = cost_via_neighbour;
best_parent_index = neighbour_idx;
}
}
let new_node = Node {
state: q_new.clone(),
parent_index: Some(best_parent_index),
cost: min_cost,
};
self.tree.push(new_node);
let new_node_index = self.tree.len() - 1;
for &neighbour_idx in &neighbours {
let new_node_ref = &self.tree[new_node_index];
let neighbour_node = &self.tree[neighbour_idx];
if new_node_ref.parent_index == Some(neighbour_idx) {
continue;
}
let cost_via_new_node = self.cost(neighbour_node, new_node_ref);
if cost_via_new_node < neighbour_node.cost
&& self.check_motion(&new_node_ref.state, &neighbour_node.state)
{
let mutable_neighbour_node = &mut self.tree[neighbour_idx];
mutable_neighbour_node.parent_index = Some(new_node_index);
mutable_neighbour_node.cost = cost_via_new_node;
}
}
if goal.is_satisfied(&q_new) {
println!("Solution found after {} nodes.", self.tree.len());
return Ok(self.reconstruct_path(self.tree.len() - 1));
}
}
}
}