a_star_traitbased/
lib.rs

1use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
2use std::rc::Rc;
3
4pub trait PathGenerator {
5    fn generate_paths(&self, from_position: (usize, usize)) -> Vec<(usize, usize)>;
6    fn calculate_heuristic_cost(
7        &self,
8        position: (usize, usize),
9        target: (Option<usize>, Option<usize>),
10    ) -> usize;
11    fn calculate_cost(
12        &self,
13        current_position: (usize, usize),
14        next_position: (usize, usize),
15    ) -> usize;
16}
17
18enum NextNodeResult<T> {
19    Ok(T),
20    Finished,
21}
22
23pub struct AStar {
24    target: (Option<usize>, Option<usize>),
25    que: Vec<Node>,
26    closed_nodes: Vec<Rc<Node>>,
27}
28
29impl AStar {
30    fn new(target: (Option<usize>, Option<usize>)) -> Self {
31        Self {
32            target,
33            que: Vec::new(),
34            closed_nodes: Vec::new(),
35        }
36    }
37
38    pub fn run<T: PathGenerator>(
39        from_struct: &T,
40        start: (usize, usize),
41        target: (Option<usize>, Option<usize>),
42    ) -> Option<Vec<(usize, usize)>> {
43        // PathGenerator is used to build possible paths
44        let mut inst = Self::new(target);
45        let exposed_struct = from_struct;
46        inst.que.push(Node::new(
47            start,
48            exposed_struct.calculate_heuristic_cost(start, target),
49        ));
50        loop {
51            if inst.que.is_empty() {
52                return None; // no elements left therefor no fast way out
53            }
54            inst.que.sort();
55            let top = Rc::new(inst.que.remove(0));
56            let possible_paths = exposed_struct.generate_paths(top.position);
57            if !possible_paths.is_empty() {
58                for possible_path in possible_paths {
59                    if inst.pull_from_closed_by_position(possible_path).is_some() {
60                        continue;
61                    }
62                    match inst.create_new_node(
63                        Rc::clone(&top),
64                        possible_path,
65                        exposed_struct.calculate_cost(top.position, possible_path),
66                        exposed_struct.calculate_heuristic_cost(possible_path, inst.target),
67                    ) {
68                        NextNodeResult::Ok(node) => inst.que.push(node),
69                        NextNodeResult::Finished => {
70                            let mut path = inst.reconstruct_path(Rc::clone(&top));
71                            path.insert(0, possible_path);
72                            return Some(path);
73                        }
74                    }
75                }
76            }
77            inst.closed_nodes.push(Rc::clone(&top));
78        }
79    }
80
81    fn create_new_node(
82        &self,
83        old_node: Rc<Node>,
84        new_position: (usize, usize),
85        cost: usize,
86        heuristic_cost: usize,
87    ) -> NextNodeResult<Node> {
88        if self.target_is_reached(&new_position) {
89            return NextNodeResult::Finished;
90        }
91        let new_cost = cost + old_node.cost;
92        NextNodeResult::Ok(Node {
93            position: new_position,
94            comes_from: Some(old_node),
95            cost: new_cost,
96            total_cost: heuristic_cost + new_cost,
97        })
98    }
99
100    fn target_is_reached(&self, position: &(usize, usize)) -> bool {
101        if self.target.0.is_some() && self.target.0.unwrap() != position.0 {
102            return false;
103        }
104        if self.target.1.is_some() && self.target.1.unwrap() != position.1 {
105            return false;
106        }
107        true
108    }
109
110    fn reconstruct_path(&self, opt: Rc<Node>) -> Vec<(usize, usize)> {
111        let mut fastest_path = vec![opt.position];
112        let mut comes_from = opt.comes_from.as_ref();
113        loop {
114            if let Some(node) = comes_from {
115                fastest_path.push(node.position);
116                comes_from = node.comes_from.as_ref();
117            } else {
118                return fastest_path;
119            }
120        }
121    }
122
123    fn pull_from_closed_by_position(&self, position: (usize, usize)) -> Option<&Rc<Node>> {
124        self.closed_nodes.iter().find(|closed_node| closed_node.position == position)
125    }
126}
127
128#[derive(Eq, Debug)]
129struct Node {
130    position: (usize, usize),
131    cost: usize,
132    total_cost: usize,
133    comes_from: Option<Rc<Node>>,
134}
135
136impl Node {
137    fn new(position: (usize, usize), total_cost: usize) -> Self {
138        Self {
139            position,
140            comes_from: None,
141            cost: 0,
142            total_cost,
143        }
144    }
145}
146
147impl PartialEq for Node {
148    fn eq(&self, other: &Self) -> bool {
149        self.total_cost == other.total_cost
150    }
151}
152
153impl Ord for Node {
154    fn cmp(&self, other: &Self) -> Ordering {
155        self.total_cost.cmp(&other.total_cost)
156    }
157}
158
159impl PartialOrd for Node {
160    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161        Some(self.cmp(other))
162    }
163
164    fn ge(&self, other: &Self) -> bool {
165        self.total_cost >= other.total_cost
166    }
167    fn le(&self, other: &Self) -> bool {
168        self.total_cost <= other.total_cost
169    }
170    fn gt(&self, other: &Self) -> bool {
171        self.total_cost > other.total_cost
172    }
173    fn lt(&self, other: &Self) -> bool {
174        self.total_cost < other.total_cost
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use super::*;
181    #[test]
182    fn testrun() {
183        fn calc_usize_diff(x: usize, y: usize) -> usize {
184            if x > y {
185                return x - y;
186            }
187            y - x
188        }
189
190        struct Map {
191            blocks: Vec<(usize, usize)>,
192        }
193        impl Map {
194            fn path_is_possible(&self, possible_path: (usize, usize)) -> Option<(usize, usize)> {
195                if self.blocks.contains(&possible_path) {
196                    return None;
197                }
198                Some(possible_path)
199            }
200        }
201        impl PathGenerator for Map {
202            fn generate_paths(&self, from_position: (usize, usize)) -> Vec<(usize, usize)> {
203                let mut possible_paths: Vec<(usize, usize)> = Vec::new();
204
205                if from_position.0 != 0 && from_position.1 != 0 {
206                    for possible_path in [
207                        (from_position.0 - 1, from_position.1 - 1),
208                        (from_position.0, from_position.1 - 1),
209                        (from_position.0 - 1, from_position.1),
210                    ] {
211                        if let Some(path_) = self.path_is_possible(possible_path) {
212                            possible_paths.push(path_)
213                        }
214                    }
215                };
216                for possible_path in [
217                    (from_position.0 + 1, from_position.1 + 1),
218                    (from_position.0, from_position.1 + 1),
219                    (from_position.0 + 1, from_position.1),
220                ] {
221                    if let Some(path_) = self.path_is_possible(possible_path) {
222                        possible_paths.push(path_)
223                    }
224                }
225                possible_paths
226            }
227            #[allow(unused_variables)]
228            fn calculate_cost(
229                &self,
230                current_position: (usize, usize),
231                next_position: (usize, usize),
232            ) -> usize {
233                1
234            }
235            fn calculate_heuristic_cost(
236                &self,
237                position: (usize, usize),
238                target: (Option<usize>, Option<usize>),
239            ) -> usize {
240                if target.0.is_none() && target.1.is_none() {
241                    return 0;
242                }
243                if target.0.is_none() {
244                    return calc_usize_diff(target.1.unwrap(), position.1);
245                }
246                if target.1.is_none() {
247                    return calc_usize_diff(target.0.unwrap(), position.0);
248                }
249                f64::sqrt(
250                    ((calc_usize_diff(target.0.unwrap(), position.0) ^ 2)
251                        + (calc_usize_diff(target.1.unwrap(), position.1) ^ 2))
252                        as f64,
253                ) as usize
254            }
255        }
256
257        let map_fixture = Map {
258            blocks: vec![(2, 2)],
259        };
260        let path = AStar::run(&map_fixture, (0, 0), (Some(3), Some(3)));
261        assert_eq!(path.unwrap(), vec![(3, 3), (2, 3), (1, 2), (1, 1), (0, 0)])
262    }
263}