1#![forbid(clippy::mod_module_files)]
2
3use std::{
4 cmp::Ordering,
5 collections::HashMap,
6 fmt::{Debug, Display},
7 hash::Hash,
8};
9
10use binary_heap_plus::BinaryHeap;
11use comparator::AStarNodeComparator;
12use compare::Compare;
13use cost::AStarCost;
14use deterministic_default_hasher::DeterministicDefaultHasher;
15use extend_map::ExtendFilter;
16use num_traits::{Bounded, Zero};
17use reset::Reset;
18
19mod comparator;
20pub mod cost;
21pub mod reset;
22
23pub trait AStarNode: Sized + Ord + Debug + Display {
27 type Identifier: Debug + Clone + Eq + Hash;
31
32 type EdgeType: Debug;
36
37 type Cost: AStarCost;
38
39 fn identifier(&self) -> &Self::Identifier;
41
42 fn cost(&self) -> Self::Cost;
46
47 fn a_star_lower_bound(&self) -> Self::Cost;
49
50 fn secondary_maximisable_score(&self) -> usize;
54
55 fn predecessor(&self) -> Option<&Self::Identifier>;
57
58 fn predecessor_edge_type(&self) -> Option<Self::EdgeType>;
60}
61
62pub trait AStarContext: Reset {
63 type Node: AStarNode;
65
66 fn create_root(&self) -> Self::Node;
68
69 fn generate_successors(&mut self, node: &Self::Node, output: &mut impl Extend<Self::Node>);
71
72 fn is_target(&self, node: &Self::Node) -> bool;
74
75 fn cost_limit(&self) -> Option<<Self::Node as AStarNode>::Cost>;
79
80 fn memory_limit(&self) -> Option<usize>;
84
85 fn is_label_setting(&self) -> bool {
93 true
94 }
95}
96
97#[derive(Debug, Default)]
98pub struct AStarPerformanceCounters {
99 pub opened_nodes: usize,
100 pub suboptimal_opened_nodes: usize,
102 pub closed_nodes: usize,
103}
104
105#[derive(Debug, PartialEq, Eq)]
106pub enum AStarState<NodeIdentifier, Cost> {
107 Empty,
109 Init,
111 Searching,
113 Terminated {
115 result: AStarResult<NodeIdentifier, Cost>,
116 },
117}
118
119#[derive(Debug)]
120pub struct AStar<Context: AStarContext> {
121 state: AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>,
122 context: Context,
123 closed_list: HashMap<
124 <Context::Node as AStarNode>::Identifier,
125 Context::Node,
126 DeterministicDefaultHasher,
127 >,
128 open_list: BinaryHeap<Context::Node, AStarNodeComparator>,
129 performance_counters: AStarPerformanceCounters,
130}
131
132#[derive(Debug)]
133pub struct AStarBuffers<NodeIdentifier, Node> {
134 closed_list: HashMap<NodeIdentifier, Node, DeterministicDefaultHasher>,
135 open_list: BinaryHeap<Node, AStarNodeComparator>,
136}
137
138#[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Hash)]
139#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
140#[cfg_attr(feature = "serde", serde(tag = "astar_result_type"))]
141pub enum AStarResult<NodeIdentifier, Cost> {
142 FoundTarget {
144 #[cfg_attr(feature = "serde", serde(skip))]
145 identifier: NodeIdentifier,
146 cost: Cost,
147 },
148
149 ExceededCostLimit { cost_limit: Cost },
151
152 ExceededMemoryLimit {
154 max_cost: Cost,
156 },
157
158 NoTarget,
160}
161
162struct BacktrackingIterator<'a_star, Context: AStarContext> {
163 a_star: &'a_star AStar<Context>,
164 current: <Context::Node as AStarNode>::Identifier,
165}
166
167struct BacktrackingIteratorWithCost<'a_star, Context: AStarContext> {
168 a_star: &'a_star AStar<Context>,
169 current: <Context::Node as AStarNode>::Identifier,
170}
171
172impl<Context: AStarContext> AStar<Context> {
173 pub fn new(context: Context) -> Self {
174 Self {
175 state: AStarState::Empty,
176 context,
177 closed_list: Default::default(),
178 open_list: BinaryHeap::from_vec(Vec::new()),
179 performance_counters: Default::default(),
180 }
181 }
182
183 pub fn new_with_buffers(
184 context: Context,
185 mut buffers: AStarBuffers<<Context::Node as AStarNode>::Identifier, Context::Node>,
186 ) -> Self {
187 buffers.closed_list.clear();
188 buffers.open_list.clear();
189 Self {
190 state: AStarState::Empty,
191 context,
192 closed_list: buffers.closed_list,
193 open_list: buffers.open_list,
194 performance_counters: Default::default(),
195 }
196 }
197
198 pub fn state(
199 &self,
200 ) -> &AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
201 {
202 &self.state
203 }
204
205 pub fn context(&self) -> &Context {
206 &self.context
207 }
208
209 pub fn into_context(self) -> Context {
210 self.context
211 }
212
213 pub fn into_buffers(
214 self,
215 ) -> AStarBuffers<<Context::Node as AStarNode>::Identifier, Context::Node> {
216 AStarBuffers {
217 closed_list: self.closed_list,
218 open_list: self.open_list,
219 }
220 }
221
222 pub fn closed_node(
223 &self,
224 node_identifier: &<Context::Node as AStarNode>::Identifier,
225 ) -> Option<&Context::Node> {
226 self.closed_list.get(node_identifier)
227 }
228
229 pub fn performance_counters(&self) -> &AStarPerformanceCounters {
230 &self.performance_counters
231 }
232
233 pub fn reset(&mut self) {
234 self.state = AStarState::Empty;
235 self.context.reset();
236 self.closed_list.clear();
237 self.open_list.clear();
238 self.performance_counters = Default::default();
239 }
240
241 pub fn initialise(&mut self) {
242 self.initialise_with(|context| context.create_root());
243 }
244
245 pub fn initialise_with(&mut self, node: impl FnOnce(&Context) -> Context::Node) {
246 assert_eq!(self.state, AStarState::Empty);
247
248 self.state = AStarState::Init;
249 self.open_list.push(node(&self.context));
250 }
251
252 pub fn search(
253 &mut self,
254 ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
255 {
256 self.search_until(|context, node| context.is_target(node))
257 }
258
259 pub fn search_until(
260 &mut self,
261 mut is_target: impl FnMut(&Context, &Context::Node) -> bool,
262 ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
263 {
264 assert!(matches!(
265 self.state,
266 AStarState::Init | AStarState::Searching | AStarState::Terminated { .. }
267 ));
268
269 let cost_limit = self
270 .context
271 .cost_limit()
272 .unwrap_or(<Context::Node as AStarNode>::Cost::max_value());
273 let mut applied_cost_limit = false;
274 let memory_limit = self.context.memory_limit().unwrap_or(usize::MAX);
275 let node_count_limit =
277 (memory_limit as f64 / std::mem::size_of::<Context::Node>() as f64 / 2.3).round()
278 as usize;
279
280 if self.open_list.is_empty() {
281 return AStarResult::NoTarget;
282 }
283
284 self.state = AStarState::Searching;
285
286 let mut last_node = None;
287 let mut target_identifier = None;
288 let mut target_cost = <Context::Node as AStarNode>::Cost::max_value();
289 let mut target_secondary_maximisable_score = 0;
290
291 loop {
292 let Some(node) = self.open_list.pop() else {
293 if last_node.is_none() {
294 unreachable!("Open list was empty.");
295 };
296 if applied_cost_limit {
297 self.state = AStarState::Terminated {
298 result: AStarResult::ExceededCostLimit { cost_limit },
299 };
300 return AStarResult::ExceededCostLimit { cost_limit };
301 } else {
302 self.state = AStarState::Terminated {
303 result: AStarResult::NoTarget,
304 };
305 return AStarResult::NoTarget;
306 }
307 };
308
309 if node.cost() + node.a_star_lower_bound() > cost_limit {
312 self.state = AStarState::Terminated {
313 result: AStarResult::ExceededCostLimit { cost_limit },
314 };
315 return AStarResult::ExceededCostLimit { cost_limit };
316 }
317
318 if self.closed_list.len() + self.open_list.len() > node_count_limit {
320 self.state = AStarState::Terminated {
321 result: AStarResult::ExceededMemoryLimit {
322 max_cost: node.cost(),
323 },
324 };
325 return AStarResult::ExceededMemoryLimit {
326 max_cost: node.cost(),
327 };
328 }
329
330 if node.cost() + node.a_star_lower_bound() > target_cost {
332 debug_assert!(!self.context.is_label_setting());
333 break;
334 }
335
336 last_node = Some(node.identifier().clone());
337
338 if let Some(previous_visit) = self.closed_list.get(node.identifier()) {
339 self.performance_counters.suboptimal_opened_nodes += 1;
340
341 if self.context.is_label_setting() {
342 debug_assert!(
344 previous_visit.cost() + previous_visit.a_star_lower_bound()
345 <= node.cost() + node.a_star_lower_bound(),
346 "{}",
347 {
348 use std::fmt::Write;
349 let mut previous_visit = previous_visit;
350 let mut node = &node;
351 let mut out = String::new();
352
353 writeln!(out, "previous_visit:").unwrap();
354 while let Some(predecessor) = previous_visit.predecessor() {
355 writeln!(out, "{previous_visit}").unwrap();
356 previous_visit = self.closed_list.get(predecessor).unwrap();
357 }
358
359 writeln!(out, "\nnode:").unwrap();
360 while let Some(predecessor) = node.predecessor() {
361 writeln!(out, "{node}").unwrap();
362 node = self.closed_list.get(predecessor).unwrap();
363 }
364
365 out
366 }
367 );
368
369 continue;
370 } else if AStarNodeComparator.compare(&node, previous_visit) != Ordering::Greater {
371 continue;
374 }
375 }
376
377 let open_nodes_without_new_successors = self.open_list.len();
378 self.context.generate_successors(
379 &node,
380 &mut ExtendFilter::new(&mut self.open_list, |node| {
381 let result = node.cost() + node.a_star_lower_bound() <= cost_limit;
382 applied_cost_limit = applied_cost_limit || !result;
383 result
384 }),
385 );
386 self.performance_counters.opened_nodes +=
387 self.open_list.len() - open_nodes_without_new_successors;
388
389 let is_target = is_target(&self.context, &node);
390 debug_assert!(!is_target || node.a_star_lower_bound().is_zero());
391
392 if is_target
393 && (node.cost() < target_cost
394 || (node.cost() == target_cost
395 && node.secondary_maximisable_score() > target_secondary_maximisable_score))
396 {
397 target_identifier = Some(node.identifier().clone());
398 target_cost = node.cost();
399 target_secondary_maximisable_score = node.secondary_maximisable_score();
400
401 if self.context.is_label_setting() {
402 let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
403 self.performance_counters.closed_nodes += 1;
404 debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
405 break;
406 }
407 }
408
409 let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
410 self.performance_counters.closed_nodes += 1;
411 debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
412 }
413
414 let Some(target_identifier) = target_identifier else {
415 debug_assert!(!self.context.is_label_setting());
416 self.state = AStarState::Terminated {
417 result: AStarResult::NoTarget,
418 };
419 return AStarResult::NoTarget;
420 };
421
422 let cost = self.closed_list.get(&target_identifier).unwrap().cost();
423 debug_assert_eq!(cost, target_cost);
424 self.state = AStarState::Terminated {
425 result: AStarResult::FoundTarget {
426 identifier: target_identifier.clone(),
427 cost,
428 },
429 };
430 AStarResult::FoundTarget {
431 identifier: target_identifier,
432 cost,
433 }
434 }
435
436 pub fn backtrack(
437 &self,
438 ) -> impl use<'_, Context> + Iterator<Item = <Context::Node as AStarNode>::EdgeType> {
439 let AStarState::Terminated {
440 result: AStarResult::FoundTarget { identifier, .. },
441 } = &self.state
442 else {
443 panic!("Cannot backtrack since no target was found.")
444 };
445
446 self.backtrack_from(identifier).unwrap()
447 }
448
449 pub fn backtrack_with_costs(
454 &self,
455 ) -> impl use<'_, Context>
456 + Iterator<
457 Item = (
458 <Context::Node as AStarNode>::EdgeType,
459 <Context::Node as AStarNode>::Cost,
460 ),
461 > {
462 let AStarState::Terminated {
463 result: AStarResult::FoundTarget { identifier, .. },
464 } = &self.state
465 else {
466 panic!("Cannot backtrack since no target was found.")
467 };
468
469 self.backtrack_with_costs_from(identifier).unwrap()
470 }
471
472 pub fn backtrack_from(
473 &self,
474 identifier: &<Context::Node as AStarNode>::Identifier,
475 ) -> Option<impl use<'_, Context> + Iterator<Item = <Context::Node as AStarNode>::EdgeType>>
476 {
477 if self.closed_list.contains_key(identifier) {
478 Some(BacktrackingIterator {
479 a_star: self,
480 current: identifier.clone(),
481 })
482 } else {
483 None
484 }
485 }
486
487 #[allow(clippy::type_complexity)]
492 pub fn backtrack_with_costs_from(
493 &self,
494 identifier: &<Context::Node as AStarNode>::Identifier,
495 ) -> Option<
496 impl use<'_, Context>
497 + Iterator<
498 Item = (
499 <Context::Node as AStarNode>::EdgeType,
500 <Context::Node as AStarNode>::Cost,
501 ),
502 >,
503 > {
504 if self.closed_list.contains_key(identifier) {
505 Some(BacktrackingIteratorWithCost {
506 a_star: self,
507 current: identifier.clone(),
508 })
509 } else {
510 None
511 }
512 }
513}
514
515impl<NodeIdentifier, Cost: Copy> AStarResult<NodeIdentifier, Cost> {
516 pub fn cost(&self) -> Cost {
520 match self {
521 Self::FoundTarget { cost, .. } => *cost,
522 Self::ExceededCostLimit { cost_limit } => *cost_limit,
523 Self::ExceededMemoryLimit { max_cost } => *max_cost,
524 Self::NoTarget => panic!("AStarResult has no costs"),
525 }
526 }
527
528 pub fn without_node_identifier(&self) -> AStarResult<(), Cost> {
529 match *self {
530 Self::FoundTarget { cost, .. } => AStarResult::FoundTarget {
531 identifier: (),
532 cost,
533 },
534 Self::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit { cost_limit },
535 Self::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit { max_cost },
536 Self::NoTarget => AStarResult::NoTarget,
537 }
538 }
539}
540
541impl<NodeIdentifier: Clone, Cost> AStarResult<NodeIdentifier, Cost> {
542 pub fn transform_cost<TargetCost>(
543 &self,
544 transform: impl Fn(&Cost) -> TargetCost,
545 ) -> AStarResult<NodeIdentifier, TargetCost> {
546 match self {
547 AStarResult::FoundTarget { identifier, cost } => AStarResult::FoundTarget {
548 identifier: identifier.clone(),
549 cost: transform(cost),
550 },
551 AStarResult::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit {
552 cost_limit: transform(cost_limit),
553 },
554 AStarResult::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit {
555 max_cost: transform(max_cost),
556 },
557 AStarResult::NoTarget => AStarResult::NoTarget,
558 }
559 }
560}
561
562impl<Context: AStarContext> Iterator for BacktrackingIterator<'_, Context> {
563 type Item = <Context::Node as AStarNode>::EdgeType;
564
565 fn next(&mut self) -> Option<Self::Item> {
566 let current = self.a_star.closed_list.get(&self.current).unwrap();
567
568 if let Some(predecessor) = current.predecessor().cloned() {
569 let predecessor_edge_type = current.predecessor_edge_type().unwrap();
570 self.current = predecessor;
571 Some(predecessor_edge_type)
572 } else {
573 None
574 }
575 }
576}
577
578impl<Context: AStarContext> Iterator for BacktrackingIteratorWithCost<'_, Context> {
579 type Item = (
580 <Context::Node as AStarNode>::EdgeType,
581 <Context::Node as AStarNode>::Cost,
582 );
583
584 fn next(&mut self) -> Option<Self::Item> {
585 let current = self.a_star.closed_list.get(&self.current).unwrap();
586 let cost = current.cost();
587
588 if let Some(predecessor) = current.predecessor().cloned() {
589 let predecessor_edge_type = current.predecessor_edge_type().unwrap();
590 self.current = predecessor;
591 Some((predecessor_edge_type, cost))
592 } else {
593 None
594 }
595 }
596}
597
598impl<NodeIdentifier, Node: AStarNode> Default for AStarBuffers<NodeIdentifier, Node> {
599 fn default() -> Self {
600 Self {
601 closed_list: Default::default(),
602 open_list: BinaryHeap::from_vec(Vec::new()),
603 }
604 }
605}
606
607impl<NodeIdentifier, Cost: Display> Display for AStarResult<NodeIdentifier, Cost> {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 match self {
610 AStarResult::FoundTarget { cost, .. } => write!(f, "Reached target with cost {cost}"),
611 AStarResult::ExceededCostLimit { cost_limit } => {
612 write!(f, "Exceeded cost limit of {cost_limit}")
613 }
614 AStarResult::ExceededMemoryLimit { max_cost } => write!(
615 f,
616 "Exceeded memory limit, but reached a maximum cost of {max_cost}"
617 ),
618 AStarResult::NoTarget => write!(f, "Found no target"),
619 }
620 }
621}
622
623impl<NodeIdentifier, Cost> Default for AStarResult<NodeIdentifier, Cost> {
624 fn default() -> Self {
625 Self::NoTarget
626 }
627}
628
629impl<T: AStarNode> AStarNode for Box<T> {
630 type Identifier = <T as AStarNode>::Identifier;
631
632 type EdgeType = <T as AStarNode>::EdgeType;
633
634 type Cost = <T as AStarNode>::Cost;
635
636 fn identifier(&self) -> &Self::Identifier {
637 <T as AStarNode>::identifier(self)
638 }
639
640 fn cost(&self) -> Self::Cost {
641 <T as AStarNode>::cost(self)
642 }
643
644 fn a_star_lower_bound(&self) -> Self::Cost {
645 <T as AStarNode>::a_star_lower_bound(self)
646 }
647
648 fn secondary_maximisable_score(&self) -> usize {
649 <T as AStarNode>::secondary_maximisable_score(self)
650 }
651
652 fn predecessor(&self) -> Option<&Self::Identifier> {
653 <T as AStarNode>::predecessor(self)
654 }
655
656 fn predecessor_edge_type(&self) -> Option<Self::EdgeType> {
657 <T as AStarNode>::predecessor_edge_type(self)
658 }
659}