pub use petgraph::stable_graph::{EdgeIndex, NodeIndex};
use std::collections::{HashSet, VecDeque};
use std::marker::PhantomData;
use std::ops::{Index, IndexMut};
use nalgebra::SVector;
use num_traits::{Float, Zero};
use petgraph::stable_graph::{
DefaultIx, EdgeIndices, EdgeReference, Neighbors, NodeIndices, StableDiGraph,
WalkNeighbors,
};
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::scalar::Scalar;
use crate::trajectories::{FullTrajRefOwned, Trajectory};
#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
#[serde(bound(
serialize = "X: Serialize",
deserialize = "X: DeserializeOwned"
))]
pub struct Node<X: Scalar, const N: usize> {
state: SVector<X, N>,
pub cost: X,
}
impl<X: Scalar + Float, const N: usize> Node<X, N> {
pub fn new(state: SVector<X, N>) -> Self {
Self {
state,
cost: X::infinity(),
}
}
}
impl<X: Scalar, const N: usize> Node<X, N> {
fn with_cost(state: SVector<X, N>, cost: X) -> Self {
Self { state, cost }
}
pub fn state(&self) -> &SVector<X, N> {
&self.state
}
}
pub struct NodeIter<'a, X: Scalar, const N: usize> {
nodes: NodeIndices<'a, Node<X, N>>,
}
impl<'a, X: Scalar, const N: usize> NodeIter<'a, X, N> {
fn new<T>(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self
where
T: Trajectory<X, N>,
{
Self {
nodes: graph.node_indices(),
}
}
}
impl<'a, X: Scalar, const N: usize> Iterator for NodeIter<'a, X, N> {
type Item = NodeIndex;
fn next(&mut self) -> Option<Self::Item> {
self.nodes.next()
}
}
pub struct EdgeIter<'a, X, T, const N: usize>
where
T: Trajectory<X, N>,
{
edges: EdgeIndices<'a, T>,
phantom_x: PhantomData<X>,
}
impl<'a, X, T, const N: usize> EdgeIter<'a, X, T, N>
where
X: Scalar,
T: Trajectory<X, N>,
{
fn new(graph: &'a StableDiGraph<Node<X, N>, T>) -> Self {
Self {
edges: graph.edge_indices(),
phantom_x: PhantomData,
}
}
}
impl<'a, X, T, const N: usize> Iterator for EdgeIter<'a, X, T, N>
where
T: Trajectory<X, N>,
{
type Item = EdgeIndex;
fn next(&mut self) -> Option<Self::Item> {
self.edges.next()
}
}
pub struct OptimalPathIter<'a, X, T, const N: usize>
where
X: Scalar,
T: Trajectory<X, N>,
{
graph: &'a RrtStarTree<X, T, N>,
next_node: Option<NodeIndex>,
}
impl<'a, X, T, const N: usize> OptimalPathIter<'a, X, T, N>
where
X: Scalar,
T: Trajectory<X, N>,
{
fn new(graph: &'a RrtStarTree<X, T, N>, node: NodeIndex) -> Self {
Self {
graph,
next_node: Some(node),
}
}
pub fn detach(self) -> OptimalPathWalker {
OptimalPathWalker::new(self.next_node)
}
}
impl<'a, X, T, const N: usize> Iterator for OptimalPathIter<'a, X, T, N>
where
X: Scalar,
T: Trajectory<X, N>,
{
type Item = NodeIndex;
fn next(&mut self) -> Option<Self::Item> {
let node = self.next_node?;
match self.graph.parent(node) {
Some(parent) => self.next_node = Some(parent),
None => self.next_node = None,
}
Some(node)
}
}
pub struct OptimalPathWalker {
next_node: Option<NodeIndex>,
}
impl OptimalPathWalker {
fn new(node: Option<NodeIndex>) -> Self {
Self { next_node: node }
}
pub fn next<X, T, const N: usize>(
&mut self,
g: &RrtStarTree<X, T, N>,
) -> Option<NodeIndex>
where
X: Scalar,
T: Trajectory<X, N>,
{
let node = self.next_node?;
match g.parent(node) {
Some(parent) => self.next_node = Some(parent),
None => self.next_node = None,
}
Some(node)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(
serialize = "X: Serialize, T: Serialize",
deserialize = "X: DeserializeOwned, T: DeserializeOwned",
))]
pub struct RrtStarTree<X, T, const N: usize>
where
X: Scalar,
T: Trajectory<X, N>,
{
goal_idx: NodeIndex,
graph: StableDiGraph<Node<X, N>, T>,
orphans: HashSet<NodeIndex>,
}
impl<X, T, const N: usize> RrtStarTree<X, T, N>
where
X: Scalar + Zero,
T: Trajectory<X, N>,
{
pub fn new(goal: SVector<X, N>) -> Self {
let mut graph = StableDiGraph::new();
let goal_node = Node::with_cost(goal, X::zero());
let goal_idx = graph.add_node(goal_node);
let orphans = HashSet::new();
Self {
goal_idx,
graph,
orphans,
}
}
}
impl<X, T, const N: usize> RrtStarTree<X, T, N>
where
X: Scalar,
T: Trajectory<X, N>,
{
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn get_goal_idx(&self) -> NodeIndex {
self.goal_idx
}
pub fn get_goal(&self) -> &Node<X, N> {
self.get_node(self.goal_idx)
}
pub fn all_nodes(&self) -> NodeIter<X, N> {
NodeIter::new(&self.graph)
}
pub fn all_edges(&self) -> EdgeIter<X, T, N> {
EdgeIter::new(&self.graph)
}
pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
Some(self.parent_edge(node)?.target())
}
fn parent_edge(&self, node: NodeIndex) -> Option<EdgeReference<T>> {
self.graph.edges_directed(node, Direction::Outgoing).next()
}
pub fn is_parent(&self, node: NodeIndex, parent: NodeIndex) -> bool {
self.graph.find_edge(node, parent).is_some()
}
pub fn children(&self, node: NodeIndex) -> Neighbors<T, DefaultIx> {
self.graph.neighbors_directed(node, Direction::Incoming)
}
pub fn children_walker(&self, node: NodeIndex) -> WalkNeighbors<DefaultIx> {
self
.graph
.neighbors_directed(node, Direction::Incoming)
.detach()
}
pub fn is_child(&self, node: NodeIndex, child: NodeIndex) -> bool {
self.is_parent(child, node)
}
pub fn add_orphan(&mut self, node: NodeIndex) {
let mut queue = VecDeque::new();
queue.push_back(node);
while let Some(node) = queue.pop_front() {
if self.orphans.insert(node) {
let mut children = self.children_walker(node);
while let Some(child_idx) = children.next_node(&self.graph) {
queue.push_back(child_idx);
}
}
}
}
pub fn remove_orphan(&mut self, node: NodeIndex) {
self.orphans.remove(&node);
}
pub fn is_orphan(&self, node: NodeIndex) -> bool {
self.orphans.contains(&node)
}
pub fn orphans(&self) -> impl Iterator<Item = NodeIndex> + '_ {
self.orphans.iter().map(|&x| x)
}
pub fn clear_orphans(&mut self) {
let orphans: Vec<_> = self.orphans().collect();
for orphan_idx in orphans {
self.graph.remove_node(orphan_idx);
}
self.orphans.clear();
}
pub fn add_node(
&mut self,
node: Node<X, N>,
parent: NodeIndex,
trajectory: T,
) -> (NodeIndex, EdgeIndex) {
let node_idx = self.graph.add_node(node);
let edge_idx = self.update_edge(node_idx, parent, trajectory);
(node_idx, edge_idx)
}
pub fn update_edge(
&mut self,
node: NodeIndex,
new_parent: NodeIndex,
new_trajectory: T,
) -> EdgeIndex {
self.remove_any_parents(node);
self.graph.update_edge(node, new_parent, new_trajectory)
}
fn remove_any_parents(&mut self, node: NodeIndex) -> bool {
let edges = self
.graph
.edges_directed(node, Direction::Outgoing)
.map(|edge_ref| edge_ref.id());
let edges: Vec<_> = edges.collect();
let removed = edges.len() > 0;
for edge_idx in edges {
self.graph.remove_edge(edge_idx);
}
removed
}
pub fn get_optimal_path(
&self,
node: NodeIndex,
) -> Option<OptimalPathIter<X, T, N>> {
match self.is_orphan(node) {
true => None,
false => Some(OptimalPathIter::new(self, node)),
}
}
pub fn get_node(&self, idx: NodeIndex) -> &Node<X, N> {
self.graph.index(idx)
}
pub fn get_node_mut(&mut self, idx: NodeIndex) -> &mut Node<X, N> {
self.graph.index_mut(idx)
}
pub fn get_edge(&self, idx: EdgeIndex) -> &T {
self.graph.index(idx)
}
pub fn get_endpoints(&self, idx: EdgeIndex) -> (NodeIndex, NodeIndex) {
self.graph.edge_endpoints(idx).unwrap()
}
pub fn get_trajectory(&self, idx: EdgeIndex) -> FullTrajRefOwned<X, T, N> {
let (start_idx, end_idx) = self.get_endpoints(idx);
let start = self.get_node(start_idx).state();
let end = self.get_node(end_idx).state();
let traj_data = self.get_edge(idx);
FullTrajRefOwned::new(start, end, traj_data)
}
}
#[cfg(test)]
mod tests {
use crate::trajectories::EuclideanTrajectory;
use super::*;
#[test]
fn test_rrt_star_tree_parent() {
let goal_coord = [1.5, 1.5].into();
let mut g = RrtStarTree::new(goal_coord);
let goal = g.get_goal_idx();
let n1_coord = [2.0, 2.0].into();
let n1 = Node {
state: n1_coord,
cost: 0.1,
};
let n1 = g.graph.add_node(n1);
g.update_edge(n1, goal, EuclideanTrajectory::new());
let n2_coord = [-2.0, -2.0].into();
let n2 = Node {
state: n2_coord,
cost: 0.5,
};
let n2 = g.graph.add_node(n2);
g.update_edge(n2, goal, EuclideanTrajectory::new());
assert_eq!(g.parent(goal), None);
assert_eq!(g.parent(n1), Some(goal));
assert_eq!(g.parent(n2), Some(goal));
}
#[test]
fn test_rrt_star_tree_children() {
let goal_coord = [1.5, 1.5].into();
let mut g = RrtStarTree::new(goal_coord);
let goal = g.get_goal_idx();
let n1_coord = [2.0, 2.0].into();
let n1 = Node {
state: n1_coord,
cost: 0.1,
};
let n1 = g.graph.add_node(n1);
g.update_edge(n1, goal, EuclideanTrajectory::new());
let n2_coord = [-2.0, -2.0].into();
let n2 = Node {
state: n2_coord,
cost: 0.5,
};
let n2 = g.graph.add_node(n2);
g.update_edge(n2, goal, EuclideanTrajectory::new());
let mut iter = g.children(goal);
assert_eq!(iter.next(), Some(n2));
assert_eq!(iter.next(), Some(n1));
assert_eq!(iter.next(), None);
let mut iter = g.children(n1);
assert_eq!(iter.next(), None);
let mut iter = g.children(n2);
assert_eq!(iter.next(), None);
}
#[test]
fn test_rrt_star_tree_children_walker() {
let goal_coord = [1.5, 1.5].into();
let mut g = RrtStarTree::new(goal_coord);
let goal = g.get_goal_idx();
let n1_coord = [2.0, 2.0].into();
let n1 = Node {
state: n1_coord,
cost: 0.1,
};
let n1 = g.graph.add_node(n1);
g.update_edge(n1, goal, EuclideanTrajectory::new());
let n2_coord = [-2.0, -2.0].into();
let n2 = Node {
state: n2_coord,
cost: 0.5,
};
let n2 = g.graph.add_node(n2);
g.update_edge(n2, goal, EuclideanTrajectory::new());
let mut iter = g.children_walker(goal);
assert_eq!(iter.next_node(&g.graph), Some(n2));
assert_eq!(iter.next_node(&g.graph), Some(n1));
assert_eq!(iter.next_node(&g.graph), None);
let mut iter = g.children_walker(n1);
assert_eq!(iter.next_node(&g.graph), None);
let mut iter = g.children_walker(n2);
assert_eq!(iter.next_node(&g.graph), None);
}
}