generic_a_star/
lib.rs

1#![forbid(clippy::mod_module_files)]
2
3use std::{
4    cmp::Ordering,
5    collections::HashMap,
6    fmt::{Debug, Display},
7    hash::Hash,
8};
9
10use binary_heap_plus::BinaryHeap;
11use comparator::AStarNodeComparator;
12use compare::Compare;
13use cost::AStarCost;
14use deterministic_default_hasher::DeterministicDefaultHasher;
15use extend_map::ExtendFilter;
16use num_traits::{Bounded, Zero};
17use reset::Reset;
18
19mod comparator;
20pub mod cost;
21pub mod reset;
22
23/// A node of the A* graph.
24/// The node must implement [`Ord`], ordering it by its cost plus A* cost, ascending.
25/// The graph defined by the node type must be cycle-free.
26pub trait AStarNode: Sized + Ord + Debug + Display {
27    /// A unique identifier of the node.
28    ///
29    /// For example, in case of traditional edit distance, this would be the tuple (i, j) indicating which alignment matrix cell this node belongs to.
30    type Identifier: Debug + Clone + Eq + Hash;
31
32    /// The type collecting possible edge types.
33    ///
34    /// These are used when backtracking a solution.
35    type EdgeType: Debug;
36
37    type Cost: AStarCost;
38
39    /// Returns the identifier of this node.
40    fn identifier(&self) -> &Self::Identifier;
41
42    /// Returns the cost of this node.
43    ///
44    /// This is the cost measured from the root node, and does NOT include the A* lower bound.
45    fn cost(&self) -> Self::Cost;
46
47    /// Returns the A* lower bound of this node.
48    fn a_star_lower_bound(&self) -> Self::Cost;
49
50    /// Returns a score that is used to order nodes of the same cost.
51    ///
52    /// This score should be maximised, which is done via complete search.
53    fn secondary_maximisable_score(&self) -> usize;
54
55    /// Returns the identifier of the predecessor of this node.
56    fn predecessor(&self) -> Option<&Self::Identifier>;
57
58    /// Returns the edge type used to reach this node from the predecessor, or `None` if this is a root node.
59    fn predecessor_edge_type(&self) -> Option<Self::EdgeType>;
60}
61
62pub trait AStarContext: Reset {
63    /// The node type used by the A* algorithm.
64    type Node: AStarNode;
65
66    /// Create the root node of the A* graph.
67    fn create_root(&self) -> Self::Node;
68
69    /// Generate the successors of this node.
70    fn generate_successors(&mut self, node: &Self::Node, output: &mut impl Extend<Self::Node>);
71
72    /// Returns true if this node is a target node of the A* graph.
73    fn is_target(&self, node: &Self::Node) -> bool;
74
75    /// Returns the maximum cost that the target node is allowed to have.
76    ///
77    /// If no target is found with this cost or lower, then [`AStarResult::ExceededCostLimit`] is returned.
78    fn cost_limit(&self) -> Option<<Self::Node as AStarNode>::Cost>;
79
80    /// An approximate memory limit for the aligner in bytes.
81    ///
82    /// If it is exceeded, then [`AStarResult::ExceededMemoryLimit`] is returned
83    fn memory_limit(&self) -> Option<usize>;
84
85    /// Returns true if the nodes are generated in a label-setting manner.
86    ///
87    /// Label setting means that once a node has been closed, it will never be opened at a smaller cost.
88    /// On the contrary, if this is set to false, then nodes are allowed to be closed multiple times.
89    /// This results in a worse performance, but allows for example to handle negative costs.
90    ///
91    /// This method returns `true` in its default implementation.
92    fn is_label_setting(&self) -> bool {
93        true
94    }
95}
96
97#[derive(Debug, Default)]
98pub struct AStarPerformanceCounters {
99    pub opened_nodes: usize,
100    /// Opened nodes that do not have optimal costs.
101    pub suboptimal_opened_nodes: usize,
102    pub closed_nodes: usize,
103}
104
105#[derive(Debug, PartialEq, Eq)]
106pub enum AStarState<NodeIdentifier, Cost> {
107    /// The algorithm was just created or reset.
108    Empty,
109    /// The algorithm was just initialised.
110    Init,
111    /// The algorithm is searching for a target node.
112    Searching,
113    /// The algorithm terminated.
114    Terminated {
115        result: AStarResult<NodeIdentifier, Cost>,
116    },
117}
118
119#[derive(Debug)]
120pub struct AStar<Context: AStarContext> {
121    state: AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>,
122    context: Context,
123    closed_list: HashMap<
124        <Context::Node as AStarNode>::Identifier,
125        Context::Node,
126        DeterministicDefaultHasher,
127    >,
128    open_list: BinaryHeap<Context::Node, AStarNodeComparator>,
129    performance_counters: AStarPerformanceCounters,
130}
131
132#[derive(Debug)]
133pub struct AStarBuffers<NodeIdentifier, Node> {
134    closed_list: HashMap<NodeIdentifier, Node, DeterministicDefaultHasher>,
135    open_list: BinaryHeap<Node, AStarNodeComparator>,
136}
137
138#[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Hash)]
139#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
140#[cfg_attr(feature = "serde", serde(tag = "astar_result_type"))]
141pub enum AStarResult<NodeIdentifier, Cost> {
142    /// The algorithm has found a target node.
143    FoundTarget {
144        #[cfg_attr(feature = "serde", serde(skip))]
145        identifier: NodeIdentifier,
146        cost: Cost,
147    },
148
149    /// The algorithm terminated before finding a target because the cost limit was reached.
150    ExceededCostLimit { cost_limit: Cost },
151
152    /// The algorithm termianted before finding a target because the memory limit was reached.
153    ExceededMemoryLimit {
154        /// The maximum cost reached before reaching the memory limit.
155        max_cost: Cost,
156    },
157
158    /// The algorithm terminated, but did not find a target.
159    NoTarget,
160}
161
162struct BacktrackingIterator<'a_star, Context: AStarContext> {
163    a_star: &'a_star AStar<Context>,
164    current: <Context::Node as AStarNode>::Identifier,
165}
166
167struct BacktrackingIteratorWithCost<'a_star, Context: AStarContext> {
168    a_star: &'a_star AStar<Context>,
169    current: <Context::Node as AStarNode>::Identifier,
170}
171
172impl<Context: AStarContext> AStar<Context> {
173    pub fn new(context: Context) -> Self {
174        Self {
175            state: AStarState::Empty,
176            context,
177            closed_list: Default::default(),
178            open_list: BinaryHeap::from_vec(Vec::new()),
179            performance_counters: Default::default(),
180        }
181    }
182
183    pub fn new_with_buffers(
184        context: Context,
185        mut buffers: AStarBuffers<<Context::Node as AStarNode>::Identifier, Context::Node>,
186    ) -> Self {
187        buffers.closed_list.clear();
188        buffers.open_list.clear();
189        Self {
190            state: AStarState::Empty,
191            context,
192            closed_list: buffers.closed_list,
193            open_list: buffers.open_list,
194            performance_counters: Default::default(),
195        }
196    }
197
198    pub fn state(
199        &self,
200    ) -> &AStarState<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
201    {
202        &self.state
203    }
204
205    pub fn context(&self) -> &Context {
206        &self.context
207    }
208
209    pub fn into_context(self) -> Context {
210        self.context
211    }
212
213    pub fn into_buffers(
214        self,
215    ) -> AStarBuffers<<Context::Node as AStarNode>::Identifier, Context::Node> {
216        AStarBuffers {
217            closed_list: self.closed_list,
218            open_list: self.open_list,
219        }
220    }
221
222    pub fn closed_node(
223        &self,
224        node_identifier: &<Context::Node as AStarNode>::Identifier,
225    ) -> Option<&Context::Node> {
226        self.closed_list.get(node_identifier)
227    }
228
229    pub fn performance_counters(&self) -> &AStarPerformanceCounters {
230        &self.performance_counters
231    }
232
233    pub fn reset(&mut self) {
234        self.state = AStarState::Empty;
235        self.context.reset();
236        self.closed_list.clear();
237        self.open_list.clear();
238        self.performance_counters = Default::default();
239    }
240
241    pub fn initialise(&mut self) {
242        self.initialise_with(|context| context.create_root());
243    }
244
245    pub fn initialise_with(&mut self, node: impl FnOnce(&Context) -> Context::Node) {
246        assert_eq!(self.state, AStarState::Empty);
247
248        self.state = AStarState::Init;
249        self.open_list.push(node(&self.context));
250    }
251
252    pub fn search(
253        &mut self,
254    ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
255    {
256        self.search_until(|context, node| context.is_target(node))
257    }
258
259    pub fn search_until(
260        &mut self,
261        mut is_target: impl FnMut(&Context, &Context::Node) -> bool,
262    ) -> AStarResult<<Context::Node as AStarNode>::Identifier, <Context::Node as AStarNode>::Cost>
263    {
264        assert!(matches!(
265            self.state,
266            AStarState::Init | AStarState::Searching | AStarState::Terminated { .. }
267        ));
268
269        let cost_limit = self
270            .context
271            .cost_limit()
272            .unwrap_or(<Context::Node as AStarNode>::Cost::max_value());
273        let mut applied_cost_limit = false;
274        let memory_limit = self.context.memory_limit().unwrap_or(usize::MAX);
275        // The factor of 2.3 is determined empirically.
276        let node_count_limit =
277            (memory_limit as f64 / std::mem::size_of::<Context::Node>() as f64 / 2.3).round()
278                as usize;
279
280        if self.open_list.is_empty() {
281            return AStarResult::NoTarget;
282        }
283
284        self.state = AStarState::Searching;
285
286        let mut last_node = None;
287        let mut target_identifier = None;
288        let mut target_cost = <Context::Node as AStarNode>::Cost::max_value();
289        let mut target_secondary_maximisable_score = 0;
290
291        loop {
292            let Some(node) = self.open_list.pop() else {
293                if last_node.is_none() {
294                    unreachable!("Open list was empty.");
295                };
296                if applied_cost_limit {
297                    self.state = AStarState::Terminated {
298                        result: AStarResult::ExceededCostLimit { cost_limit },
299                    };
300                    return AStarResult::ExceededCostLimit { cost_limit };
301                } else {
302                    self.state = AStarState::Terminated {
303                        result: AStarResult::NoTarget,
304                    };
305                    return AStarResult::NoTarget;
306                }
307            };
308
309            // Check cost limit.
310            // Nodes are ordered by cost plus lower bound.
311            if node.cost() + node.a_star_lower_bound() > cost_limit {
312                self.state = AStarState::Terminated {
313                    result: AStarResult::ExceededCostLimit { cost_limit },
314                };
315                return AStarResult::ExceededCostLimit { cost_limit };
316            }
317
318            // Check memory limit.
319            if self.closed_list.len() + self.open_list.len() > node_count_limit {
320                self.state = AStarState::Terminated {
321                    result: AStarResult::ExceededMemoryLimit {
322                        max_cost: node.cost(),
323                    },
324                };
325                return AStarResult::ExceededMemoryLimit {
326                    max_cost: node.cost(),
327                };
328            }
329
330            // If label-correcting, abort when the first node more expensive than the cheapest target is visited.
331            if node.cost() + node.a_star_lower_bound() > target_cost {
332                debug_assert!(!self.context.is_label_setting());
333                break;
334            }
335
336            last_node = Some(node.identifier().clone());
337
338            if let Some(previous_visit) = self.closed_list.get(node.identifier()) {
339                self.performance_counters.suboptimal_opened_nodes += 1;
340
341                if self.context.is_label_setting() {
342                    // In label-setting mode, if we have already visited the node, we now must be visiting it with a higher or equal cost.
343                    debug_assert!(
344                        previous_visit.cost() + previous_visit.a_star_lower_bound()
345                            <= node.cost() + node.a_star_lower_bound(),
346                        "{}",
347                        {
348                            use std::fmt::Write;
349                            let mut previous_visit = previous_visit;
350                            let mut node = &node;
351                            let mut out = String::new();
352
353                            writeln!(out, "previous_visit:").unwrap();
354                            while let Some(predecessor) = previous_visit.predecessor() {
355                                writeln!(out, "{previous_visit}").unwrap();
356                                previous_visit = self.closed_list.get(predecessor).unwrap();
357                            }
358
359                            writeln!(out, "\nnode:").unwrap();
360                            while let Some(predecessor) = node.predecessor() {
361                                writeln!(out, "{node}").unwrap();
362                                node = self.closed_list.get(predecessor).unwrap();
363                            }
364
365                            out
366                        }
367                    );
368
369                    continue;
370                } else if AStarNodeComparator.compare(&node, previous_visit) != Ordering::Greater {
371                    // If we are label-correcting, we may still find a better node later on.
372                    // Skip if equal or worse.
373                    continue;
374                }
375            }
376
377            let open_nodes_without_new_successors = self.open_list.len();
378            self.context.generate_successors(
379                &node,
380                &mut ExtendFilter::new(&mut self.open_list, |node| {
381                    let result = node.cost() + node.a_star_lower_bound() <= cost_limit;
382                    applied_cost_limit = applied_cost_limit || !result;
383                    result
384                }),
385            );
386            self.performance_counters.opened_nodes +=
387                self.open_list.len() - open_nodes_without_new_successors;
388
389            let is_target = is_target(&self.context, &node);
390            debug_assert!(!is_target || node.a_star_lower_bound().is_zero());
391
392            if is_target
393                && (node.cost() < target_cost
394                    || (node.cost() == target_cost
395                        && node.secondary_maximisable_score() > target_secondary_maximisable_score))
396            {
397                target_identifier = Some(node.identifier().clone());
398                target_cost = node.cost();
399                target_secondary_maximisable_score = node.secondary_maximisable_score();
400
401                if self.context.is_label_setting() {
402                    let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
403                    self.performance_counters.closed_nodes += 1;
404                    debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
405                    break;
406                }
407            }
408
409            let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
410            self.performance_counters.closed_nodes += 1;
411            debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
412        }
413
414        let Some(target_identifier) = target_identifier else {
415            debug_assert!(!self.context.is_label_setting());
416            self.state = AStarState::Terminated {
417                result: AStarResult::NoTarget,
418            };
419            return AStarResult::NoTarget;
420        };
421
422        let cost = self.closed_list.get(&target_identifier).unwrap().cost();
423        debug_assert_eq!(cost, target_cost);
424        self.state = AStarState::Terminated {
425            result: AStarResult::FoundTarget {
426                identifier: target_identifier.clone(),
427                cost,
428            },
429        };
430        AStarResult::FoundTarget {
431            identifier: target_identifier,
432            cost,
433        }
434    }
435
436    pub fn backtrack(
437        &self,
438    ) -> impl use<'_, Context> + Iterator<Item = <Context::Node as AStarNode>::EdgeType> {
439        let AStarState::Terminated {
440            result: AStarResult::FoundTarget { identifier, .. },
441        } = &self.state
442        else {
443            panic!("Cannot backtrack since no target was found.")
444        };
445
446        self.backtrack_from(identifier).unwrap()
447    }
448
449    /// Backtrack from the target node to a root node.
450    ///
451    /// The elements of the iterator are a pair of an edge and the cost of the node that is reached by the edge.
452    /// The cost of the first node is never returned.
453    pub fn backtrack_with_costs(
454        &self,
455    ) -> impl use<'_, Context>
456    + Iterator<
457        Item = (
458            <Context::Node as AStarNode>::EdgeType,
459            <Context::Node as AStarNode>::Cost,
460        ),
461    > {
462        let AStarState::Terminated {
463            result: AStarResult::FoundTarget { identifier, .. },
464        } = &self.state
465        else {
466            panic!("Cannot backtrack since no target was found.")
467        };
468
469        self.backtrack_with_costs_from(identifier).unwrap()
470    }
471
472    pub fn backtrack_from(
473        &self,
474        identifier: &<Context::Node as AStarNode>::Identifier,
475    ) -> Option<impl use<'_, Context> + Iterator<Item = <Context::Node as AStarNode>::EdgeType>>
476    {
477        if self.closed_list.contains_key(identifier) {
478            Some(BacktrackingIterator {
479                a_star: self,
480                current: identifier.clone(),
481            })
482        } else {
483            None
484        }
485    }
486
487    /// Backtrack from a node to a root node.
488    ///
489    /// The elements of the iterator are a pair of an edge and the cost of the node that is reached by the edge.
490    /// The cost of the first node is never returned.
491    #[allow(clippy::type_complexity)]
492    pub fn backtrack_with_costs_from(
493        &self,
494        identifier: &<Context::Node as AStarNode>::Identifier,
495    ) -> Option<
496        impl use<'_, Context>
497        + Iterator<
498            Item = (
499                <Context::Node as AStarNode>::EdgeType,
500                <Context::Node as AStarNode>::Cost,
501            ),
502        >,
503    > {
504        if self.closed_list.contains_key(identifier) {
505            Some(BacktrackingIteratorWithCost {
506                a_star: self,
507                current: identifier.clone(),
508            })
509        } else {
510            None
511        }
512    }
513}
514
515impl<NodeIdentifier, Cost: Copy> AStarResult<NodeIdentifier, Cost> {
516    /// Returns the maximum cost of closed nodes reached during alignment.
517    ///
518    /// **Panics** if `self` is [`AStarResult::NoTarget`].
519    pub fn cost(&self) -> Cost {
520        match self {
521            Self::FoundTarget { cost, .. } => *cost,
522            Self::ExceededCostLimit { cost_limit } => *cost_limit,
523            Self::ExceededMemoryLimit { max_cost } => *max_cost,
524            Self::NoTarget => panic!("AStarResult has no costs"),
525        }
526    }
527
528    pub fn without_node_identifier(&self) -> AStarResult<(), Cost> {
529        match *self {
530            Self::FoundTarget { cost, .. } => AStarResult::FoundTarget {
531                identifier: (),
532                cost,
533            },
534            Self::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit { cost_limit },
535            Self::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit { max_cost },
536            Self::NoTarget => AStarResult::NoTarget,
537        }
538    }
539}
540
541impl<NodeIdentifier: Clone, Cost> AStarResult<NodeIdentifier, Cost> {
542    pub fn transform_cost<TargetCost>(
543        &self,
544        transform: impl Fn(&Cost) -> TargetCost,
545    ) -> AStarResult<NodeIdentifier, TargetCost> {
546        match self {
547            AStarResult::FoundTarget { identifier, cost } => AStarResult::FoundTarget {
548                identifier: identifier.clone(),
549                cost: transform(cost),
550            },
551            AStarResult::ExceededCostLimit { cost_limit } => AStarResult::ExceededCostLimit {
552                cost_limit: transform(cost_limit),
553            },
554            AStarResult::ExceededMemoryLimit { max_cost } => AStarResult::ExceededMemoryLimit {
555                max_cost: transform(max_cost),
556            },
557            AStarResult::NoTarget => AStarResult::NoTarget,
558        }
559    }
560}
561
562impl<Context: AStarContext> Iterator for BacktrackingIterator<'_, Context> {
563    type Item = <Context::Node as AStarNode>::EdgeType;
564
565    fn next(&mut self) -> Option<Self::Item> {
566        let current = self.a_star.closed_list.get(&self.current).unwrap();
567
568        if let Some(predecessor) = current.predecessor().cloned() {
569            let predecessor_edge_type = current.predecessor_edge_type().unwrap();
570            self.current = predecessor;
571            Some(predecessor_edge_type)
572        } else {
573            None
574        }
575    }
576}
577
578impl<Context: AStarContext> Iterator for BacktrackingIteratorWithCost<'_, Context> {
579    type Item = (
580        <Context::Node as AStarNode>::EdgeType,
581        <Context::Node as AStarNode>::Cost,
582    );
583
584    fn next(&mut self) -> Option<Self::Item> {
585        let current = self.a_star.closed_list.get(&self.current).unwrap();
586        let cost = current.cost();
587
588        if let Some(predecessor) = current.predecessor().cloned() {
589            let predecessor_edge_type = current.predecessor_edge_type().unwrap();
590            self.current = predecessor;
591            Some((predecessor_edge_type, cost))
592        } else {
593            None
594        }
595    }
596}
597
598impl<NodeIdentifier, Node: AStarNode> Default for AStarBuffers<NodeIdentifier, Node> {
599    fn default() -> Self {
600        Self {
601            closed_list: Default::default(),
602            open_list: BinaryHeap::from_vec(Vec::new()),
603        }
604    }
605}
606
607impl<NodeIdentifier, Cost: Display> Display for AStarResult<NodeIdentifier, Cost> {
608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609        match self {
610            AStarResult::FoundTarget { cost, .. } => write!(f, "Reached target with cost {cost}"),
611            AStarResult::ExceededCostLimit { cost_limit } => {
612                write!(f, "Exceeded cost limit of {cost_limit}")
613            }
614            AStarResult::ExceededMemoryLimit { max_cost } => write!(
615                f,
616                "Exceeded memory limit, but reached a maximum cost of {max_cost}"
617            ),
618            AStarResult::NoTarget => write!(f, "Found no target"),
619        }
620    }
621}
622
623impl<NodeIdentifier, Cost> Default for AStarResult<NodeIdentifier, Cost> {
624    fn default() -> Self {
625        Self::NoTarget
626    }
627}
628
629impl<T: AStarNode> AStarNode for Box<T> {
630    type Identifier = <T as AStarNode>::Identifier;
631
632    type EdgeType = <T as AStarNode>::EdgeType;
633
634    type Cost = <T as AStarNode>::Cost;
635
636    fn identifier(&self) -> &Self::Identifier {
637        <T as AStarNode>::identifier(self)
638    }
639
640    fn cost(&self) -> Self::Cost {
641        <T as AStarNode>::cost(self)
642    }
643
644    fn a_star_lower_bound(&self) -> Self::Cost {
645        <T as AStarNode>::a_star_lower_bound(self)
646    }
647
648    fn secondary_maximisable_score(&self) -> usize {
649        <T as AStarNode>::secondary_maximisable_score(self)
650    }
651
652    fn predecessor(&self) -> Option<&Self::Identifier> {
653        <T as AStarNode>::predecessor(self)
654    }
655
656    fn predecessor_edge_type(&self) -> Option<Self::EdgeType> {
657        <T as AStarNode>::predecessor_edge_type(self)
658    }
659}