1pub use petgraph::stable_graph::{EdgeIndex, NodeIndex};
4
5use std::collections::HashSet;
8use std::marker::PhantomData;
9use std::ops::{Index, IndexMut};
10
11use nalgebra::SVector;
12use petgraph::stable_graph::{
13 DefaultIx, EdgeIndices, EdgeReference, Neighbors, NodeIndices, StableDiGraph,
14 WalkNeighbors,
15};
16use petgraph::visit::EdgeRef;
17use petgraph::Direction;
18use serde::{de::DeserializeOwned, Deserialize, Serialize};
19
20use crate::scalar::Scalar;
21use crate::trajectories::{FullTrajRefOwned, Trajectory};
22
23pub struct NodeIter<'a, X, const N: usize> {
25 nodes: NodeIndices<'a, SVector<X, N>>,
26}
27
28impl<'a, X, const N: usize> NodeIter<'a, X, N> {
29 fn new<T>(graph: &'a StableDiGraph<SVector<X, N>, T>) -> Self
30 where
31 T: Trajectory<X, N>,
32 {
33 Self {
34 nodes: graph.node_indices(),
35 }
36 }
37}
38
39impl<'a, X, const N: usize> Iterator for NodeIter<'a, X, N> {
40 type Item = NodeIndex;
41
42 fn next(&mut self) -> Option<Self::Item> {
43 self.nodes.next()
44 }
45}
46
47pub struct EdgeIter<'a, X, T, const N: usize>
49where
50 T: Trajectory<X, N>,
51{
52 edges: EdgeIndices<'a, T>,
53 phantom_x: PhantomData<X>,
54}
55
56impl<'a, X, T, const N: usize> EdgeIter<'a, X, T, N>
57where
58 X: Scalar,
59 T: Trajectory<X, N>,
60{
61 fn new(graph: &'a StableDiGraph<SVector<X, N>, T>) -> Self {
62 Self {
63 edges: graph.edge_indices(),
64 phantom_x: PhantomData,
65 }
66 }
67}
68
69impl<'a, X, T, const N: usize> Iterator for EdgeIter<'a, X, T, N>
70where
71 T: Trajectory<X, N>,
72{
73 type Item = EdgeIndex;
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.edges.next()
77 }
78}
79
80pub struct OptimalPathIter<'a, X, T, const N: usize>
82where
83 X: Scalar,
84 T: Trajectory<X, N>,
85{
86 graph: &'a RrtTree<X, T, N>,
87 next_node: Option<NodeIndex>,
88}
89
90impl<'a, X, T, const N: usize> OptimalPathIter<'a, X, T, N>
91where
92 X: Scalar,
93 T: Trajectory<X, N>,
94{
95 fn new(graph: &'a RrtTree<X, T, N>, node: NodeIndex) -> Self {
96 Self {
97 graph,
98 next_node: Some(node),
99 }
100 }
101
102 pub fn detach(self) -> OptimalPathWalker {
103 OptimalPathWalker::new(self.next_node)
104 }
105}
106
107impl<'a, X, T, const N: usize> Iterator for OptimalPathIter<'a, X, T, N>
108where
109 X: Scalar,
110 T: Trajectory<X, N>,
111{
112 type Item = NodeIndex;
113
114 fn next(&mut self) -> Option<Self::Item> {
115 let node = self.next_node?;
116 match self.graph.parent(node) {
117 Some(parent) => self.next_node = Some(parent),
118 None => self.next_node = None,
119 }
120
121 Some(node)
122 }
123}
124
125pub struct OptimalPathWalker {
127 next_node: Option<NodeIndex>,
128}
129
130impl OptimalPathWalker {
131 fn new(node: Option<NodeIndex>) -> Self {
132 Self { next_node: node }
133 }
134
135 pub fn next<X, T, const N: usize>(
136 &mut self,
137 g: &RrtTree<X, T, N>,
138 ) -> Option<NodeIndex>
139 where
140 X: Scalar,
141 T: Trajectory<X, N>,
142 {
143 let node = self.next_node?;
144 match g.parent(node) {
145 Some(parent) => self.next_node = Some(parent),
146 None => self.next_node = None,
147 }
148
149 Some(node)
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155#[serde(bound(
156 serialize = "X: Serialize, T: Serialize",
157 deserialize = "X: DeserializeOwned, T: DeserializeOwned",
158))]
159pub struct RrtTree<X, T, const N: usize>
160where
161 X: Scalar,
162 T: Trajectory<X, N>,
163{
164 goal_idx: NodeIndex,
165 graph: StableDiGraph<SVector<X, N>, T>,
166 orphans: HashSet<NodeIndex>,
167 #[serde(skip)]
168 phantom_x: PhantomData<X>,
169}
170
171impl<X, T, const N: usize> RrtTree<X, T, N>
172where
173 X: Scalar,
174 T: Trajectory<X, N>,
175{
176 pub fn new(goal: SVector<X, N>) -> Self {
178 let mut graph = StableDiGraph::new();
179 let goal_idx = graph.add_node(goal);
180
181 let orphans = HashSet::new();
182
183 Self {
184 goal_idx,
185 graph,
186 orphans,
187 phantom_x: PhantomData,
188 }
189 }
190
191 pub fn node_count(&self) -> usize {
193 self.graph.node_count()
194 }
195
196 pub fn get_goal_idx(&self) -> NodeIndex {
198 self.goal_idx
199 }
200
201 pub fn get_goal(&self) -> &SVector<X, N> {
203 self.get_node(self.goal_idx)
204 }
205
206 pub fn all_nodes(&self) -> NodeIter<X, N> {
208 NodeIter::new(&self.graph)
209 }
210
211 pub fn all_edges(&self) -> EdgeIter<X, T, N> {
213 EdgeIter::new(&self.graph)
214 }
215
216 pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
219 Some(self.parent_edge(node)?.target())
220 }
221
222 fn parent_edge(&self, node: NodeIndex) -> Option<EdgeReference<T>> {
224 self.graph.edges_directed(node, Direction::Outgoing).next()
225 }
226
227 pub fn is_parent(&self, node: NodeIndex, parent: NodeIndex) -> bool {
229 self.graph.find_edge(node, parent).is_some()
230 }
231
232 pub fn children(&self, node: NodeIndex) -> Neighbors<T, DefaultIx> {
235 self.graph.neighbors_directed(node, Direction::Incoming)
236 }
237
238 pub fn children_walker(&self, node: NodeIndex) -> WalkNeighbors<DefaultIx> {
241 self
242 .graph
243 .neighbors_directed(node, Direction::Incoming)
244 .detach()
245 }
246
247 pub fn is_child(&self, node: NodeIndex, child: NodeIndex) -> bool {
249 self.is_parent(child, node)
250 }
251
252 pub fn add_orphan(&mut self, node: NodeIndex) {
257 self.orphans.insert(node);
258
259 let mut children = self.children_walker(node);
260 while let Some(child_idx) = children.next_node(&self.graph) {
261 if !self.is_orphan(child_idx) {
265 self.add_orphan(child_idx);
266 }
267 }
268 }
269
270 pub fn remove_orphan(&mut self, node: NodeIndex) {
272 self.orphans.remove(&node);
273 }
274
275 pub fn is_orphan(&self, node: NodeIndex) -> bool {
277 self.orphans.contains(&node)
278 }
279
280 pub fn orphans(&self) -> impl Iterator<Item = NodeIndex> + '_ {
282 self.orphans.iter().map(|&x| x)
283 }
284
285 pub fn clear_orphans(&mut self) {
287 let orphans: Vec<_> = self.orphans().collect();
288 for orphan_idx in orphans {
289 self.graph.remove_node(orphan_idx);
290 }
291 self.orphans.clear();
292 }
293
294 pub fn add_node(
296 &mut self,
297 node: SVector<X, N>,
298 parent: NodeIndex,
299 trajectory: T,
300 ) -> (NodeIndex, EdgeIndex) {
301 let node_idx = self.graph.add_node(node);
302 let edge_idx = self.update_edge(node_idx, parent, trajectory);
303 (node_idx, edge_idx)
304 }
305
306 pub fn update_edge(
308 &mut self,
309 node: NodeIndex,
310 new_parent: NodeIndex,
311 new_trajectory: T,
312 ) -> EdgeIndex {
313 self.remove_any_parents(node);
314 self.graph.update_edge(node, new_parent, new_trajectory)
315 }
316
317 fn remove_any_parents(&mut self, node: NodeIndex) -> bool {
320 let edges = self
321 .graph
322 .edges_directed(node, Direction::Outgoing)
323 .map(|edge_ref| edge_ref.id());
324 let edges: Vec<_> = edges.collect();
325
326 let removed = edges.len() > 0;
327 for edge_idx in edges {
328 self.graph.remove_edge(edge_idx);
329 }
330 removed
331 }
332
333 pub fn get_optimal_path(
336 &self,
337 node: NodeIndex,
338 ) -> Option<OptimalPathIter<X, T, N>> {
339 match self.is_orphan(node) {
340 true => None,
341 false => Some(OptimalPathIter::new(self, node)),
342 }
343 }
344
345 pub fn get_node(&self, idx: NodeIndex) -> &SVector<X, N> {
349 self.graph.index(idx)
350 }
351
352 pub fn get_node_mut(&mut self, idx: NodeIndex) -> &mut SVector<X, N> {
356 self.graph.index_mut(idx)
357 }
358
359 pub fn get_edge(&self, idx: EdgeIndex) -> &T {
363 self.graph.index(idx)
364 }
365
366 pub fn get_endpoints(&self, idx: EdgeIndex) -> (NodeIndex, NodeIndex) {
370 self.graph.edge_endpoints(idx).unwrap()
371 }
372
373 pub fn get_trajectory(&self, idx: EdgeIndex) -> FullTrajRefOwned<X, T, N> {
377 let (start_idx, end_idx) = self.get_endpoints(idx);
378 let start = self.get_node(start_idx);
379 let end = self.get_node(end_idx);
380 let traj_data = self.get_edge(idx);
381 FullTrajRefOwned::new(start, end, traj_data)
382 }
383}
384
385#[cfg(test)]
386mod tests {
387
388 use crate::trajectories::EuclideanTrajectory;
389
390 use super::*;
391
392 #[test]
393 fn test_rrt_tree_parent() {
394 let goal_coord = [1.5, 1.5].into();
395
396 let mut g = RrtTree::new(goal_coord);
397 let goal = g.get_goal_idx();
398
399 let n1_coord = [2.0, 2.0].into();
400 let n1 = g.graph.add_node(n1_coord);
401
402 g.update_edge(n1, goal, EuclideanTrajectory::new());
403
404 let n2_coord = [-2.0, -2.0].into();
405 let n2 = g.graph.add_node(n2_coord);
406
407 g.update_edge(n2, goal, EuclideanTrajectory::new());
408
409 assert_eq!(g.parent(goal), None);
411 assert_eq!(g.parent(n1), Some(goal));
412 assert_eq!(g.parent(n2), Some(goal));
413 }
414
415 #[test]
416 fn test_rrt_tree_children() {
417 let goal_coord = [1.5, 1.5].into();
418
419 let mut g = RrtTree::new(goal_coord);
420 let goal = g.get_goal_idx();
421
422 let n1_coord = [2.0, 2.0].into();
423 let n1 = g.graph.add_node(n1_coord);
424
425 g.update_edge(n1, goal, EuclideanTrajectory::new());
426
427 let n2_coord = [-2.0, -2.0].into();
428 let n2 = g.graph.add_node(n2_coord);
429
430 g.update_edge(n2, goal, EuclideanTrajectory::new());
431
432 let mut iter = g.children(goal);
434 assert_eq!(iter.next(), Some(n2));
435 assert_eq!(iter.next(), Some(n1));
436 assert_eq!(iter.next(), None);
437
438 let mut iter = g.children(n1);
439 assert_eq!(iter.next(), None);
440
441 let mut iter = g.children(n2);
442 assert_eq!(iter.next(), None);
443 }
444
445 #[test]
446 fn test_rrt_tree_children_walker() {
447 let goal_coord = [1.5, 1.5].into();
448
449 let mut g = RrtTree::new(goal_coord);
450 let goal = g.get_goal_idx();
451
452 let n1_coord = [2.0, 2.0].into();
453 let n1 = g.graph.add_node(n1_coord);
454
455 g.update_edge(n1, goal, EuclideanTrajectory::new());
456
457 let n2_coord = [-2.0, -2.0].into();
458 let n2 = g.graph.add_node(n2_coord);
459
460 g.update_edge(n2, goal, EuclideanTrajectory::new());
461
462 let mut iter = g.children_walker(goal);
464 assert_eq!(iter.next_node(&g.graph), Some(n2));
465 assert_eq!(iter.next_node(&g.graph), Some(n1));
466 assert_eq!(iter.next_node(&g.graph), None);
467
468 let mut iter = g.children_walker(n1);
469 assert_eq!(iter.next_node(&g.graph), None);
470
471 let mut iter = g.children_walker(n2);
472 assert_eq!(iter.next_node(&g.graph), None);
473 }
474}