bsp_pathfinding/astar/
mod.rs

1use core::slice;
2use std::{
3    collections::{BinaryHeap, HashSet},
4    ops::{Deref, DerefMut, RangeBounds},
5};
6
7use glam::Vec2;
8use slotmap::{secondary::Entry, Key, SecondaryMap};
9use smallvec::{Drain, SmallVec};
10
11use crate::{BSPTree, NodeIndex, Portal, PortalRef, Portals, TOLERANCE};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WayPoint {
15    point: Vec2,
16    node: NodeIndex,
17    portal: Option<PortalRef>,
18}
19
20impl Deref for WayPoint {
21    type Target = Vec2;
22
23    fn deref(&self) -> &Self::Target {
24        &self.point
25    }
26}
27
28impl WayPoint {
29    pub fn new(point: Vec2, node: NodeIndex, portal: Option<PortalRef>) -> Self {
30        Self {
31            point,
32            node,
33            portal,
34        }
35    }
36
37    /// Get the way point's point.
38    pub fn point(&self) -> Vec2 {
39        self.point
40    }
41
42    /// Get the way point's portal.
43    pub fn portal(&self) -> Option<PortalRef> {
44        self.portal
45    }
46}
47
48#[derive(Debug, Clone, Default)]
49pub struct Path {
50    points: SmallVec<[WayPoint; 8]>,
51}
52
53impl<'a> IntoIterator for &'a Path {
54    type Item = &'a WayPoint;
55
56    type IntoIter = slice::Iter<'a, WayPoint>;
57
58    fn into_iter(self) -> Self::IntoIter {
59        self.points.iter()
60    }
61}
62
63impl Path {
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    pub fn from_points(points: impl Into<SmallVec<[WayPoint; 8]>>) -> Self {
69        Self {
70            points: points.into(),
71        }
72    }
73
74    /// Get a reference to the path's points.
75    pub fn points(&self) -> &[WayPoint] {
76        self.points.as_ref()
77    }
78
79    pub fn push(&mut self, value: WayPoint) {
80        self.points.push(value)
81    }
82
83    pub fn append(&mut self, other: &mut Self) {
84        self.points.append(&mut other.points)
85    }
86
87    /// Creates a path using the euclidian path
88    pub fn euclidian(start: Vec2, end: Vec2) -> Path {
89        Path::from_points(vec![
90            WayPoint::new(start, NodeIndex::null(), None),
91            WayPoint::new(end, NodeIndex::null(), None),
92        ])
93    }
94
95    pub fn clear(&mut self) {
96        self.points.clear()
97    }
98
99    pub fn drain<R: RangeBounds<usize>>(&mut self, range: R) -> Drain<'_, [WayPoint; 8]> {
100        self.points.drain(range)
101    }
102}
103
104impl Deref for Path {
105    type Target = [WayPoint];
106
107    fn deref(&self) -> &Self::Target {
108        &self.points
109    }
110}
111
112impl DerefMut for Path {
113    fn deref_mut(&mut self) -> &mut Self::Target {
114        &mut self.points
115    }
116}
117
118#[derive(Debug, Copy, Clone, PartialEq)]
119struct Backtrace<'a> {
120    // Index to the portal
121    node: NodeIndex,
122    // The first side of the portal
123    point: Vec2,
124    portal: Option<Portal<'a>>,
125    prev: Option<NodeIndex>,
126    start_cost: f32,
127    total_cost: f32,
128}
129
130impl<'a> Backtrace<'a> {
131    fn start(node: NodeIndex, point: Vec2, heuristic: f32) -> Self {
132        Self {
133            node,
134            point,
135            portal: None,
136            prev: None,
137            start_cost: 0.0,
138            total_cost: heuristic,
139        }
140    }
141
142    fn new(portal: Portal<'a>, point: Vec2, prev: &Backtrace, heuristic: f32) -> Self {
143        let start_cost = prev.start_cost + point.distance(prev.point);
144        Self {
145            node: portal.dst(),
146            portal: Some(portal),
147            point,
148            prev: Some(prev.node),
149            start_cost,
150            total_cost: start_cost + heuristic,
151        }
152    }
153}
154
155// Order by lowest total_cost
156impl<'a> PartialOrd for Backtrace<'a> {
157    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
158        other.total_cost.partial_cmp(&self.total_cost)
159    }
160}
161
162impl<'a> Eq for Backtrace<'a> {}
163
164impl<'a> Ord for Backtrace<'a> {
165    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
166        other
167            .total_cost
168            .partial_cmp(&self.total_cost)
169            .unwrap_or(std::cmp::Ordering::Equal)
170    }
171}
172
173#[derive(Default, Debug, Clone, Copy, PartialEq)]
174pub struct SearchInfo {
175    pub agent_radius: f32,
176}
177
178pub fn astar<'a, F: Fn(Vec2, Vec2) -> f32>(
179    tree: &BSPTree,
180    portals: &Portals,
181    start: Vec2,
182    end: Vec2,
183    heuristic: F,
184    info: SearchInfo,
185    path: &'a mut Option<Path>,
186) -> Option<&'a mut Path> {
187    let mut open = BinaryHeap::new();
188    let start_node = tree.locate(start);
189    let end_node = tree.locate(end);
190
191    // // No path if start or end are covered
192    // if start_node.covered() || end_node.covered() {
193    //     return None;
194    // }
195
196    // Find matching start node
197    // if let Some(p) = path {
198    //     let inc_start = p.iter().position(|p| p.node == start_node.index());
199
200    //     // New end is in the same node as old end
201    //     if let (Some(start_idx), Some(last)) = (inc_start, p.last_mut()) {
202    //         assert_eq!(last.portal, None);
203    //         if last.node == end_node.index() {
204    //             last.point = end;
205    //             p.drain(0..start_idx);
206    //             p[0].point = start;
207
208    //             return path.as_mut();
209    //         }
210    //     }
211    // }
212
213    let start_node = start_node.index();
214    let end_node = end_node.index();
215
216    // Information of how a node was reached
217    let mut backtraces: SecondaryMap<_, Backtrace> = SecondaryMap::new();
218    let start = Backtrace::start(start_node, start, (heuristic)(start, end));
219
220    // Push the fist node
221    open.push(start);
222    backtraces.insert(start_node, start);
223
224    let mut closed = HashSet::new();
225
226    // Expand the node with the lowest total cost
227    while let Some(current) = open.pop() {
228        if closed.contains(&current.node) {
229            continue;
230        }
231
232        // End found
233        // Generate backtrace and terminate
234        if current.node == end_node {
235            let path = path.get_or_insert_with(|| Default::default());
236
237            backtrace(end, current.node, backtraces, path);
238            shorten(tree, portals, path, info.agent_radius);
239            resolve_clip(portals, path, info.agent_radius);
240
241            return Some(path);
242        }
243
244        let end_rel = end - current.point;
245
246        // Add all edges to the open list and update backtraces
247        let portals = portals.get(current.node).filter_map(|portal| {
248            let face = portal.apply_margin(info.agent_radius);
249            if portal.dst() == current.node
250                || face.length() < 2.0 * info.agent_radius
251                || closed.contains(&portal.dst())
252            {
253                return None;
254            }
255
256            assert_eq!(portal.src(), current.node);
257
258            // Distance to each of the nodes
259            let (p1, p2) = face.into_tuple();
260            let p1_dist = (heuristic)(p1, end);
261            let p2_dist = (heuristic)(p2, end);
262
263            let p = if portal.normal().dot(end_rel) > 0.0 {
264                portal.clip(current.point, end, info.agent_radius)
265            } else if p1_dist < p2_dist {
266                p1
267            } else {
268                p2
269            };
270
271            let backtrace = Backtrace::new(portal, p, &current, (heuristic)(p, end));
272
273            // Update backtrace
274            // If the cost to this node is lower than previosuly found,
275            // overwrite with the new backtrace.
276            match backtraces.entry(backtrace.node).unwrap() {
277                Entry::Occupied(mut val) => {
278                    if val.get().total_cost > backtrace.total_cost {
279                        val.insert(backtrace);
280                    } else {
281                        return None;
282                    }
283                }
284                Entry::Vacant(entry) => {
285                    entry.insert(backtrace);
286                }
287            }
288
289            Some(backtrace)
290        });
291
292        // Add the edges
293        open.extend(portals);
294
295        // The current node is now done and won't be revisited
296        assert!(closed.insert(current.node))
297    }
298
299    None
300}
301
302fn backtrace(
303    end: Vec2,
304    mut current: NodeIndex,
305    backtraces: SecondaryMap<NodeIndex, Backtrace>,
306    path: &mut Path,
307) {
308    path.clear();
309    path.push(WayPoint::new(end, current, None));
310    let mut prev = end;
311    loop {
312        // Backtrace backwards
313        let node = backtraces[current];
314
315        if path.len() < 2 || prev.distance_squared(node.point) > TOLERANCE {
316            path.push(WayPoint::new(
317                node.point,
318                node.node,
319                node.portal.as_ref().map(Portal::portal_ref),
320            ));
321        }
322
323        prev = node.point;
324
325        // Continue up the backtrace
326        if let Some(prev) = node.prev {
327            current = prev;
328        } else {
329            break;
330        }
331    }
332
333    path.reverse();
334}
335
336fn resolve_clip(portals: &Portals, path: &mut [WayPoint], margin: f32) {
337    if path.len() < 3 {
338        return;
339    }
340
341    let a = path[0];
342    let c = path[2];
343    let b = &mut path[1];
344
345    if let Some(portal) = b.portal {
346        let portal = portals.from_ref(portal);
347        let [p, q] = portal.face.vertices;
348        if (b.point().distance(p) < margin + TOLERANCE && portal.adjacent[0])
349            || (b.point().distance(q) < margin + TOLERANCE && portal.adjacent[1])
350        {
351            let normal = portal.normal();
352            let a_inc = (a.point - b.point)
353                .normalize_or_zero()
354                .perp_dot(normal)
355                .abs();
356
357            let c_inc = (c.point - b.point)
358                .normalize_or_zero()
359                .perp_dot(normal)
360                .abs();
361
362            b.point += normal * margin * (c_inc - a_inc)
363        }
364    }
365
366    resolve_clip(portals, &mut path[1..], margin)
367}
368
369fn shorten(tree: &BSPTree, portals: &Portals, path: &mut [WayPoint], agent_radius: f32) -> bool {
370    if path.len() < 3 {
371        return true;
372    }
373
374    let a = &path[0];
375    let b = &path[1];
376    let c = &path[2];
377
378    if let Some(portal) = b.portal {
379        let portal = portals.from_ref(portal);
380        // c was directly visible from a
381        if let Some(p) = portal.try_clip(a.point, c.point, agent_radius) {
382            let prev = b.point;
383
384            path[1].point = p;
385
386            // Try to shorten the next strip.
387            // If successful, retry shortening for this strip
388            if shorten(tree, portals, &mut path[1..], agent_radius)
389                && prev.distance_squared(p) > TOLERANCE
390            {
391                shorten(tree, portals, path, agent_radius);
392            }
393
394            return true;
395        }
396    }
397
398    shorten(tree, portals, &mut path[1..], agent_radius)
399}