Skip to main content

oxicuda_graphalg/shortest_path/
a_star.rs

1//! A* search with admissible heuristic.
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::error::{GraphalgError, GraphalgResult};
7use crate::repr::weighted_graph::WeightedGraph;
8
9#[derive(Debug, Clone)]
10pub struct AStarOutput {
11    pub dist: f64,
12    pub path: Vec<usize>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16struct Item {
17    f: f64,
18    g: f64,
19    node: usize,
20}
21impl Eq for Item {}
22impl PartialOrd for Item {
23    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24        Some(self.cmp(other))
25    }
26}
27impl Ord for Item {
28    fn cmp(&self, other: &Self) -> Ordering {
29        other
30            .f
31            .partial_cmp(&self.f)
32            .unwrap_or(Ordering::Equal)
33            .then_with(|| other.node.cmp(&self.node))
34    }
35}
36
37/// A* shortest path. `heuristic(node)` must be admissible (never overestimates).
38pub fn a_star<F>(
39    graph: &WeightedGraph,
40    source: usize,
41    target: usize,
42    heuristic: F,
43) -> GraphalgResult<AStarOutput>
44where
45    F: Fn(usize) -> f64,
46{
47    if source >= graph.n || target >= graph.n {
48        return Err(GraphalgError::SourceOutOfRange {
49            node: source.max(target),
50            n: graph.n,
51        });
52    }
53    let mut gscore = vec![f64::INFINITY; graph.n];
54    let mut parent = vec![usize::MAX; graph.n];
55    gscore[source] = 0.0;
56    let mut heap = BinaryHeap::new();
57    heap.push(Item {
58        f: heuristic(source),
59        g: 0.0,
60        node: source,
61    });
62    while let Some(Item {
63        f: _,
64        g: gu,
65        node: u,
66    }) = heap.pop()
67    {
68        if u == target {
69            // Reconstruct
70            let mut path = Vec::new();
71            let mut cur = target;
72            while cur != source {
73                path.push(cur);
74                let p = parent[cur];
75                if p == usize::MAX {
76                    return Err(GraphalgError::NumericalInstability(
77                        "broken parent in A*".to_string(),
78                    ));
79                }
80                cur = p;
81            }
82            path.push(source);
83            path.reverse();
84            return Ok(AStarOutput { dist: gu, path });
85        }
86        if gu > gscore[u] {
87            continue;
88        }
89        for &(v, w) in graph.neighbors(u)? {
90            if w < 0.0 {
91                return Err(GraphalgError::NegativeWeight {
92                    edge: (u, v),
93                    weight: w,
94                });
95            }
96            let cand = gu + w;
97            if cand < gscore[v] {
98                gscore[v] = cand;
99                parent[v] = u;
100                let fv = cand + heuristic(v);
101                heap.push(Item {
102                    f: fv,
103                    g: cand,
104                    node: v,
105                });
106            }
107        }
108    }
109    Err(GraphalgError::NoSolution(format!(
110        "A* failed to reach target {target}"
111    )))
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    fn graph_4() -> WeightedGraph {
119        let mut g = WeightedGraph::new(4);
120        g.add_edge(0, 1, 1.0).expect("ok");
121        g.add_edge(0, 2, 4.0).expect("ok");
122        g.add_edge(1, 2, 2.0).expect("ok");
123        g.add_edge(1, 3, 5.0).expect("ok");
124        g.add_edge(2, 3, 1.0).expect("ok");
125        g
126    }
127
128    #[test]
129    fn a_star_zero_heuristic_eq_dijkstra() {
130        let g = graph_4();
131        let out = a_star(&g, 0, 3, |_| 0.0).expect("ok");
132        assert!((out.dist - 4.0).abs() < 1e-12);
133    }
134
135    #[test]
136    fn a_star_with_heuristic() {
137        let g = graph_4();
138        // Trivial heuristic 0 still works
139        let out = a_star(&g, 0, 3, |_| 0.0).expect("ok");
140        assert_eq!(*out.path.last().expect("ok"), 3usize);
141    }
142
143    #[test]
144    fn a_star_no_path() {
145        let g = WeightedGraph::new(3);
146        assert!(a_star(&g, 0, 2, |_| 0.0).is_err());
147    }
148}