oxicuda_graphalg/shortest_path/
a_star.rs1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::error::{GraphalgError, GraphalgResult};
7use crate::repr::weighted_graph::WeightedGraph;
8
9#[derive(Debug, Clone)]
10pub struct AStarOutput {
11 pub dist: f64,
12 pub path: Vec<usize>,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16struct Item {
17 f: f64,
18 g: f64,
19 node: usize,
20}
21impl Eq for Item {}
22impl PartialOrd for Item {
23 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24 Some(self.cmp(other))
25 }
26}
27impl Ord for Item {
28 fn cmp(&self, other: &Self) -> Ordering {
29 other
30 .f
31 .partial_cmp(&self.f)
32 .unwrap_or(Ordering::Equal)
33 .then_with(|| other.node.cmp(&self.node))
34 }
35}
36
37pub fn a_star<F>(
39 graph: &WeightedGraph,
40 source: usize,
41 target: usize,
42 heuristic: F,
43) -> GraphalgResult<AStarOutput>
44where
45 F: Fn(usize) -> f64,
46{
47 if source >= graph.n || target >= graph.n {
48 return Err(GraphalgError::SourceOutOfRange {
49 node: source.max(target),
50 n: graph.n,
51 });
52 }
53 let mut gscore = vec![f64::INFINITY; graph.n];
54 let mut parent = vec![usize::MAX; graph.n];
55 gscore[source] = 0.0;
56 let mut heap = BinaryHeap::new();
57 heap.push(Item {
58 f: heuristic(source),
59 g: 0.0,
60 node: source,
61 });
62 while let Some(Item {
63 f: _,
64 g: gu,
65 node: u,
66 }) = heap.pop()
67 {
68 if u == target {
69 let mut path = Vec::new();
71 let mut cur = target;
72 while cur != source {
73 path.push(cur);
74 let p = parent[cur];
75 if p == usize::MAX {
76 return Err(GraphalgError::NumericalInstability(
77 "broken parent in A*".to_string(),
78 ));
79 }
80 cur = p;
81 }
82 path.push(source);
83 path.reverse();
84 return Ok(AStarOutput { dist: gu, path });
85 }
86 if gu > gscore[u] {
87 continue;
88 }
89 for &(v, w) in graph.neighbors(u)? {
90 if w < 0.0 {
91 return Err(GraphalgError::NegativeWeight {
92 edge: (u, v),
93 weight: w,
94 });
95 }
96 let cand = gu + w;
97 if cand < gscore[v] {
98 gscore[v] = cand;
99 parent[v] = u;
100 let fv = cand + heuristic(v);
101 heap.push(Item {
102 f: fv,
103 g: cand,
104 node: v,
105 });
106 }
107 }
108 }
109 Err(GraphalgError::NoSolution(format!(
110 "A* failed to reach target {target}"
111 )))
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 fn graph_4() -> WeightedGraph {
119 let mut g = WeightedGraph::new(4);
120 g.add_edge(0, 1, 1.0).expect("ok");
121 g.add_edge(0, 2, 4.0).expect("ok");
122 g.add_edge(1, 2, 2.0).expect("ok");
123 g.add_edge(1, 3, 5.0).expect("ok");
124 g.add_edge(2, 3, 1.0).expect("ok");
125 g
126 }
127
128 #[test]
129 fn a_star_zero_heuristic_eq_dijkstra() {
130 let g = graph_4();
131 let out = a_star(&g, 0, 3, |_| 0.0).expect("ok");
132 assert!((out.dist - 4.0).abs() < 1e-12);
133 }
134
135 #[test]
136 fn a_star_with_heuristic() {
137 let g = graph_4();
138 let out = a_star(&g, 0, 3, |_| 0.0).expect("ok");
140 assert_eq!(*out.path.last().expect("ok"), 3usize);
141 }
142
143 #[test]
144 fn a_star_no_path() {
145 let g = WeightedGraph::new(3);
146 assert!(a_star(&g, 0, 2, |_| 0.0).is_err());
147 }
148}