path_planning/rrt_star/
rrt_star_tree.rs

1/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */
2
3pub use petgraph::stable_graph::{EdgeIndex, NodeIndex};
4
5////////////////////////////////////////////////////////////////////////////////
6
7use std::collections::{HashSet, VecDeque};
8use std::marker::PhantomData;
9use std::ops::{Index, IndexMut};
10
11use nalgebra::SVector;
12use num_traits::{Float, Zero};
13use petgraph::stable_graph::{
14  DefaultIx, EdgeIndices, EdgeReference, Neighbors, NodeIndices, StableDiGraph,
15  WalkNeighbors,
16};
17use petgraph::visit::EdgeRef;
18use petgraph::Direction;
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20
21use crate::scalar::Scalar;
22use crate::trajectories::{FullTrajRefOwned, Trajectory};
23
24/// A node in the RrtStar graph. Stores information about a particular state in the state space.
25#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
26#[serde(bound(
27  serialize = "X: Serialize",
28  deserialize = "X: DeserializeOwned"
29))]
30pub struct Node<X: Scalar, const N: usize> {
31  /// The point in the state space this node is at (immutable)
32  state: SVector<X, N>,
33  /// The cost data for this node (mutable)
34  pub cost: X,
35}
36
37impl<X: Scalar + Float, const N: usize> Node<X, N> {
38  pub fn new(state: SVector<X, N>) -> Self {
39    Self {
40      state,
41      cost: X::infinity(),
42    }
43  }
44}
45
46impl<X: Scalar, const N: usize> Node<X, N> {
47  fn with_cost(state: SVector<X, N>, cost: X) -> Self {
48    Self { state, cost }
49  }
50
51  pub fn state(&self) -> &SVector<X, N> {
52    &self.state
53  }
54}
55
56/// Iterator over all the node indices of the graph
57pub struct NodeIter<'a, X: Scalar, const N: usize> {
58  nodes: NodeIndices<'a, Node<X, N>>,
59}
60
61impl<'a, X: Scalar, const N: usize> NodeIter<'a, X, N> {
62  fn new<T>(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self
63  where
64    T: Trajectory<X, N>,
65  {
66    Self {
67      nodes: graph.node_indices(),
68    }
69  }
70}
71
72impl<'a, X: Scalar, const N: usize> Iterator for NodeIter<'a, X, N> {
73  type Item = NodeIndex;
74
75  fn next(&mut self) -> Option<Self::Item> {
76    self.nodes.next()
77  }
78}
79
80/// Iterator over all the edge indices of the graph
81pub struct EdgeIter<'a, X, T, const N: usize>
82where
83  T: Trajectory<X, N>,
84{
85  edges: EdgeIndices<'a, T>,
86  phantom_x: PhantomData<X>,
87}
88
89impl<'a, X, T, const N: usize> EdgeIter<'a, X, T, N>
90where
91  X: Scalar,
92  T: Trajectory<X, N>,
93{
94  fn new(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self {
95    Self {
96      edges: graph.edge_indices(),
97      phantom_x: PhantomData,
98    }
99  }
100}
101
102impl<'a, X, T, const N: usize> Iterator for EdgeIter<'a, X, T, N>
103where
104  T: Trajectory<X, N>,
105{
106  type Item = EdgeIndex;
107
108  fn next(&mut self) -> Option<Self::Item> {
109    self.edges.next()
110  }
111}
112
113/// Iterator over edge indices of the graph in the optimal subtree
114pub struct OptimalPathIter<'a, X, T, const N: usize>
115where
116  X: Scalar,
117  T: Trajectory<X, N>,
118{
119  graph: &'a RrtStarTree<X, T, N>,
120  next_node: Option<NodeIndex>,
121}
122
123impl<'a, X, T, const N: usize> OptimalPathIter<'a, X, T, N>
124where
125  X: Scalar,
126  T: Trajectory<X, N>,
127{
128  fn new(graph: &'a RrtStarTree<X, T, N>, node: NodeIndex) -> Self {
129    Self {
130      graph,
131      next_node: Some(node),
132    }
133  }
134
135  pub fn detach(self) -> OptimalPathWalker {
136    OptimalPathWalker::new(self.next_node)
137  }
138}
139
140impl<'a, X, T, const N: usize> Iterator for OptimalPathIter<'a, X, T, N>
141where
142  X: Scalar,
143  T: Trajectory<X, N>,
144{
145  type Item = NodeIndex;
146
147  fn next(&mut self) -> Option<Self::Item> {
148    let node = self.next_node?;
149    match self.graph.parent(node) {
150      Some(parent) => self.next_node = Some(parent),
151      None => self.next_node = None,
152    }
153
154    Some(node)
155  }
156}
157
158/// Iterator over edge indices of the graph in the optimal subtree
159pub struct OptimalPathWalker {
160  next_node: Option<NodeIndex>,
161}
162
163impl OptimalPathWalker {
164  fn new(node: Option<NodeIndex>) -> Self {
165    Self { next_node: node }
166  }
167
168  pub fn next<X, T, const N: usize>(
169    &mut self,
170    g: &RrtStarTree<X, T, N>,
171  ) -> Option<NodeIndex>
172  where
173    X: Scalar,
174    T: Trajectory<X, N>,
175  {
176    let node = self.next_node?;
177    match g.parent(node) {
178      Some(parent) => self.next_node = Some(parent),
179      None => self.next_node = None,
180    }
181
182    Some(node)
183  }
184}
185
186/// Tree structure for holding the optimal tree
187#[derive(Debug, Clone, Serialize, Deserialize)]
188#[serde(bound(
189  serialize = "X: Serialize, T: Serialize",
190  deserialize = "X: DeserializeOwned, T: DeserializeOwned",
191))]
192pub struct RrtStarTree<X, T, const N: usize>
193where
194  X: Scalar,
195  T: Trajectory<X, N>,
196{
197  goal_idx: NodeIndex,
198  graph: StableDiGraph<Node<X, N>, T>,
199  orphans: HashSet<NodeIndex>,
200}
201
202impl<X, T, const N: usize> RrtStarTree<X, T, N>
203where
204  X: Scalar + Zero,
205  T: Trajectory<X, N>,
206{
207  /// Creates new graph with goal as the root node with cost zero
208  pub fn new(goal: SVector<X, N>) -> Self {
209    let mut graph = StableDiGraph::new();
210    let goal_node = Node::with_cost(goal, X::zero());
211    let goal_idx = graph.add_node(goal_node);
212
213    let orphans = HashSet::new();
214
215    Self {
216      goal_idx,
217      graph,
218      orphans,
219    }
220  }
221}
222
223impl<X, T, const N: usize> RrtStarTree<X, T, N>
224where
225  X: Scalar,
226  T: Trajectory<X, N>,
227{
228  /// Returns the number of nodes in the graph
229  pub fn node_count(&self) -> usize {
230    self.graph.node_count()
231  }
232
233  /// Returns the index of the goal node
234  pub fn get_goal_idx(&self) -> NodeIndex {
235    self.goal_idx
236  }
237
238  /// Returns reference to the goal node
239  pub fn get_goal(&self) -> &Node<X, N> {
240    self.get_node(self.goal_idx)
241  }
242
243  /// Returns iterator over all nodes in the graph
244  pub fn all_nodes(&self) -> NodeIter<X, N> {
245    NodeIter::new(&self.graph)
246  }
247
248  /// Returns iterator over all edges in the graph
249  pub fn all_edges(&self) -> EdgeIter<X, T, N> {
250    EdgeIter::new(&self.graph)
251  }
252
253  /// Returns the NodeIndex of the parent of a in the optimal subtree if exists
254  /// Returns None if no parent edge
255  pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
256    Some(self.parent_edge(node)?.target())
257  }
258
259  /// Returns the edge directed at the nodes parent in the optimal subtree if exists
260  fn parent_edge(&self, node: NodeIndex) -> Option<EdgeReference<T>> {
261    self.graph.edges_directed(node, Direction::Outgoing).next()
262  }
263
264  /// Looks up edge from node -> parent to see if parent is the parent of node
265  pub fn is_parent(&self, node: NodeIndex, parent: NodeIndex) -> bool {
266    self.graph.find_edge(node, parent).is_some()
267  }
268
269  /// Returns iterator over all children of a in the optimal subtree.
270  /// Iterator will be empty for leaf nodes
271  pub fn children(&self, node: NodeIndex) -> Neighbors<T, DefaultIx> {
272    self.graph.neighbors_directed(node, Direction::Incoming)
273  }
274
275  /// Returns walker over all children of a in the optimal subtree.
276  /// Iterator will be empty for leaf nodes
277  pub fn children_walker(&self, node: NodeIndex) -> WalkNeighbors<DefaultIx> {
278    self
279      .graph
280      .neighbors_directed(node, Direction::Incoming)
281      .detach()
282  }
283
284  /// Looks up edge from child -> node to see if node is the parent of child
285  pub fn is_child(&self, node: NodeIndex, child: NodeIndex) -> bool {
286    self.is_parent(child, node)
287  }
288
289  /// Add node and children of node to the internal set of orphans
290  ///
291  /// This garentees that the children of any orphan node are also marked as orphans
292  /// as soon as they are added
293  pub fn add_orphan(&mut self, node: NodeIndex) {
294    // Add all childern node as orphans
295    let mut queue = VecDeque::new();
296    queue.push_back(node);
297
298    // While there are still nodes to process
299    while let Some(node) = queue.pop_front() {
300      // Set node as orphan
301      if self.orphans.insert(node) {
302        // if newly inserted, add childern to queue
303        let mut children = self.children_walker(node);
304        while let Some(child_idx) = children.next_node(&self.graph) {
305          queue.push_back(child_idx);
306        }
307      }
308    }
309  }
310
311  /// Remove node from the internal list of orphans, doesn't modifiy the graph
312  pub fn remove_orphan(&mut self, node: NodeIndex) {
313    self.orphans.remove(&node);
314  }
315
316  /// Checks if node is in the internal set of orphans
317  pub fn is_orphan(&self, node: NodeIndex) -> bool {
318    self.orphans.contains(&node)
319  }
320
321  /// Get iterator over all orphans
322  pub fn orphans(&self) -> impl Iterator<Item = NodeIndex> + '_ {
323    self.orphans.iter().map(|&x| x)
324  }
325
326  /// Remove all orphan nodes from the graph and the orphan set
327  pub fn clear_orphans(&mut self) {
328    let orphans: Vec<_> = self.orphans().collect();
329    for orphan_idx in orphans {
330      self.graph.remove_node(orphan_idx);
331    }
332    self.orphans.clear();
333  }
334
335  /// Adds a node to the graph and returns it's index in addition fo the edge index to the parent
336  pub fn add_node(
337    &mut self,
338    node: Node<X, N>,
339    parent: NodeIndex,
340    trajectory: T,
341  ) -> (NodeIndex, EdgeIndex) {
342    let node_idx = self.graph.add_node(node);
343    let edge_idx = self.update_edge(node_idx, parent, trajectory);
344    (node_idx, edge_idx)
345  }
346
347  /// Sets the new parent and removes any existing parent
348  pub fn update_edge(
349    &mut self,
350    node: NodeIndex,
351    new_parent: NodeIndex,
352    new_trajectory: T,
353  ) -> EdgeIndex {
354    self.remove_any_parents(node);
355    self.graph.update_edge(node, new_parent, new_trajectory)
356  }
357
358  /// The given node will have no outgoing edges after this function
359  /// Returns true if any parents were removed
360  fn remove_any_parents(&mut self, node: NodeIndex) -> bool {
361    let edges = self
362      .graph
363      .edges_directed(node, Direction::Outgoing)
364      .map(|edge_ref| edge_ref.id());
365    let edges: Vec<_> = edges.collect();
366
367    let removed = edges.len() > 0;
368    for edge_idx in edges {
369      self.graph.remove_edge(edge_idx);
370    }
371    removed
372  }
373
374  /// Returns an iterator of edges over the optimal path to the goal node
375  /// Returns None if no such path exists
376  pub fn get_optimal_path(
377    &self,
378    node: NodeIndex,
379  ) -> Option<OptimalPathIter<X, T, N>> {
380    match self.is_orphan(node) {
381      true => None,
382      false => Some(OptimalPathIter::new(self, node)),
383    }
384  }
385
386  /// Returns a refernce the specified node
387  ///
388  /// panics if `idx` is invalid
389  pub fn get_node(&self, idx: NodeIndex) -> &Node<X, N> {
390    self.graph.index(idx)
391  }
392
393  /// Returns a mutable refernce the specified node
394  ///
395  /// panics if `idx` is invalid
396  pub fn get_node_mut(&mut self, idx: NodeIndex) -> &mut Node<X, N> {
397    self.graph.index_mut(idx)
398  }
399
400  /// Returns a refernce the specified edge
401  ///
402  /// panics if `idx` is invalid
403  pub fn get_edge(&self, idx: EdgeIndex) -> &T {
404    self.graph.index(idx)
405  }
406
407  /// Returns source and target endpoints of an edge
408  pub fn get_endpoints(&self, idx: EdgeIndex) -> (NodeIndex, NodeIndex) {
409    self.graph.edge_endpoints(idx).unwrap()
410  }
411
412  /// Returns the trajectory stored at the specified edge
413  ///
414  /// panics if `idx` is invalid
415  pub fn get_trajectory(&self, idx: EdgeIndex) -> FullTrajRefOwned<X, T, N> {
416    let (start_idx, end_idx) = self.get_endpoints(idx);
417    let start = self.get_node(start_idx).state();
418    let end = self.get_node(end_idx).state();
419    let traj_data = self.get_edge(idx);
420    FullTrajRefOwned::new(start, end, traj_data)
421  }
422}
423
424#[cfg(test)]
425mod tests {
426
427  use crate::trajectories::EuclideanTrajectory;
428
429  use super::*;
430
431  #[test]
432  fn test_rrt_star_tree_parent() {
433    let goal_coord = [1.5, 1.5].into();
434
435    let mut g = RrtStarTree::new(goal_coord);
436    let goal = g.get_goal_idx();
437
438    let n1_coord = [2.0, 2.0].into();
439    let n1 = Node {
440      state: n1_coord,
441      cost: 0.1,
442    };
443    let n1 = g.graph.add_node(n1);
444
445    g.update_edge(n1, goal, EuclideanTrajectory::new());
446
447    let n2_coord = [-2.0, -2.0].into();
448    let n2 = Node {
449      state: n2_coord,
450      cost: 0.5,
451    };
452    let n2 = g.graph.add_node(n2);
453
454    g.update_edge(n2, goal, EuclideanTrajectory::new());
455
456    // Parents
457    assert_eq!(g.parent(goal), None);
458    assert_eq!(g.parent(n1), Some(goal));
459    assert_eq!(g.parent(n2), Some(goal));
460  }
461
462  #[test]
463  fn test_rrt_star_tree_children() {
464    let goal_coord = [1.5, 1.5].into();
465
466    let mut g = RrtStarTree::new(goal_coord);
467    let goal = g.get_goal_idx();
468
469    let n1_coord = [2.0, 2.0].into();
470    let n1 = Node {
471      state: n1_coord,
472      cost: 0.1,
473    };
474    let n1 = g.graph.add_node(n1);
475
476    g.update_edge(n1, goal, EuclideanTrajectory::new());
477
478    let n2_coord = [-2.0, -2.0].into();
479    let n2 = Node {
480      state: n2_coord,
481      cost: 0.5,
482    };
483    let n2 = g.graph.add_node(n2);
484
485    g.update_edge(n2, goal, EuclideanTrajectory::new());
486
487    // Childern
488    let mut iter = g.children(goal);
489    assert_eq!(iter.next(), Some(n2));
490    assert_eq!(iter.next(), Some(n1));
491    assert_eq!(iter.next(), None);
492
493    let mut iter = g.children(n1);
494    assert_eq!(iter.next(), None);
495
496    let mut iter = g.children(n2);
497    assert_eq!(iter.next(), None);
498  }
499
500  #[test]
501  fn test_rrt_star_tree_children_walker() {
502    let goal_coord = [1.5, 1.5].into();
503
504    let mut g = RrtStarTree::new(goal_coord);
505    let goal = g.get_goal_idx();
506
507    let n1_coord = [2.0, 2.0].into();
508    let n1 = Node {
509      state: n1_coord,
510      cost: 0.1,
511    };
512    let n1 = g.graph.add_node(n1);
513
514    g.update_edge(n1, goal, EuclideanTrajectory::new());
515
516    let n2_coord = [-2.0, -2.0].into();
517    let n2 = Node {
518      state: n2_coord,
519      cost: 0.5,
520    };
521    let n2 = g.graph.add_node(n2);
522
523    g.update_edge(n2, goal, EuclideanTrajectory::new());
524
525    // Childern Walker
526    let mut iter = g.children_walker(goal);
527    assert_eq!(iter.next_node(&g.graph), Some(n2));
528    assert_eq!(iter.next_node(&g.graph), Some(n1));
529    assert_eq!(iter.next_node(&g.graph), None);
530
531    let mut iter = g.children_walker(n1);
532    assert_eq!(iter.next_node(&g.graph), None);
533
534    let mut iter = g.children_walker(n2);
535    assert_eq!(iter.next_node(&g.graph), None);
536  }
537}