1pub use petgraph::stable_graph::{EdgeIndex, NodeIndex};
4
5use std::collections::{HashSet, VecDeque};
8use std::marker::PhantomData;
9use std::ops::{Index, IndexMut};
10
11use nalgebra::SVector;
12use num_traits::{Float, Zero};
13use petgraph::stable_graph::{
14 DefaultIx, EdgeIndices, EdgeReference, Neighbors, NodeIndices, StableDiGraph,
15 WalkNeighbors,
16};
17use petgraph::visit::EdgeRef;
18use petgraph::Direction;
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20
21use crate::scalar::Scalar;
22use crate::trajectories::{FullTrajRefOwned, Trajectory};
23
24#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
26#[serde(bound(
27 serialize = "X: Serialize",
28 deserialize = "X: DeserializeOwned"
29))]
30pub struct Node<X: Scalar, const N: usize> {
31 state: SVector<X, N>,
33 pub cost: X,
35}
36
37impl<X: Scalar + Float, const N: usize> Node<X, N> {
38 pub fn new(state: SVector<X, N>) -> Self {
39 Self {
40 state,
41 cost: X::infinity(),
42 }
43 }
44}
45
46impl<X: Scalar, const N: usize> Node<X, N> {
47 fn with_cost(state: SVector<X, N>, cost: X) -> Self {
48 Self { state, cost }
49 }
50
51 pub fn state(&self) -> &SVector<X, N> {
52 &self.state
53 }
54}
55
56pub struct NodeIter<'a, X: Scalar, const N: usize> {
58 nodes: NodeIndices<'a, Node<X, N>>,
59}
60
61impl<'a, X: Scalar, const N: usize> NodeIter<'a, X, N> {
62 fn new<T>(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self
63 where
64 T: Trajectory<X, N>,
65 {
66 Self {
67 nodes: graph.node_indices(),
68 }
69 }
70}
71
72impl<'a, X: Scalar, const N: usize> Iterator for NodeIter<'a, X, N> {
73 type Item = NodeIndex;
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.nodes.next()
77 }
78}
79
80pub struct EdgeIter<'a, X, T, const N: usize>
82where
83 T: Trajectory<X, N>,
84{
85 edges: EdgeIndices<'a, T>,
86 phantom_x: PhantomData<X>,
87}
88
89impl<'a, X, T, const N: usize> EdgeIter<'a, X, T, N>
90where
91 X: Scalar,
92 T: Trajectory<X, N>,
93{
94 fn new(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self {
95 Self {
96 edges: graph.edge_indices(),
97 phantom_x: PhantomData,
98 }
99 }
100}
101
102impl<'a, X, T, const N: usize> Iterator for EdgeIter<'a, X, T, N>
103where
104 T: Trajectory<X, N>,
105{
106 type Item = EdgeIndex;
107
108 fn next(&mut self) -> Option<Self::Item> {
109 self.edges.next()
110 }
111}
112
113pub struct OptimalPathIter<'a, X, T, const N: usize>
115where
116 X: Scalar,
117 T: Trajectory<X, N>,
118{
119 graph: &'a RrtStarTree<X, T, N>,
120 next_node: Option<NodeIndex>,
121}
122
123impl<'a, X, T, const N: usize> OptimalPathIter<'a, X, T, N>
124where
125 X: Scalar,
126 T: Trajectory<X, N>,
127{
128 fn new(graph: &'a RrtStarTree<X, T, N>, node: NodeIndex) -> Self {
129 Self {
130 graph,
131 next_node: Some(node),
132 }
133 }
134
135 pub fn detach(self) -> OptimalPathWalker {
136 OptimalPathWalker::new(self.next_node)
137 }
138}
139
140impl<'a, X, T, const N: usize> Iterator for OptimalPathIter<'a, X, T, N>
141where
142 X: Scalar,
143 T: Trajectory<X, N>,
144{
145 type Item = NodeIndex;
146
147 fn next(&mut self) -> Option<Self::Item> {
148 let node = self.next_node?;
149 match self.graph.parent(node) {
150 Some(parent) => self.next_node = Some(parent),
151 None => self.next_node = None,
152 }
153
154 Some(node)
155 }
156}
157
158pub struct OptimalPathWalker {
160 next_node: Option<NodeIndex>,
161}
162
163impl OptimalPathWalker {
164 fn new(node: Option<NodeIndex>) -> Self {
165 Self { next_node: node }
166 }
167
168 pub fn next<X, T, const N: usize>(
169 &mut self,
170 g: &RrtStarTree<X, T, N>,
171 ) -> Option<NodeIndex>
172 where
173 X: Scalar,
174 T: Trajectory<X, N>,
175 {
176 let node = self.next_node?;
177 match g.parent(node) {
178 Some(parent) => self.next_node = Some(parent),
179 None => self.next_node = None,
180 }
181
182 Some(node)
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188#[serde(bound(
189 serialize = "X: Serialize, T: Serialize",
190 deserialize = "X: DeserializeOwned, T: DeserializeOwned",
191))]
192pub struct RrtStarTree<X, T, const N: usize>
193where
194 X: Scalar,
195 T: Trajectory<X, N>,
196{
197 goal_idx: NodeIndex,
198 graph: StableDiGraph<Node<X, N>, T>,
199 orphans: HashSet<NodeIndex>,
200}
201
202impl<X, T, const N: usize> RrtStarTree<X, T, N>
203where
204 X: Scalar + Zero,
205 T: Trajectory<X, N>,
206{
207 pub fn new(goal: SVector<X, N>) -> Self {
209 let mut graph = StableDiGraph::new();
210 let goal_node = Node::with_cost(goal, X::zero());
211 let goal_idx = graph.add_node(goal_node);
212
213 let orphans = HashSet::new();
214
215 Self {
216 goal_idx,
217 graph,
218 orphans,
219 }
220 }
221}
222
223impl<X, T, const N: usize> RrtStarTree<X, T, N>
224where
225 X: Scalar,
226 T: Trajectory<X, N>,
227{
228 pub fn node_count(&self) -> usize {
230 self.graph.node_count()
231 }
232
233 pub fn get_goal_idx(&self) -> NodeIndex {
235 self.goal_idx
236 }
237
238 pub fn get_goal(&self) -> &Node<X, N> {
240 self.get_node(self.goal_idx)
241 }
242
243 pub fn all_nodes(&self) -> NodeIter<X, N> {
245 NodeIter::new(&self.graph)
246 }
247
248 pub fn all_edges(&self) -> EdgeIter<X, T, N> {
250 EdgeIter::new(&self.graph)
251 }
252
253 pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
256 Some(self.parent_edge(node)?.target())
257 }
258
259 fn parent_edge(&self, node: NodeIndex) -> Option<EdgeReference<T>> {
261 self.graph.edges_directed(node, Direction::Outgoing).next()
262 }
263
264 pub fn is_parent(&self, node: NodeIndex, parent: NodeIndex) -> bool {
266 self.graph.find_edge(node, parent).is_some()
267 }
268
269 pub fn children(&self, node: NodeIndex) -> Neighbors<T, DefaultIx> {
272 self.graph.neighbors_directed(node, Direction::Incoming)
273 }
274
275 pub fn children_walker(&self, node: NodeIndex) -> WalkNeighbors<DefaultIx> {
278 self
279 .graph
280 .neighbors_directed(node, Direction::Incoming)
281 .detach()
282 }
283
284 pub fn is_child(&self, node: NodeIndex, child: NodeIndex) -> bool {
286 self.is_parent(child, node)
287 }
288
289 pub fn add_orphan(&mut self, node: NodeIndex) {
294 let mut queue = VecDeque::new();
296 queue.push_back(node);
297
298 while let Some(node) = queue.pop_front() {
300 if self.orphans.insert(node) {
302 let mut children = self.children_walker(node);
304 while let Some(child_idx) = children.next_node(&self.graph) {
305 queue.push_back(child_idx);
306 }
307 }
308 }
309 }
310
311 pub fn remove_orphan(&mut self, node: NodeIndex) {
313 self.orphans.remove(&node);
314 }
315
316 pub fn is_orphan(&self, node: NodeIndex) -> bool {
318 self.orphans.contains(&node)
319 }
320
321 pub fn orphans(&self) -> impl Iterator<Item = NodeIndex> + '_ {
323 self.orphans.iter().map(|&x| x)
324 }
325
326 pub fn clear_orphans(&mut self) {
328 let orphans: Vec<_> = self.orphans().collect();
329 for orphan_idx in orphans {
330 self.graph.remove_node(orphan_idx);
331 }
332 self.orphans.clear();
333 }
334
335 pub fn add_node(
337 &mut self,
338 node: Node<X, N>,
339 parent: NodeIndex,
340 trajectory: T,
341 ) -> (NodeIndex, EdgeIndex) {
342 let node_idx = self.graph.add_node(node);
343 let edge_idx = self.update_edge(node_idx, parent, trajectory);
344 (node_idx, edge_idx)
345 }
346
347 pub fn update_edge(
349 &mut self,
350 node: NodeIndex,
351 new_parent: NodeIndex,
352 new_trajectory: T,
353 ) -> EdgeIndex {
354 self.remove_any_parents(node);
355 self.graph.update_edge(node, new_parent, new_trajectory)
356 }
357
358 fn remove_any_parents(&mut self, node: NodeIndex) -> bool {
361 let edges = self
362 .graph
363 .edges_directed(node, Direction::Outgoing)
364 .map(|edge_ref| edge_ref.id());
365 let edges: Vec<_> = edges.collect();
366
367 let removed = edges.len() > 0;
368 for edge_idx in edges {
369 self.graph.remove_edge(edge_idx);
370 }
371 removed
372 }
373
374 pub fn get_optimal_path(
377 &self,
378 node: NodeIndex,
379 ) -> Option<OptimalPathIter<X, T, N>> {
380 match self.is_orphan(node) {
381 true => None,
382 false => Some(OptimalPathIter::new(self, node)),
383 }
384 }
385
386 pub fn get_node(&self, idx: NodeIndex) -> &Node<X, N> {
390 self.graph.index(idx)
391 }
392
393 pub fn get_node_mut(&mut self, idx: NodeIndex) -> &mut Node<X, N> {
397 self.graph.index_mut(idx)
398 }
399
400 pub fn get_edge(&self, idx: EdgeIndex) -> &T {
404 self.graph.index(idx)
405 }
406
407 pub fn get_endpoints(&self, idx: EdgeIndex) -> (NodeIndex, NodeIndex) {
409 self.graph.edge_endpoints(idx).unwrap()
410 }
411
412 pub fn get_trajectory(&self, idx: EdgeIndex) -> FullTrajRefOwned<X, T, N> {
416 let (start_idx, end_idx) = self.get_endpoints(idx);
417 let start = self.get_node(start_idx).state();
418 let end = self.get_node(end_idx).state();
419 let traj_data = self.get_edge(idx);
420 FullTrajRefOwned::new(start, end, traj_data)
421 }
422}
423
424#[cfg(test)]
425mod tests {
426
427 use crate::trajectories::EuclideanTrajectory;
428
429 use super::*;
430
431 #[test]
432 fn test_rrt_star_tree_parent() {
433 let goal_coord = [1.5, 1.5].into();
434
435 let mut g = RrtStarTree::new(goal_coord);
436 let goal = g.get_goal_idx();
437
438 let n1_coord = [2.0, 2.0].into();
439 let n1 = Node {
440 state: n1_coord,
441 cost: 0.1,
442 };
443 let n1 = g.graph.add_node(n1);
444
445 g.update_edge(n1, goal, EuclideanTrajectory::new());
446
447 let n2_coord = [-2.0, -2.0].into();
448 let n2 = Node {
449 state: n2_coord,
450 cost: 0.5,
451 };
452 let n2 = g.graph.add_node(n2);
453
454 g.update_edge(n2, goal, EuclideanTrajectory::new());
455
456 assert_eq!(g.parent(goal), None);
458 assert_eq!(g.parent(n1), Some(goal));
459 assert_eq!(g.parent(n2), Some(goal));
460 }
461
462 #[test]
463 fn test_rrt_star_tree_children() {
464 let goal_coord = [1.5, 1.5].into();
465
466 let mut g = RrtStarTree::new(goal_coord);
467 let goal = g.get_goal_idx();
468
469 let n1_coord = [2.0, 2.0].into();
470 let n1 = Node {
471 state: n1_coord,
472 cost: 0.1,
473 };
474 let n1 = g.graph.add_node(n1);
475
476 g.update_edge(n1, goal, EuclideanTrajectory::new());
477
478 let n2_coord = [-2.0, -2.0].into();
479 let n2 = Node {
480 state: n2_coord,
481 cost: 0.5,
482 };
483 let n2 = g.graph.add_node(n2);
484
485 g.update_edge(n2, goal, EuclideanTrajectory::new());
486
487 let mut iter = g.children(goal);
489 assert_eq!(iter.next(), Some(n2));
490 assert_eq!(iter.next(), Some(n1));
491 assert_eq!(iter.next(), None);
492
493 let mut iter = g.children(n1);
494 assert_eq!(iter.next(), None);
495
496 let mut iter = g.children(n2);
497 assert_eq!(iter.next(), None);
498 }
499
500 #[test]
501 fn test_rrt_star_tree_children_walker() {
502 let goal_coord = [1.5, 1.5].into();
503
504 let mut g = RrtStarTree::new(goal_coord);
505 let goal = g.get_goal_idx();
506
507 let n1_coord = [2.0, 2.0].into();
508 let n1 = Node {
509 state: n1_coord,
510 cost: 0.1,
511 };
512 let n1 = g.graph.add_node(n1);
513
514 g.update_edge(n1, goal, EuclideanTrajectory::new());
515
516 let n2_coord = [-2.0, -2.0].into();
517 let n2 = Node {
518 state: n2_coord,
519 cost: 0.5,
520 };
521 let n2 = g.graph.add_node(n2);
522
523 g.update_edge(n2, goal, EuclideanTrajectory::new());
524
525 let mut iter = g.children_walker(goal);
527 assert_eq!(iter.next_node(&g.graph), Some(n2));
528 assert_eq!(iter.next_node(&g.graph), Some(n1));
529 assert_eq!(iter.next_node(&g.graph), None);
530
531 let mut iter = g.children_walker(n1);
532 assert_eq!(iter.next_node(&g.graph), None);
533
534 let mut iter = g.children_walker(n2);
535 assert_eq!(iter.next_node(&g.graph), None);
536 }
537}