digraph_rs/analyzer/
astar.rs

1use crate::DiGraph;
2
3use super::min_weight::{MinWeight, Score};
4use std::collections::hash_map::Entry::Occupied;
5use std::collections::hash_map::Entry::Vacant;
6use std::convert::identity;
7use std::{
8    collections::{BinaryHeap, HashMap},
9    hash::Hash,
10    ops::Add,
11};
12
13#[derive(Debug)]
14pub struct MinPathStrict<NId>
15where
16    NId: Eq + Hash + Clone,
17{
18    path: HashMap<NId, NId>,
19    start: NId,
20    target: NId,
21}
22
23impl<NId> MinPathStrict<NId>
24where
25    NId: Eq + Hash + Clone,
26{
27    fn path(&self) -> Vec<NId> {
28        if self.path.is_empty() {
29            vec![]
30        } else {
31            let mut path = Vec::new();
32            let mut step = Some(self.target.clone());
33
34            while let Some(s) = step {
35                path.push(s.clone());
36                if s == self.start {
37                    break;
38                }
39                step = self.path.get(&s).cloned();
40            }
41
42            path.reverse();
43            path
44        }
45    }
46}
47
48#[derive(Debug)]
49pub struct AStarPath<'a, NId, NL, EL>
50where
51    NId: Eq + Hash + Clone,
52{
53    graph: &'a DiGraph<NId, NL, EL>,
54}
55
56impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
57where
58    NId: Eq + Hash + Clone,
59{
60    pub fn on_edge_custom<H, E, ScoreV>(
61        &self,
62        start: NId,
63        target: NId,
64        heuristic: H,
65        edge_w: E,
66    ) -> MinPathStrict<NId>
67    where
68        H: Fn(&NId) -> ScoreV,
69        E: Fn(EL) -> ScoreV,
70        ScoreV: Ord + Add<Output = ScoreV> + Clone,
71        EL: Clone,
72    {
73        let mut traverse: BinaryHeap<MinWeight<NId, ScoreV>> = BinaryHeap::new();
74        let mut path: HashMap<NId, NId> = HashMap::new();
75        let mut scores: HashMap<&NId, Score<ScoreV>> =
76            HashMap::from_iter(self.graph.nodes.keys().map(|k| (k, Score::Inf)));
77        let mut est_scores: HashMap<&NId, Score<ScoreV>> = HashMap::new();
78
79        scores.insert(&start, Score::Zero);
80        traverse.push(MinWeight(&start, Score::Value(heuristic(&start))));
81
82        while let Some(MinWeight(current, curr_est_score)) = traverse.pop() {
83            if current == &target {
84                return MinPathStrict {
85                    path,
86                    start,
87                    target,
88                };
89            }
90
91            match est_scores.entry(current) {
92                Occupied(mut entry) => {
93                    // If the node has been visited with an equal or lower score, then skip.
94                    if *entry.get() <= curr_est_score {
95                        continue;
96                    }
97                    entry.insert(curr_est_score);
98                }
99                Vacant(entry) => {
100                    entry.insert(curr_est_score);
101                }
102            }
103
104            if let Some(ss) = self.graph.edges.get(current) {
105                let current_score = scores.get(current).unwrap().clone();
106                for (to, el) in ss {
107                    let next_score = scores.get(to).unwrap().clone();
108                    let tentative_score = current_score.clone() + Score::Value(edge_w(el.clone()));
109                    if tentative_score < next_score {
110                        path.insert(to.clone(), current.clone());
111                        scores.insert(to, tentative_score.clone());
112                        traverse.push(MinWeight(
113                            to,
114                            tentative_score + Score::Value(heuristic(&to)),
115                        ))
116                    }
117                }
118            }
119        }
120
121        MinPathStrict {
122            path,
123            start,
124            target,
125        }
126    }
127}
128
129impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
130where
131    NId: Eq + Hash + Clone,
132    EL: Ord + Add<Output = EL> + Clone,
133{
134    pub fn on_edge<H>(&self, start: NId, target: NId, heuristic: H) -> MinPathStrict<NId>
135    where
136        H: Fn(&NId) -> EL,
137    {
138        self.on_edge_custom(start, target, heuristic, identity)
139    }
140}
141
142impl<'a, NId, NL, EL> AStarPath<'a, NId, NL, EL>
143where
144    NId: Eq + Hash + Clone,
145{
146    pub fn new(graph: &'a DiGraph<NId, NL, EL>) -> Self {
147        Self { graph }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::AStarPath;
154    use crate::analyzer::dijkstra::DijkstraPath;
155    use crate::analyzer::dijkstra::MinPathProcessor;
156    use crate::DiGraph;
157    use crate::EmptyPayload;
158    use crate::{digraph, extend_edges, extend_nodes};
159    use std::convert::identity;
160
161    #[test]
162    fn simple_test() {
163        let graph = digraph!((_,_,usize) => [1,2,3,4,5,6,7,8,9,10,11,] => {
164           1 => [(2,1),(3,1)];
165           2 => (4,2);
166           3 => (5,3);
167           [4,5] => (6,1);
168           5 => (11,4);
169           6 => [(7,1),(1,1)];
170           7 => [(8,1),(9,2),(10,3)];
171           [8,9,10] => (11,1)
172
173        });
174
175        let astar = AStarPath::new(&graph);
176
177        let astar_res = astar.on_edge(1, 11, |from| 0).path();
178        let dijkstra_res = DijkstraPath::new(&graph).on_edge(1).trail(&11).unwrap();
179
180        assert_eq!(astar_res, dijkstra_res);
181    }
182}