1#![forbid(clippy::mod_module_files)]
2
3use std::{
4 fmt::{Debug, Display},
5 hash::Hash,
6 marker::PhantomData,
7};
8
9use binary_heap_plus::BinaryHeap;
10use comparator::AStarNodeComparator;
11use cost::AStarCost;
12use extend_map::ExtendFilter;
13use log::trace;
14use num_traits::{Bounded, Zero};
15use reset::Reset;
16use rustc_hash::FxHashMapSeed;
17
18use crate::{closed_lists::AStarClosedList, open_lists::AStarOpenList};
19
20pub mod closed_lists;
21pub mod comparator;
22pub mod cost;
23pub mod open_lists;
24pub mod reset;
25
26const DEBUG_ASTAR: bool = false;
27
28pub trait AStarIdentifier: Debug + Clone + Eq + Hash {}
30
31pub trait AStarNode: Sized + Ord + Debug + Display {
35 type Identifier: AStarIdentifier;
39
40 type EdgeType: Debug;
44
45 type Cost: AStarCost;
46
47 fn identifier(&self) -> &Self::Identifier;
49
50 fn cost(&self) -> Self::Cost;
54
55 fn a_star_lower_bound(&self) -> Self::Cost;
57
58 fn secondary_maximisable_score(&self) -> usize;
62
63 fn predecessor(&self) -> Option<&Self::Identifier>;
65
66 fn predecessor_edge_type(&self) -> Option<Self::EdgeType>;
68
69 fn required_memory() -> usize {
73 std::mem::size_of::<Self>()
74 }
75}
76
77pub trait AStarContext: Reset {
78 type Node: AStarNode;
80
81 fn create_root(&self) -> Self::Node;
83
84 fn generate_successors(&mut self, node: &Self::Node, output: &mut impl Extend<Self::Node>);
86
87 fn is_target(&self, node: &Self::Node) -> bool;
89
90 fn cost_limit(&self) -> Option<<Self::Node as AStarNode>::Cost>;
94
95 fn memory_limit(&self) -> Option<usize>;
99
100 fn is_label_setting(&self) -> bool {
108 true
109 }
110}
111
112#[derive(Debug, Default)]
113pub struct AStarPerformanceCounters {
114 pub opened_nodes: usize,
115 pub suboptimal_opened_nodes: usize,
117 pub closed_nodes: usize,
118}
119
120#[derive(Debug, PartialEq, Eq)]
121pub enum AStarState<NodeIdentifier, Cost> {
122 Empty,
124 Init,
126 Searching,
128 Terminated {
130 result: AStarResult<NodeIdentifier, Cost>,
131 },
132}
133
134#[derive(Debug)]
135pub struct AStar<
136 Context: AStarContext,
137 ClosedList: AStarClosedList<Context::Node> = FxHashMapSeed<
138 <<Context as AStarContext>::Node as AStarNode>::Identifier,
139 <Context as AStarContext>::Node,
140 >,
141 OpenList: AStarOpenList<Context::Node> = BinaryHeap<
142 <Context as AStarContext>::Node,
143 AStarNodeComparator,
144 >,
145> {
146 state: AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>,
147 context: Context,
148 closed_list: ClosedList,
149 open_list: OpenList,
150 performance_counters: AStarPerformanceCounters,
151}
152
153#[derive(Debug)]
154pub struct AStarBuffers<
155 Node: AStarNode,
156 ClosedList: AStarClosedList<Node> = FxHashMapSeed<<Node as AStarNode>::Identifier, Node>,
157 OpenList: AStarOpenList<Node> = BinaryHeap<Node, AStarNodeComparator>,
158> {
159 closed_list: ClosedList,
160 open_list: OpenList,
161 phantom_data: PhantomData<Node>,
162}
163
164#[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Hash, Default)]
165#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
166#[cfg_attr(feature = "serde", serde(tag = "astar_result_type"))]
167pub enum AStarResult<NodeIdentifier, Cost> {
168 FoundTarget {
170 #[cfg_attr(feature = "serde", serde(skip))]
171 identifier: NodeIdentifier,
172 cost: Cost,
173 },
174
175 ExceededCostLimit { cost_limit: Cost },
177
178 ExceededMemoryLimit {
180 max_cost: Cost,
182 },
183
184 #[default]
186 NoTarget,
187}
188
189struct BacktrackingIterator<
190 'a_star,
191 Context: AStarContext,
192 ClosedList: AStarClosedList<Context::Node>,
193 OpenList: AStarOpenList<Context::Node>,
194> {
195 a_star: &'a_star AStar<Context, ClosedList, OpenList>,
196 current: <Context::Node as AStarNode>::Identifier,
197}
198
199struct BacktrackingIteratorWithCost<
200 'a_star,
201 Context: AStarContext,
202 ClosedList: AStarClosedList<Context::Node>,
203 OpenList: AStarOpenList<Context::Node>,
204> {
205 a_star: &'a_star AStar<Context, ClosedList, OpenList>,
206 current: <Context::Node as AStarNode>::Identifier,
207}
208
209impl<
210 Context: AStarContext,
211 ClosedList: AStarClosedList<Context::Node>,
212 OpenList: AStarOpenList<Context::Node>,
213> AStar<Context, ClosedList, OpenList>
214{
215 pub fn new(context: Context) -> Self {
216 Self {
217 state: AStarState::Empty,
218 context,
219 closed_list: ClosedList::new(),
220 open_list: OpenList::new(),
221 performance_counters: Default::default(),
222 }
223 }
224
225 pub fn new_with_buffers(
226 context: Context,
227 mut buffers: AStarBuffers<Context::Node, ClosedList, OpenList>,
228 ) -> Self {
229 buffers.closed_list.reset();
230 buffers.open_list.reset();
231 Self {
232 state: AStarState::Empty,
233 context,
234 closed_list: buffers.closed_list,
235 open_list: buffers.open_list,
236 performance_counters: Default::default(),
237 }
238 }
239
240 pub fn state(
241 &self,
242 ) -> &AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
243 {
244 &self.state
245 }
246
247 pub fn context(&self) -> &Context {
248 &self.context
249 }
250
251 pub fn context_mut(&mut self) -> &mut Context {
252 &mut self.context
253 }
254
255 pub fn into_context(self) -> Context {
256 self.context
257 }
258
259 pub fn into_buffers(self) -> AStarBuffers<Context::Node, ClosedList, OpenList> {
260 AStarBuffers {
261 closed_list: self.closed_list,
262 open_list: self.open_list,
263 phantom_data: Default::default(),
264 }
265 }
266
267 pub fn closed_node(
268 &self,
269 node_identifier: &<Context::Node as AStarNode>::Identifier,
270 ) -> Option<&Context::Node> {
271 self.closed_list.get(node_identifier)
272 }
273
274 pub fn iter_closed_nodes(&self) -> impl Iterator<Item = &Context::Node> {
275 self.closed_list.iter()
276 }
277
278 pub fn performance_counters(&self) -> &AStarPerformanceCounters {
279 &self.performance_counters
280 }
281
282 pub fn reset(&mut self) {
283 self.state = AStarState::Empty;
284 self.context.reset();
285 self.closed_list.reset();
286 self.open_list.reset();
287 self.performance_counters = Default::default();
288 }
289
290 pub fn initialise(&mut self) {
291 self.initialise_with(|context| context.create_root());
292 }
293
294 pub fn initialise_with(&mut self, node: impl FnOnce(&Context) -> Context::Node) {
295 assert_eq!(self.state, AStarState::Empty);
296
297 self.state = AStarState::Init;
298 self.open_list.push(node(&self.context));
299 }
300
301 pub fn search(
302 &mut self,
303 ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
304 {
305 self.search_until(|context, node| context.is_target(node))
306 }
307
308 pub fn search_until(
309 &mut self,
310 is_target: impl FnMut(&Context, &Context::Node) -> bool,
311 ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
312 {
313 self.search_until_with_target_policy(is_target, false)
314 }
315
316 pub fn search_until_with_target_policy(
317 &mut self,
318 mut is_target: impl FnMut(&Context, &Context::Node) -> bool,
319 abort_on_first_target: bool,
320 ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
321 {
322 assert!(matches!(
323 self.state,
324 AStarState::Init | AStarState::Searching | AStarState::Terminated { .. }
325 ));
326
327 let cost_limit = self
328 .context
329 .cost_limit()
330 .unwrap_or(<Context::Node as AStarNode>::Cost::max_value());
331 let mut applied_cost_limit = false;
332 let memory_limit = self.context.memory_limit().unwrap_or(usize::MAX);
333 let node_count_limit =
335 (memory_limit as f64 / Context::Node::required_memory() as f64 / 2.3).round() as usize;
336
337 if self.open_list.is_empty() {
338 return AStarResult::NoTarget;
339 }
340
341 self.state = AStarState::Searching;
342
343 let mut last_node = None;
344 let mut target_identifier = None;
345 let mut target_cost = <Context::Node as AStarNode>::Cost::max_value();
346 let mut target_secondary_maximisable_score = 0;
347
348 loop {
349 let Some(node) = self.open_list.pop_min() else {
350 if last_node.is_none() {
351 unreachable!("Open list was empty.");
352 };
353 if applied_cost_limit {
354 self.state = AStarState::Terminated {
355 result: AStarResult::ExceededCostLimit { cost_limit },
356 };
357 return AStarResult::ExceededCostLimit { cost_limit };
358 } else if target_identifier.is_some() {
359 break;
362 } else {
363 self.state = AStarState::Terminated {
364 result: AStarResult::NoTarget,
365 };
366 return AStarResult::NoTarget;
367 }
368 };
369
370 if node.cost() + node.a_star_lower_bound() > cost_limit {
373 self.state = AStarState::Terminated {
374 result: AStarResult::ExceededCostLimit { cost_limit },
375 };
376 return AStarResult::ExceededCostLimit { cost_limit };
377 }
378
379 if self.closed_list.len() + self.open_list.len() > node_count_limit {
381 self.state = AStarState::Terminated {
382 result: AStarResult::ExceededMemoryLimit {
383 max_cost: node.cost(),
384 },
385 };
386 return AStarResult::ExceededMemoryLimit {
387 max_cost: node.cost(),
388 };
389 }
390
391 if node.cost() + node.a_star_lower_bound() > target_cost {
393 debug_assert!(!self.context.is_label_setting());
394 break;
395 }
396
397 last_node = Some(node.identifier().clone());
398 let is_target = is_target(&self.context, &node);
399 if DEBUG_ASTAR && is_target {
400 trace!("Node {node} is target");
401 }
402 if is_target {
403 } else {
405 }
407 debug_assert!(!is_target || node.a_star_lower_bound().is_zero());
408
409 if self
410 .closed_list
411 .can_skip_node(&node, self.context.is_label_setting())
412 {
413 self.performance_counters.suboptimal_opened_nodes += 1;
415 let existing_cost = self.closed_list.get(node.identifier()).unwrap().cost();
416 let existing_secondary_maximisable_score = self
417 .closed_list
418 .get(node.identifier())
419 .unwrap()
420 .secondary_maximisable_score();
421
422 if is_target
423 && (node.cost() < target_cost.min(existing_cost)
424 || (node.cost() == target_cost.min(existing_cost)
425 && node.secondary_maximisable_score()
426 > target_secondary_maximisable_score
427 .max(existing_secondary_maximisable_score)))
428 {
429 if DEBUG_ASTAR {
430 trace!("Updating target to {node}");
431 }
432 target_identifier = Some(node.identifier().clone());
434 target_cost = node.cost();
435 target_secondary_maximisable_score = node.secondary_maximisable_score();
436
437 if self.context.is_label_setting() || abort_on_first_target {
438 if DEBUG_ASTAR {
439 trace!("Context is label setting, so we return the first target found");
440 }
441 let previous_visit =
442 self.closed_list.insert(node.identifier().clone(), node);
443 self.performance_counters.closed_nodes += 1;
444 debug_assert!(
445 previous_visit.is_none() || !self.context.is_label_setting(),
446 "Visited node again even though we are label setting:\nprevious: {}",
447 previous_visit.unwrap(),
448 );
449 break;
450 }
451 } else if is_target
452 && (existing_cost < target_cost
453 || (existing_cost == target_cost
454 && node.secondary_maximisable_score()
455 > existing_secondary_maximisable_score))
456 {
457 let node = self.closed_list.get(node.identifier()).unwrap();
458 if DEBUG_ASTAR {
460 trace!("Updating target to {node}");
461 }
462 target_identifier = Some(node.identifier().clone());
464 target_cost = node.cost();
465 target_secondary_maximisable_score = node.secondary_maximisable_score();
466
467 if self.context.is_label_setting() || abort_on_first_target {
468 if DEBUG_ASTAR {
469 trace!("Context is label setting, so we return the first target found");
470 }
471 self.performance_counters.closed_nodes += 1;
472 break;
473 }
474 }
475
476 if DEBUG_ASTAR {
477 trace!("Skipping node {node}");
478 }
479 continue;
480 }
481
482 let open_nodes_without_new_successors = self.open_list.len();
483 self.context.generate_successors(
484 &node,
485 &mut ExtendFilter::new(&mut self.open_list, |node| {
486 let result = node.cost() + node.a_star_lower_bound() <= cost_limit;
487 applied_cost_limit = applied_cost_limit || !result;
488 result
489 }),
490 );
491 self.performance_counters.opened_nodes +=
492 self.open_list.len() - open_nodes_without_new_successors;
493
494 if is_target
495 && (node.cost() < target_cost
496 || (node.cost() == target_cost
497 && node.secondary_maximisable_score() > target_secondary_maximisable_score))
498 {
499 if DEBUG_ASTAR {
500 trace!("Updating target to {node}");
501 }
502 target_identifier = Some(node.identifier().clone());
504 target_cost = node.cost();
505 target_secondary_maximisable_score = node.secondary_maximisable_score();
506
507 if self.context.is_label_setting() || abort_on_first_target {
508 if DEBUG_ASTAR {
509 trace!("Context is label setting, so we return the first target found");
510 }
511 let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
512 self.performance_counters.closed_nodes += 1;
513 debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
514 break;
515 }
516 }
517
518 let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
519 self.performance_counters.closed_nodes += 1;
520 debug_assert!(
521 previous_visit.is_none() || !self.context.is_label_setting(),
522 "Node was visited previously: {}",
523 previous_visit.unwrap()
524 );
525 }
526
527 let Some(target_identifier) = target_identifier else {
528 debug_assert!(!self.context.is_label_setting());
529 self.state = AStarState::Terminated {
530 result: AStarResult::NoTarget,
531 };
532 return AStarResult::NoTarget;
533 };
534
535 let cost = self.closed_list.get(&target_identifier).unwrap().cost();
536 debug_assert_eq!(
537 cost,
538 target_cost,
539 "Target node has lower cost than target_cost:\nnode: {}\ntarget_cost: {target_cost}",
540 self.closed_list.get(&target_identifier).unwrap(),
541 );
542 self.state = AStarState::Terminated {
543 result: AStarResult::FoundTarget {
544 identifier: target_identifier.clone(),
545 cost,
546 },
547 };
548 AStarResult::FoundTarget {
549 identifier: target_identifier,
550 cost,
551 }
552 }
553
554 pub fn backtrack(
555 &self,
556 ) -> impl use<'_, Context, ClosedList, OpenList>
557 + Iterator<Item = <Context::Node as AStarNode>::EdgeType> {
558 let AStarState::Terminated {
559 result: AStarResult::FoundTarget { identifier, .. },
560 } = &self.state
561 else {
562 panic!("Cannot backtrack since no target was found.")
563 };
564
565 self.backtrack_from(identifier).unwrap()
566 }
567
568 pub fn backtrack_with_costs(
573 &self,
574 ) -> impl use<'_, Context, ClosedList, OpenList>
575 + Iterator<
576 Item = (
577 <Context::Node as AStarNode>::EdgeType,
578 <Context::Node as AStarNode>::Cost,
579 ),
580 > {
581 let AStarState::Terminated {
582 result: AStarResult::FoundTarget { identifier, .. },
583 } = &self.state
584 else {
585 panic!("Cannot backtrack since no target was found.")
586 };
587
588 self.backtrack_with_costs_from(identifier).unwrap()
589 }
590
591 pub fn backtrack_from(
592 &self,
593 identifier: &<Context::Node as AStarNode>::Identifier,
594 ) -> Option<
595 impl use<'_, Context, ClosedList, OpenList>
596 + Iterator<Item = <Context::Node as AStarNode>::EdgeType>,
597 > {
598 if self.closed_list.contains_identifier(identifier) {
599 Some(BacktrackingIterator {
600 a_star: self,
601 current: identifier.clone(),
602 })
603 } else {
604 None
605 }
606 }
607
608 #[allow(clippy::type_complexity)]
613 pub fn backtrack_with_costs_from(
614 &self,
615 identifier: &<Context::Node as AStarNode>::Identifier,
616 ) -> Option<
617 impl use<'_, Context, ClosedList, OpenList>
618 + Iterator<
619 Item = (
620 <Context::Node as AStarNode>::EdgeType,
621 <Context::Node as AStarNode>::Cost,
622 ),
623 >,
624 > {
625 if self.closed_list.contains_identifier(identifier) {
626 Some(BacktrackingIteratorWithCost {
627 a_star: self,
628 current: identifier.clone(),
629 })
630 } else {
631 None
632 }
633 }
634
635 pub fn reconstruct_path(&self) -> Vec<<Context::Node as AStarNode>::EdgeType> {
637 let AStarState::Terminated {
638 result: AStarResult::FoundTarget { .. },
639 } = &self.state
640 else {
641 panic!("Cannot reconstruct path since no target was found.")
642 };
643
644 let mut result = self.backtrack().collect::<Vec<_>>();
645 result.reverse();
646 result
647 }
648}
649
650impl<NodeIdentifier, Cost: Copy> AStarResult<NodeIdentifier, Cost> {
651 pub fn cost(&self) -> Cost {
655 match self {
656 Self::FoundTarget { cost, .. } => *cost,
657 Self::ExceededCostLimit { cost_limit } => *cost_limit,
658 Self::ExceededMemoryLimit { max_cost } => *max_cost,
659 Self::NoTarget => panic!("AStarResult has no costs"),
660 }
661 }
662
663 pub fn without_node_identifier(&self) -> AStarResult<(), Cost> {
664 match *self {
665 Self::FoundTarget { cost, .. } => AStarResult::FoundTarget {
666 identifier: (),
667 cost,
668 },
669 Self::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit { cost_limit },
670 Self::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit { max_cost },
671 Self::NoTarget => AStarResult::NoTarget,
672 }
673 }
674}
675
676impl<NodeIdentifier: Clone, Cost> AStarResult<NodeIdentifier, Cost> {
677 pub fn transform_cost<TargetCost>(
678 &self,
679 transform: impl Fn(&Cost) -> TargetCost,
680 ) -> AStarResult<NodeIdentifier, TargetCost> {
681 match self {
682 AStarResult::FoundTarget { identifier, cost } => AStarResult::FoundTarget {
683 identifier: identifier.clone(),
684 cost: transform(cost),
685 },
686 AStarResult::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit {
687 cost_limit: transform(cost_limit),
688 },
689 AStarResult::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit {
690 max_cost: transform(max_cost),
691 },
692 AStarResult::NoTarget => AStarResult::NoTarget,
693 }
694 }
695}
696
697impl<
698 Context: AStarContext,
699 ClosedList: AStarClosedList<Context::Node>,
700 OpenList: AStarOpenList<Context::Node>,
701> Iterator for BacktrackingIterator<'_, Context, ClosedList, OpenList>
702{
703 type Item = <Context::Node as AStarNode>::EdgeType;
704
705 fn next(&mut self) -> Option<Self::Item> {
706 let current = self.a_star.closed_list.get(&self.current).unwrap();
707
708 if let Some(predecessor) = current.predecessor().cloned() {
709 let predecessor_edge_type = current.predecessor_edge_type().unwrap();
710 self.current = predecessor;
711 Some(predecessor_edge_type)
712 } else {
713 None
714 }
715 }
716}
717
718impl<
719 Context: AStarContext,
720 ClosedList: AStarClosedList<Context::Node>,
721 OpenList: AStarOpenList<Context::Node>,
722> Iterator for BacktrackingIteratorWithCost<'_, Context, ClosedList, OpenList>
723{
724 type Item = (
725 <Context::Node as AStarNode>::EdgeType,
726 <Context::Node as AStarNode>::Cost,
727 );
728
729 fn next(&mut self) -> Option<Self::Item> {
730 let current = self.a_star.closed_list.get(&self.current).unwrap();
731 let cost = current.cost();
732
733 if let Some(predecessor) = current.predecessor().cloned() {
734 let predecessor_edge_type = current.predecessor_edge_type().unwrap();
735 self.current = predecessor;
736 Some((predecessor_edge_type, cost))
737 } else {
738 None
739 }
740 }
741}
742
743impl<Node: AStarNode, ClosedList: AStarClosedList<Node>, OpenList: AStarOpenList<Node>> Default
744 for AStarBuffers<Node, ClosedList, OpenList>
745{
746 fn default() -> Self {
747 Self {
748 closed_list: ClosedList::new(),
749 open_list: OpenList::new(),
750 phantom_data: Default::default(),
751 }
752 }
753}
754
755impl<NodeIdentifier, Cost: Display> Display for AStarResult<NodeIdentifier, Cost> {
756 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
757 match self {
758 AStarResult::FoundTarget { cost, .. } => write!(f, "Reached target with cost {cost}"),
759 AStarResult::ExceededCostLimit { cost_limit } => {
760 write!(f, "Exceeded cost limit of {cost_limit}")
761 }
762 AStarResult::ExceededMemoryLimit { max_cost } => write!(
763 f,
764 "Exceeded memory limit, but reached a maximum cost of {max_cost}"
765 ),
766 AStarResult::NoTarget => write!(f, "Found no target"),
767 }
768 }
769}
770
771impl<T: AStarNode> AStarNode for Box<T> {
772 type Identifier = <T as AStarNode>::Identifier;
773
774 type EdgeType = <T as AStarNode>::EdgeType;
775
776 type Cost = <T as AStarNode>::Cost;
777
778 fn identifier(&self) -> &Self::Identifier {
779 <T as AStarNode>::identifier(self)
780 }
781
782 fn cost(&self) -> Self::Cost {
783 <T as AStarNode>::cost(self)
784 }
785
786 fn a_star_lower_bound(&self) -> Self::Cost {
787 <T as AStarNode>::a_star_lower_bound(self)
788 }
789
790 fn secondary_maximisable_score(&self) -> usize {
791 <T as AStarNode>::secondary_maximisable_score(self)
792 }
793
794 fn predecessor(&self) -> Option<&Self::Identifier> {
795 <T as AStarNode>::predecessor(self)
796 }
797
798 fn predecessor_edge_type(&self) -> Option<Self::EdgeType> {
799 <T as AStarNode>::predecessor_edge_type(self)
800 }
801
802 fn required_memory() -> usize {
803 <T as AStarNode>::required_memory() + std::mem::size_of::<Box<()>>()
804 }
805}