use crate::{
algorithm::{MinimumCostBound, Path, QueueLength},
domain::{ClosedSet, ClosedStatus},
error::ThisError,
};
use std::{
cmp::{Ordering, Reverse},
collections::BinaryHeap,
};
#[derive(Debug)]
pub struct Tree<Closed, Node, Cost> {
pub closed_set: Closed,
pub queue: TreeFrontierQueue<Cost>,
pub arena: Vec<Node>,
}
impl<Closed, Node: TreeNode> Tree<Closed, Node, Node::Cost> {
pub fn new(closed_set: Closed) -> Self
where
Node: TreeNode,
Node::Cost: Ord,
{
Self {
closed_set,
queue: Default::default(),
arena: Default::default(),
}
}
pub fn push_node(&mut self, node: Node) -> Result<(), TreeError>
where
Node: TreeNode,
Closed: ClosedSet<Node::State, usize>,
Node::Cost: Ord,
{
if let ClosedStatus::Closed(prior) = self.closed_set.status(node.state()) {
if let Some(prior) = self.arena.get(*prior) {
if prior.cost() <= node.cost() {
return Ok(());
}
} else {
return Err(TreeError::BrokenReference(*prior));
}
}
let node_id = self.arena.len();
let evaluation = node.queue_evaluation();
let bias = node.queue_bias();
self.arena.push(node);
self.queue.push(Reverse(TreeQueueTicket {
node_id,
bias,
evaluation,
}));
Ok(())
}
}
pub trait TreeNode {
type State;
type Action;
type Cost;
fn state(&self) -> &Self::State;
fn parent(&self) -> Option<(usize, &Self::Action)>;
fn cost(&self) -> Self::Cost;
fn queue_evaluation(&self) -> Self::Cost;
fn queue_bias(&self) -> Option<Self::Cost>;
}
#[derive(Debug, Clone, Copy)]
pub struct TreeQueueTicket<Cost> {
pub evaluation: Cost,
pub bias: Option<Cost>,
pub node_id: usize,
}
pub type TreeFrontierQueue<Cost> = BinaryHeap<Reverse<TreeQueueTicket<Cost>>>;
impl<Cost: PartialEq> PartialEq for TreeQueueTicket<Cost> {
fn eq(&self, other: &Self) -> bool {
self.evaluation.eq(&other.evaluation)
}
}
impl<Cost: Eq> Eq for TreeQueueTicket<Cost> {}
impl<Cost: PartialOrd> PartialOrd for TreeQueueTicket<Cost> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.evaluation.partial_cmp(&other.evaluation) {
Some(Ordering::Equal) => {
if let (Some(l), Some(r)) = (&self.bias, &other.bias) {
l.partial_cmp(r)
} else {
Some(Ordering::Equal)
}
}
value => value,
}
}
}
impl<Cost: Ord> Ord for TreeQueueTicket<Cost> {
fn cmp(&self, other: &Self) -> Ordering {
match self.evaluation.cmp(&other.evaluation) {
Ordering::Equal => {
if let (Some(l), Some(r)) = (&self.bias, &other.bias) {
l.cmp(r)
} else {
Ordering::Equal
}
}
value => value,
}
}
}
pub trait NodeContainer<N: TreeNode> {
fn get_node(&self, index: usize) -> Result<&N, TreeError>;
fn retrace(&self, index: usize) -> Result<Path<N::State, N::Action, N::Cost>, TreeError>;
}
impl<N: TreeNode> NodeContainer<N> for Vec<N>
where
N::State: Clone,
N::Action: Clone,
{
fn get_node(&self, index: usize) -> Result<&N, TreeError> {
self.get(index)
.ok_or_else(|| TreeError::BrokenReference(index))
}
fn retrace(&self, node_id: usize) -> Result<Path<N::State, N::Action, N::Cost>, TreeError> {
let total_cost = self.get_node(node_id)?.cost();
let mut initial_node_id = node_id;
let mut next_node_id = Some(node_id);
let mut sequence = Vec::new();
while let Some(current_node_id) = next_node_id {
initial_node_id = current_node_id;
let node = self.get_node(current_node_id)?;
next_node_id = if let Some((parent_id, action)) = node.parent() {
sequence.push((action.clone(), node.state().clone()));
Some(parent_id)
} else {
None
};
}
sequence.reverse();
let initial_state = self.get_node(initial_node_id)?.state().clone();
Ok(Path {
initial_state,
sequence,
total_cost,
})
}
}
#[derive(ThisError, Debug)]
pub enum TreeError {
#[error(
"A node [{0}] is referenced but does not exist in the search memory. \
This is a critical implementation error, please report this to the mapf developers."
)]
BrokenReference(usize),
}
impl<Closed, Node: TreeNode> QueueLength for Tree<Closed, Node, Node::Cost> {
fn queue_length(&self) -> usize {
self.queue.len()
}
}
impl<Closed, Node: TreeNode> MinimumCostBound for Tree<Closed, Node, Node::Cost>
where
Node::Cost: Clone,
{
type Cost = Node::Cost;
fn minimum_cost_bound(&self) -> Option<Self::Cost> {
self.queue.peek().map(|n| n.0.evaluation.clone())
}
}
pub struct BinaryHeapIntoIterSorted<T> {
inner: BinaryHeap<T>,
}
impl<T: Ord> Iterator for BinaryHeapIntoIterSorted<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.inner.pop()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let exact = self.inner.len();
(exact, Some(exact))
}
}
pub trait IntoIterSorted<T> {
fn binary_heap_into_iter_sorted(self) -> BinaryHeapIntoIterSorted<T>;
}
impl<T: Ord> IntoIterSorted<T> for BinaryHeap<T> {
fn binary_heap_into_iter_sorted(self) -> BinaryHeapIntoIterSorted<T> {
BinaryHeapIntoIterSorted { inner: self }
}
}