1use crate::{
19 algorithm::{MinimumCostBound, Path, QueueLength},
20 domain::{ClosedSet, ClosedStatus},
21 error::ThisError,
22};
23use std::{
24 cmp::{Ordering, Reverse},
25 collections::BinaryHeap,
26};
27
28#[derive(Debug)]
30pub struct Tree<Closed, Node, Cost> {
31 pub closed_set: Closed,
34 pub queue: TreeFrontierQueue<Cost>,
37 pub arena: Vec<Node>,
39}
40
41impl<Closed, Node: TreeNode> Tree<Closed, Node, Node::Cost> {
42 pub fn new(closed_set: Closed) -> Self
43 where
44 Node: TreeNode,
45 Node::Cost: Ord,
46 {
47 Self {
48 closed_set,
49 queue: Default::default(),
50 arena: Default::default(),
51 }
52 }
53
54 pub fn push_node(&mut self, node: Node) -> Result<(), TreeError>
55 where
56 Node: TreeNode,
57 Closed: ClosedSet<Node::State, usize>,
58 Node::Cost: Ord,
59 {
60 if let ClosedStatus::Closed(prior) = self.closed_set.status(node.state()) {
61 if let Some(prior) = self.arena.get(*prior) {
62 if prior.cost() <= node.cost() {
63 return Ok(());
66 }
67 } else {
68 return Err(TreeError::BrokenReference(*prior));
71 }
72 }
73
74 let node_id = self.arena.len();
75 let evaluation = node.queue_evaluation();
76 let bias = node.queue_bias();
77 self.arena.push(node);
78 self.queue.push(Reverse(TreeQueueTicket {
79 node_id,
80 bias,
81 evaluation,
82 }));
83 Ok(())
84 }
85}
86
87pub trait TreeNode {
88 type State;
90
91 type Action;
93
94 type Cost;
97
98 fn state(&self) -> &Self::State;
100
101 fn parent(&self) -> Option<(usize, &Self::Action)>;
104
105 fn cost(&self) -> Self::Cost;
107
108 fn queue_evaluation(&self) -> Self::Cost;
113
114 fn queue_bias(&self) -> Option<Self::Cost>;
119}
120
121#[derive(Debug, Clone, Copy)]
122pub struct TreeQueueTicket<Cost> {
123 pub evaluation: Cost,
124 pub bias: Option<Cost>,
125 pub node_id: usize,
126}
127
128pub type TreeFrontierQueue<Cost> = BinaryHeap<Reverse<TreeQueueTicket<Cost>>>;
129
130impl<Cost: PartialEq> PartialEq for TreeQueueTicket<Cost> {
131 fn eq(&self, other: &Self) -> bool {
132 self.evaluation.eq(&other.evaluation)
133 }
134}
135
136impl<Cost: Eq> Eq for TreeQueueTicket<Cost> {}
137
138impl<Cost: PartialOrd> PartialOrd for TreeQueueTicket<Cost> {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 match self.evaluation.partial_cmp(&other.evaluation) {
141 Some(Ordering::Equal) => {
147 if let (Some(l), Some(r)) = (&self.bias, &other.bias) {
148 l.partial_cmp(r)
149 } else {
150 Some(Ordering::Equal)
151 }
152 }
153 value => value,
154 }
155 }
156}
157
158impl<Cost: Ord> Ord for TreeQueueTicket<Cost> {
159 fn cmp(&self, other: &Self) -> Ordering {
160 match self.evaluation.cmp(&other.evaluation) {
161 Ordering::Equal => {
162 if let (Some(l), Some(r)) = (&self.bias, &other.bias) {
163 l.cmp(r)
164 } else {
165 Ordering::Equal
166 }
167 }
168 value => value,
169 }
170 }
171}
172
173pub trait NodeContainer<N: TreeNode> {
174 fn get_node(&self, index: usize) -> Result<&N, TreeError>;
175 fn retrace(&self, index: usize) -> Result<Path<N::State, N::Action, N::Cost>, TreeError>;
176}
177
178impl<N: TreeNode> NodeContainer<N> for Vec<N>
179where
180 N::State: Clone,
181 N::Action: Clone,
182{
183 fn get_node(&self, index: usize) -> Result<&N, TreeError> {
184 self.get(index)
185 .ok_or_else(|| TreeError::BrokenReference(index))
186 }
187
188 fn retrace(&self, node_id: usize) -> Result<Path<N::State, N::Action, N::Cost>, TreeError> {
189 let total_cost = self.get_node(node_id)?.cost();
190 let mut initial_node_id = node_id;
191 let mut next_node_id = Some(node_id);
192 let mut sequence = Vec::new();
193 while let Some(current_node_id) = next_node_id {
194 initial_node_id = current_node_id;
195 let node = self.get_node(current_node_id)?;
196 next_node_id = if let Some((parent_id, action)) = node.parent() {
197 sequence.push((action.clone(), node.state().clone()));
198 Some(parent_id)
199 } else {
200 None
201 };
202 }
203
204 sequence.reverse();
205
206 let initial_state = self.get_node(initial_node_id)?.state().clone();
207 Ok(Path {
208 initial_state,
209 sequence,
210 total_cost,
211 })
212 }
213}
214
215#[derive(ThisError, Debug)]
216pub enum TreeError {
217 #[error(
218 "A node [{0}] is referenced but does not exist in the search memory. \
219 This is a critical implementation error, please report this to the mapf developers."
220 )]
221 BrokenReference(usize),
222}
223
224impl<Closed, Node: TreeNode> QueueLength for Tree<Closed, Node, Node::Cost> {
225 fn queue_length(&self) -> usize {
226 self.queue.len()
227 }
228}
229
230impl<Closed, Node: TreeNode> MinimumCostBound for Tree<Closed, Node, Node::Cost>
231where
232 Node::Cost: Clone,
233{
234 type Cost = Node::Cost;
235 fn minimum_cost_bound(&self) -> Option<Self::Cost> {
236 self.queue.peek().map(|n| n.0.evaluation.clone())
237 }
238}
239
240pub struct BinaryHeapIntoIterSorted<T> {
244 inner: BinaryHeap<T>,
245}
246
247impl<T: Ord> Iterator for BinaryHeapIntoIterSorted<T> {
248 type Item = T;
249 fn next(&mut self) -> Option<Self::Item> {
250 self.inner.pop()
251 }
252
253 fn size_hint(&self) -> (usize, Option<usize>) {
254 let exact = self.inner.len();
255 (exact, Some(exact))
256 }
257}
258
259pub trait IntoIterSorted<T> {
260 fn binary_heap_into_iter_sorted(self) -> BinaryHeapIntoIterSorted<T>;
261}
262
263impl<T: Ord> IntoIterSorted<T> for BinaryHeap<T> {
264 fn binary_heap_into_iter_sorted(self) -> BinaryHeapIntoIterSorted<T> {
265 BinaryHeapIntoIterSorted { inner: self }
266 }
267}