1use crate::DiGraph;
2
3use super::min_weight::{MinWeight, Score};
4use std::collections::hash_map::Entry::Occupied;
5use std::collections::hash_map::Entry::Vacant;
6use std::convert::identity;
7use std::{
8 collections::{BinaryHeap, HashMap},
9 hash::Hash,
10 ops::Add,
11};
12
13#[derive(Debug)]
14pub struct MinPathStrict<NId>
15where
16 NId: Eq + Hash + Clone,
17{
18 path: HashMap<NId, NId>,
19 start: NId,
20 target: NId,
21}
22
23impl<NId> MinPathStrict<NId>
24where
25 NId: Eq + Hash + Clone,
26{
27 fn path(&self) -> Vec<NId> {
28 if self.path.is_empty() {
29 vec![]
30 } else {
31 let mut path = Vec::new();
32 let mut step = Some(self.target.clone());
33
34 while let Some(s) = step {
35 path.push(s.clone());
36 if s == self.start {
37 break;
38 }
39 step = self.path.get(&s).cloned();
40 }
41
42 path.reverse();
43 path
44 }
45 }
46}
47
48#[derive(Debug)]
49pub struct AStarPath<'a, NId, NL, EL>
50where
51 NId: Eq + Hash + Clone,
52{
53 graph: &'a DiGraph<NId, NL, EL>,
54}
55
56impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
57where
58 NId: Eq + Hash + Clone,
59{
60 pub fn on_edge_custom<H, E, ScoreV>(
61 &self,
62 start: NId,
63 target: NId,
64 heuristic: H,
65 edge_w: E,
66 ) -> MinPathStrict<NId>
67 where
68 H: Fn(&NId) -> ScoreV,
69 E: Fn(EL) -> ScoreV,
70 ScoreV: Ord + Add<Output = ScoreV> + Clone,
71 EL: Clone,
72 {
73 let mut traverse: BinaryHeap<MinWeight<NId, ScoreV>> = BinaryHeap::new();
74 let mut path: HashMap<NId, NId> = HashMap::new();
75 let mut scores: HashMap<&NId, Score<ScoreV>> =
76 HashMap::from_iter(self.graph.nodes.keys().map(|k| (k, Score::Inf)));
77 let mut est_scores: HashMap<&NId, Score<ScoreV>> = HashMap::new();
78
79 scores.insert(&start, Score::Zero);
80 traverse.push(MinWeight(&start, Score::Value(heuristic(&start))));
81
82 while let Some(MinWeight(current, curr_est_score)) = traverse.pop() {
83 if current == &target {
84 return MinPathStrict {
85 path,
86 start,
87 target,
88 };
89 }
90
91 match est_scores.entry(current) {
92 Occupied(mut entry) => {
93 if *entry.get() <= curr_est_score {
95 continue;
96 }
97 entry.insert(curr_est_score);
98 }
99 Vacant(entry) => {
100 entry.insert(curr_est_score);
101 }
102 }
103
104 if let Some(ss) = self.graph.edges.get(current) {
105 let current_score = scores.get(current).unwrap().clone();
106 for (to, el) in ss {
107 let next_score = scores.get(to).unwrap().clone();
108 let tentative_score = current_score.clone() + Score::Value(edge_w(el.clone()));
109 if tentative_score < next_score {
110 path.insert(to.clone(), current.clone());
111 scores.insert(to, tentative_score.clone());
112 traverse.push(MinWeight(
113 to,
114 tentative_score + Score::Value(heuristic(&to)),
115 ))
116 }
117 }
118 }
119 }
120
121 MinPathStrict {
122 path,
123 start,
124 target,
125 }
126 }
127}
128
129impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
130where
131 NId: Eq + Hash + Clone,
132 EL: Ord + Add<Output = EL> + Clone,
133{
134 pub fn on_edge<H>(&self, start: NId, target: NId, heuristic: H) -> MinPathStrict<NId>
135 where
136 H: Fn(&NId) -> EL,
137 {
138 self.on_edge_custom(start, target, heuristic, identity)
139 }
140}
141
142impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
143where
144 NId: Eq + Hash + Clone,
145{
146 pub fn new(graph: &'a DiGraph<NId, NL, EL>) -> Self {
147 Self { graph }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::AStarPath;
154 use crate::analyzer::dijkstra::DijkstraPath;
155 use crate::analyzer::dijkstra::MinPathProcessor;
156 use crate::DiGraph;
157 use crate::EmptyPayload;
158 use crate::{digraph, extend_edges, extend_nodes};
159 use std::convert::identity;
160
161 #[test]
162 fn simple_test() {
163 let graph = digraph!((_,_,usize) => [1,2,3,4,5,6,7,8,9,10,11,] => {
164 1 => [(2,1),(3,1)];
165 2 => (4,2);
166 3 => (5,3);
167 [4,5] => (6,1);
168 5 => (11,4);
169 6 => [(7,1),(1,1)];
170 7 => [(8,1),(9,2),(10,3)];
171 [8,9,10] => (11,1)
172
173 });
174
175 let astar = AStarPath::new(&graph);
176
177 let astar_res = astar.on_edge(1, 11, |from| 0).path();
178 let dijkstra_res = DijkstraPath::new(&graph).on_edge(1).trail(&11).unwrap();
179
180 assert_eq!(astar_res, dijkstra_res);
181 }
182}