Skip to main content

proof_engine/pathfinding/
astar.rs

1// src/pathfinding/astar.rs
2// A* and pathfinding variants:
3//   - Generic A* over NodeId graph
4//   - Jump Point Search (JPS) for uniform-cost grid maps
5//   - Hierarchical A* with cluster-level precomputation
6//   - Flow fields for crowd simulation
7//   - Path caching with invalidation
8
9use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
10use std::cmp::Ordering;
11use std::f32;
12
13// ── Vec2 (local, avoids cross-module dep) ────────────────────────────────────
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16pub struct Vec2 {
17    pub x: f32,
18    pub y: f32,
19}
20
21impl Vec2 {
22    #[inline] pub fn new(x: f32, y: f32) -> Self { Self { x, y } }
23    #[inline] pub fn zero() -> Self { Self { x: 0.0, y: 0.0 } }
24    #[inline] pub fn dist(self, o: Self) -> f32 { ((self.x-o.x).powi(2)+(self.y-o.y).powi(2)).sqrt() }
25    #[inline] pub fn sub(self, o: Self) -> Self { Self::new(self.x-o.x, self.y-o.y) }
26    #[inline] pub fn add(self, o: Self) -> Self { Self::new(self.x+o.x, self.y+o.y) }
27    #[inline] pub fn scale(self, s: f32) -> Self { Self::new(self.x*s, self.y*s) }
28    #[inline] pub fn len(self) -> f32 { (self.x*self.x+self.y*self.y).sqrt() }
29    #[inline] pub fn norm(self) -> Self { let l=self.len(); if l<1e-9 {Self::zero()} else {self.scale(1.0/l)} }
30}
31
32// ── Node identifier ───────────────────────────────────────────────────────────
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
35pub struct NodeId(pub u32);
36
37// ── Generic A* graph trait ────────────────────────────────────────────────────
38
39/// Trait implemented by any graph that wants generic A*.
40pub trait AStarGraph {
41    type Cost: Copy + PartialOrd + std::ops::Add<Output = Self::Cost>;
42    fn zero_cost() -> Self::Cost;
43    fn max_cost() -> Self::Cost;
44    fn heuristic(&self, from: NodeId, to: NodeId) -> Self::Cost;
45    fn neighbors(&self, node: NodeId) -> Vec<(NodeId, Self::Cost)>;
46}
47
48/// Result of A* search.
49#[derive(Clone, Debug)]
50pub struct AStarResult {
51    pub path: Vec<NodeId>,
52    pub cost: f32,
53}
54
55pub struct AStarNode {
56    pub id:       NodeId,
57    pub position: Vec2,
58    pub walkable: bool,
59}
60
61// ── Priority entry ────────────────────────────────────────────────────────────
62
63#[derive(PartialEq)]
64struct PqEntry<C: PartialOrd> {
65    node: NodeId,
66    f:    C,
67}
68
69impl<C: PartialOrd> Eq for PqEntry<C> {}
70
71impl<C: PartialOrd> PartialOrd for PqEntry<C> {
72    fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
73}
74
75impl<C: PartialOrd> Ord for PqEntry<C> {
76    fn cmp(&self, other: &Self) -> Ordering {
77        other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
78    }
79}
80
81/// Run generic A* on any graph implementing AStarGraph.
82pub fn astar_search<G: AStarGraph>(
83    graph: &G,
84    start: NodeId,
85    goal: NodeId,
86) -> Option<AStarResult>
87where
88    G::Cost: std::fmt::Debug,
89    f32: From<G::Cost>,
90{
91    let mut open: BinaryHeap<PqEntry<G::Cost>> = BinaryHeap::new();
92    let mut came_from: HashMap<NodeId, NodeId> = HashMap::new();
93    let mut g_score: HashMap<NodeId, G::Cost> = HashMap::new();
94
95    g_score.insert(start, G::zero_cost());
96    open.push(PqEntry { node: start, f: graph.heuristic(start, goal) });
97
98    while let Some(PqEntry { node: current, .. }) = open.pop() {
99        if current == goal {
100            let path = reconstruct(start, goal, &came_from);
101            let cost = f32::from(*g_score.get(&goal).unwrap_or(&G::zero_cost()));
102            return Some(AStarResult { path, cost });
103        }
104        let cur_g = *g_score.get(&current).unwrap_or(&G::max_cost());
105        for (neighbor, edge_cost) in graph.neighbors(current) {
106            let tentative = cur_g + edge_cost;
107            if tentative < *g_score.get(&neighbor).unwrap_or(&G::max_cost()) {
108                came_from.insert(neighbor, current);
109                g_score.insert(neighbor, tentative);
110                let h = graph.heuristic(neighbor, goal);
111                open.push(PqEntry { node: neighbor, f: tentative + h });
112            }
113        }
114    }
115    None
116}
117
118fn reconstruct(start: NodeId, goal: NodeId, came_from: &HashMap<NodeId, NodeId>) -> Vec<NodeId> {
119    let mut path = Vec::new();
120    let mut cur = goal;
121    while cur != start {
122        path.push(cur);
123        match came_from.get(&cur) {
124            Some(&p) => cur = p,
125            None => break,
126        }
127    }
128    path.push(start);
129    path.reverse();
130    path
131}
132
133// ── Grid map ──────────────────────────────────────────────────────────────────
134
135/// A 2-D uniform grid map for JPS and flow fields.
136#[derive(Clone, Debug)]
137pub struct GridMap {
138    pub width:    usize,
139    pub height:   usize,
140    pub cells:    Vec<bool>,  // true = walkable
141    pub cell_size: f32,
142    pub origin:   Vec2,
143}
144
145impl GridMap {
146    pub fn new(width: usize, height: usize, cell_size: f32, origin: Vec2) -> Self {
147        Self {
148            width, height,
149            cells: vec![true; width * height],
150            cell_size, origin,
151        }
152    }
153
154    #[inline]
155    pub fn idx(&self, x: usize, y: usize) -> usize { y * self.width + x }
156
157    #[inline]
158    pub fn in_bounds(&self, x: i32, y: i32) -> bool {
159        x >= 0 && y >= 0 && (x as usize) < self.width && (y as usize) < self.height
160    }
161
162    #[inline]
163    pub fn walkable(&self, x: i32, y: i32) -> bool {
164        self.in_bounds(x, y) && self.cells[self.idx(x as usize, y as usize)]
165    }
166
167    pub fn set_walkable(&mut self, x: usize, y: usize, w: bool) {
168        let i = self.idx(x, y);
169        self.cells[i] = w;
170    }
171
172    /// Block a rectangular area.
173    pub fn block_rect(&mut self, x: usize, y: usize, w: usize, h: usize) {
174        for ry in y..((y+h).min(self.height)) {
175            for rx in x..((x+w).min(self.width)) {
176                let i = self.idx(rx, ry);
177                self.cells[i] = false;
178            }
179        }
180    }
181
182    pub fn node_id(&self, x: usize, y: usize) -> NodeId {
183        NodeId((y * self.width + x) as u32)
184    }
185
186    pub fn coords(&self, id: NodeId) -> (usize, usize) {
187        let i = id.0 as usize;
188        (i % self.width, i / self.width)
189    }
190
191    pub fn world_pos(&self, x: usize, y: usize) -> Vec2 {
192        Vec2::new(
193            self.origin.x + (x as f32 + 0.5) * self.cell_size,
194            self.origin.y + (y as f32 + 0.5) * self.cell_size,
195        )
196    }
197
198    pub fn grid_coords_for_world(&self, p: Vec2) -> Option<(usize, usize)> {
199        let gx = ((p.x - self.origin.x) / self.cell_size) as i32;
200        let gy = ((p.y - self.origin.y) / self.cell_size) as i32;
201        if self.in_bounds(gx, gy) {
202            Some((gx as usize, gy as usize))
203        } else {
204            None
205        }
206    }
207}
208
209// ── Jump Point Search ─────────────────────────────────────────────────────────
210
211/// JPS pathfinder for uniform-cost grid maps (8-directional movement).
212pub struct JpsPathfinder<'a> {
213    pub grid: &'a GridMap,
214}
215
216impl<'a> JpsPathfinder<'a> {
217    pub fn new(grid: &'a GridMap) -> Self { Self { grid } }
218
219    /// Find a path from `start` to `goal`, both as grid (x,y) coordinates.
220    pub fn find_path(&self, start: (usize, usize), goal: (usize, usize)) -> Option<Vec<(usize, usize)>> {
221        if !self.grid.walkable(start.0 as i32, start.1 as i32) { return None; }
222        if !self.grid.walkable(goal.0 as i32, goal.1 as i32) { return None; }
223        if start == goal { return Some(vec![start]); }
224
225        let mut open: BinaryHeap<JpsEntry> = BinaryHeap::new();
226        let mut came_from: HashMap<(usize,usize), (usize,usize)> = HashMap::new();
227        let mut g: HashMap<(usize,usize), f32> = HashMap::new();
228        let mut closed: HashSet<(usize,usize)> = HashSet::new();
229
230        g.insert(start, 0.0);
231        open.push(JpsEntry { pos: start, f: self.h(start, goal) });
232
233        while let Some(JpsEntry { pos: cur, .. }) = open.pop() {
234            if cur == goal {
235                return Some(self.reconstruct_path(start, goal, &came_from));
236            }
237            if closed.contains(&cur) { continue; }
238            closed.insert(cur);
239
240            let cur_g = *g.get(&cur).unwrap_or(&f32::MAX);
241            let successors = self.identify_successors(cur, goal, &came_from);
242
243            for succ in successors {
244                if closed.contains(&succ) { continue; }
245                let d = self.cost(cur, succ);
246                let ng = cur_g + d;
247                if ng < *g.get(&succ).unwrap_or(&f32::MAX) {
248                    g.insert(succ, ng);
249                    came_from.insert(succ, cur);
250                    open.push(JpsEntry { pos: succ, f: ng + self.h(succ, goal) });
251                }
252            }
253        }
254        None
255    }
256
257    fn h(&self, a: (usize,usize), b: (usize,usize)) -> f32 {
258        let dx = (a.0 as f32 - b.0 as f32).abs();
259        let dy = (a.1 as f32 - b.1 as f32).abs();
260        // Octile distance
261        let (mn, mx) = if dx < dy { (dx, dy) } else { (dy, dx) };
262        mx + mn * (std::f32::consts::SQRT_2 - 1.0)
263    }
264
265    fn cost(&self, a: (usize,usize), b: (usize,usize)) -> f32 {
266        let dx = (a.0 as i32 - b.0 as i32).abs();
267        let dy = (a.1 as i32 - b.1 as i32).abs();
268        if dx + dy == 2 { std::f32::consts::SQRT_2 } else { 1.0 }
269    }
270
271    fn identify_successors(
272        &self,
273        node: (usize,usize),
274        goal: (usize,usize),
275        came_from: &HashMap<(usize,usize),(usize,usize)>,
276    ) -> Vec<(usize,usize)> {
277        let neighbors = self.prune_neighbors(node, came_from);
278        let mut successors = Vec::new();
279        for nb in neighbors {
280            let dx = (nb.0 as i32 - node.0 as i32).signum();
281            let dy = (nb.1 as i32 - node.1 as i32).signum();
282            if let Some(jp) = self.jump(node, (dx, dy), goal) {
283                successors.push(jp);
284            }
285        }
286        successors
287    }
288
289    fn prune_neighbors(
290        &self,
291        node: (usize,usize),
292        came_from: &HashMap<(usize,usize),(usize,usize)>,
293    ) -> Vec<(usize,usize)> {
294        let parent = came_from.get(&node);
295        let (x, y) = (node.0 as i32, node.1 as i32);
296        if parent.is_none() {
297            // Start node: return all walkable neighbors
298            return self.all_neighbors(node);
299        }
300        let parent = parent.unwrap();
301        let dx = (x - parent.0 as i32).signum();
302        let dy = (y - parent.1 as i32).signum();
303        let mut neighbors = Vec::new();
304
305        if dx != 0 && dy != 0 {
306            // Diagonal
307            if self.grid.walkable(x, y + dy)     { neighbors.push((x as usize, (y+dy) as usize)); }
308            if self.grid.walkable(x + dx, y)     { neighbors.push(((x+dx) as usize, y as usize)); }
309            if self.grid.walkable(x + dx, y + dy) { neighbors.push(((x+dx) as usize, (y+dy) as usize)); }
310            if !self.grid.walkable(x - dx, y) && self.grid.walkable(x, y + dy) {
311                neighbors.push((x as usize, (y + dy) as usize));
312            }
313            if !self.grid.walkable(x, y - dy) && self.grid.walkable(x + dx, y) {
314                neighbors.push(((x + dx) as usize, y as usize));
315            }
316        } else if dx != 0 {
317            // Horizontal
318            if self.grid.walkable(x + dx, y) { neighbors.push(((x+dx) as usize, y as usize)); }
319            if !self.grid.walkable(x, y + 1) && self.grid.walkable(x + dx, y + 1) {
320                neighbors.push(((x+dx) as usize, (y+1) as usize));
321            }
322            if !self.grid.walkable(x, y - 1) && self.grid.walkable(x + dx, y - 1) {
323                neighbors.push(((x+dx) as usize, (y-1) as usize));
324            }
325        } else {
326            // Vertical
327            if self.grid.walkable(x, y + dy) { neighbors.push((x as usize, (y+dy) as usize)); }
328            if !self.grid.walkable(x + 1, y) && self.grid.walkable(x + 1, y + dy) {
329                neighbors.push(((x+1) as usize, (y+dy) as usize));
330            }
331            if !self.grid.walkable(x - 1, y) && self.grid.walkable(x - 1, y + dy) {
332                neighbors.push(((x-1) as usize, (y+dy) as usize));
333            }
334        }
335        neighbors.dedup();
336        neighbors
337    }
338
339    fn all_neighbors(&self, node: (usize,usize)) -> Vec<(usize,usize)> {
340        let (x, y) = (node.0 as i32, node.1 as i32);
341        let mut result = Vec::new();
342        for dy in -1i32..=1 {
343            for dx in -1i32..=1 {
344                if dx == 0 && dy == 0 { continue; }
345                if self.grid.walkable(x + dx, y + dy) {
346                    result.push(((x + dx) as usize, (y + dy) as usize));
347                }
348            }
349        }
350        result
351    }
352
353    fn jump(&self, node: (usize,usize), dir: (i32,i32), goal: (usize,usize)) -> Option<(usize,usize)> {
354        let (mut x, mut y) = (node.0 as i32, node.1 as i32);
355        let (dx, dy) = dir;
356        let max_steps = (self.grid.width + self.grid.height) * 2;
357        let mut steps = 0;
358
359        loop {
360            x += dx;
361            y += dy;
362            steps += 1;
363            if steps > max_steps { return None; }
364            if !self.grid.walkable(x, y) { return None; }
365            let cur = (x as usize, y as usize);
366            if cur == goal { return Some(cur); }
367
368            // Check for forced neighbors
369            if self.has_forced_neighbor(cur, dir) { return Some(cur); }
370
371            // Diagonal: recurse on both cardinal directions
372            if dx != 0 && dy != 0 {
373                if self.jump((x as usize, y as usize), (dx, 0), goal).is_some() { return Some(cur); }
374                if self.jump((x as usize, y as usize), (0, dy), goal).is_some() { return Some(cur); }
375            }
376        }
377    }
378
379    fn has_forced_neighbor(&self, node: (usize,usize), dir: (i32,i32)) -> bool {
380        let (x, y) = (node.0 as i32, node.1 as i32);
381        let (dx, dy) = dir;
382        if dx != 0 && dy != 0 {
383            // diagonal forced: blocked adjacent cardinal
384            (!self.grid.walkable(x - dx, y) && self.grid.walkable(x - dx, y + dy))
385            || (!self.grid.walkable(x, y - dy) && self.grid.walkable(x + dx, y - dy))
386        } else if dx != 0 {
387            (!self.grid.walkable(x, y + 1) && self.grid.walkable(x + dx, y + 1))
388            || (!self.grid.walkable(x, y - 1) && self.grid.walkable(x + dx, y - 1))
389        } else {
390            (!self.grid.walkable(x + 1, y) && self.grid.walkable(x + 1, y + dy))
391            || (!self.grid.walkable(x - 1, y) && self.grid.walkable(x - 1, y + dy))
392        }
393    }
394
395    fn reconstruct_path(
396        &self,
397        start: (usize,usize),
398        goal: (usize,usize),
399        came_from: &HashMap<(usize,usize),(usize,usize)>,
400    ) -> Vec<(usize,usize)> {
401        let mut path = Vec::new();
402        let mut cur = goal;
403        while cur != start {
404            path.push(cur);
405            match came_from.get(&cur) {
406                Some(&p) => cur = p,
407                None => break,
408            }
409        }
410        path.push(start);
411        path.reverse();
412        // Expand jump-point path into full grid steps
413        let mut expanded = Vec::new();
414        for i in 0..path.len().saturating_sub(1) {
415            expanded.push(path[i]);
416            let (ax, ay) = (path[i].0 as i32, path[i].1 as i32);
417            let (bx, by) = (path[i+1].0 as i32, path[i+1].1 as i32);
418            let sdx = (bx - ax).signum();
419            let sdy = (by - ay).signum();
420            let mut cx = ax + sdx;
421            let mut cy = ay + sdy;
422            while (cx, cy) != (bx, by) {
423                expanded.push((cx as usize, cy as usize));
424                cx += sdx;
425                cy += sdy;
426            }
427        }
428        if let Some(&last) = path.last() { expanded.push(last); }
429        expanded.dedup();
430        expanded
431    }
432}
433
434#[derive(PartialEq)]
435struct JpsEntry { pos: (usize,usize), f: f32 }
436impl Eq for JpsEntry {}
437impl PartialOrd for JpsEntry {
438    fn partial_cmp(&self, o: &Self) -> Option<Ordering> { Some(self.cmp(o)) }
439}
440impl Ord for JpsEntry {
441    fn cmp(&self, o: &Self) -> Ordering { o.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal) }
442}
443
444// ── Hierarchical A* ───────────────────────────────────────────────────────────
445
446#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
447pub struct ClusterId(pub u32);
448
449/// A cluster groups nearby grid cells for hierarchical planning.
450#[derive(Clone, Debug)]
451pub struct Cluster {
452    pub id:            ClusterId,
453    pub x:             usize,   // grid cell offset
454    pub y:             usize,
455    pub width:         usize,
456    pub height:        usize,
457    pub entry_cells:   Vec<(usize, usize)>,   // border cells that connect to other clusters
458    pub neighbors:     Vec<(ClusterId, f32)>, // neighbor cluster + estimated cost
459}
460
461impl Cluster {
462    pub fn contains(&self, cx: usize, cy: usize) -> bool {
463        cx >= self.x && cx < self.x + self.width
464        && cy >= self.y && cy < self.y + self.height
465    }
466    pub fn center_cell(&self) -> (usize, usize) {
467        (self.x + self.width / 2, self.y + self.height / 2)
468    }
469}
470
471/// Hierarchical pathfinder: builds abstract cluster graph then refines.
472pub struct HierarchicalPathfinder {
473    pub clusters:      Vec<Cluster>,
474    pub cluster_map:   Vec<Option<ClusterId>>,  // per grid cell
475    pub grid_width:    usize,
476    pub grid_height:   usize,
477    pub cluster_size:  usize,
478}
479
480impl HierarchicalPathfinder {
481    /// Build cluster graph from a GridMap with given cluster size.
482    pub fn build(grid: &GridMap, cluster_size: usize) -> Self {
483        let cw = (grid.width  + cluster_size - 1) / cluster_size;
484        let ch = (grid.height + cluster_size - 1) / cluster_size;
485        let mut clusters = Vec::with_capacity(cw * ch);
486        let mut cluster_map = vec![None; grid.width * grid.height];
487        let mut id = 0u32;
488
489        for cy in 0..ch {
490            for cx in 0..cw {
491                let ox = cx * cluster_size;
492                let oy = cy * cluster_size;
493                let w  = cluster_size.min(grid.width  - ox);
494                let h  = cluster_size.min(grid.height - oy);
495
496                let mut entry_cells = Vec::new();
497                // Top/bottom border
498                for bx in ox..(ox+w) {
499                    if grid.walkable(bx as i32, oy as i32)        { entry_cells.push((bx, oy)); }
500                    let by = oy + h - 1;
501                    if grid.walkable(bx as i32, by as i32)        { entry_cells.push((bx, by)); }
502                }
503                // Left/right border
504                for by in oy..(oy+h) {
505                    if grid.walkable(ox as i32, by as i32)        { entry_cells.push((ox, by)); }
506                    let bx = ox + w - 1;
507                    if grid.walkable(bx as i32, by as i32)        { entry_cells.push((bx, by)); }
508                }
509                entry_cells.sort();
510                entry_cells.dedup();
511
512                let cid = ClusterId(id);
513                for gx in ox..(ox+w) {
514                    for gy in oy..(oy+h) {
515                        if grid.in_bounds(gx as i32, gy as i32) {
516                            let gi = gy * grid.width + gx;
517                            cluster_map[gi] = Some(cid);
518                        }
519                    }
520                }
521                clusters.push(Cluster {
522                    id: cid, x: ox, y: oy, width: w, height: h,
523                    entry_cells, neighbors: Vec::new(),
524                });
525                id += 1;
526            }
527        }
528
529        // Build neighbor edges between adjacent clusters
530        let mut hpf = HierarchicalPathfinder {
531            clusters, cluster_map,
532            grid_width: grid.width,
533            grid_height: grid.height,
534            cluster_size,
535        };
536        hpf.build_cluster_edges(grid);
537        hpf
538    }
539
540    fn build_cluster_edges(&mut self, grid: &GridMap) {
541        let nc = self.clusters.len();
542        for i in 0..nc {
543            let ci = &self.clusters[i];
544            // Check 4 adjacent cluster positions
545            let (cx, cy) = (ci.x, ci.y);
546            let cs = self.cluster_size;
547            let adj_offsets: [(i32,i32); 4] = [(1,0),(-1,0),(0,1),(0,-1)];
548            let mut neighbors = Vec::new();
549            for (aox, aoy) in adj_offsets {
550                let nx = cx as i32 + aox * cs as i32;
551                let ny = cy as i32 + aoy * cs as i32;
552                if nx < 0 || ny < 0 { continue; }
553                if let Some(j) = self.find_cluster_at(nx as usize, ny as usize) {
554                    if j != i {
555                        let cost = cs as f32; // approximate
556                        neighbors.push((self.clusters[j].id, cost));
557                    }
558                }
559            }
560            // Update neighbors (can't borrow mut + immut simultaneously, so rebuild)
561            let _ = neighbors; // will be set below
562        }
563        // Simplified: link adjacent grid clusters
564        let cw = (grid.width  + self.cluster_size - 1) / self.cluster_size;
565        let ch = (grid.height + self.cluster_size - 1) / self.cluster_size;
566        for cy in 0..ch {
567            for cx in 0..cw {
568                let idx = cy * cw + cx;
569                if idx >= self.clusters.len() { continue; }
570                let mut nbrs = Vec::new();
571                let pairs: [(i32,i32); 4] = [(1,0),(-1,0),(0,1),(0,-1)];
572                for (ddx, ddy) in pairs {
573                    let ncx = cx as i32 + ddx;
574                    let ncy = cy as i32 + ddy;
575                    if ncx < 0 || ncy < 0 || ncx >= cw as i32 || ncy >= ch as i32 { continue; }
576                    let nidx = (ncy as usize) * cw + (ncx as usize);
577                    if nidx < self.clusters.len() {
578                        let nid = self.clusters[nidx].id;
579                        nbrs.push((nid, self.cluster_size as f32));
580                    }
581                }
582                self.clusters[idx].neighbors = nbrs;
583            }
584        }
585    }
586
587    fn find_cluster_at(&self, x: usize, y: usize) -> Option<usize> {
588        self.clusters.iter().position(|c| c.contains(x, y))
589    }
590
591    pub fn cluster_for_cell(&self, x: usize, y: usize) -> Option<ClusterId> {
592        if x >= self.grid_width || y >= self.grid_height { return None; }
593        self.cluster_map[y * self.grid_width + x]
594    }
595
596    /// High-level path: returns sequence of ClusterIds.
597    pub fn abstract_path(&self, start_cell: (usize,usize), goal_cell: (usize,usize)) -> Vec<ClusterId> {
598        let sc = match self.cluster_for_cell(start_cell.0, start_cell.1) { Some(c) => c, None => return Vec::new() };
599        let gc = match self.cluster_for_cell(goal_cell.0, goal_cell.1) { Some(c) => c, None => return Vec::new() };
600        if sc == gc { return vec![sc]; }
601
602        let mut open: BinaryHeap<ClusterEntry> = BinaryHeap::new();
603        let mut came_from: HashMap<ClusterId, ClusterId> = HashMap::new();
604        let mut g: HashMap<ClusterId, f32> = HashMap::new();
605
606        g.insert(sc, 0.0);
607        open.push(ClusterEntry { id: sc, f: self.cluster_heuristic(sc, gc) });
608
609        while let Some(ClusterEntry { id: cur, .. }) = open.pop() {
610            if cur == gc {
611                let mut path = Vec::new();
612                let mut c = gc;
613                while c != sc {
614                    path.push(c);
615                    c = *came_from.get(&c).unwrap_or(&sc);
616                }
617                path.push(sc);
618                path.reverse();
619                return path;
620            }
621            let cur_g = *g.get(&cur).unwrap_or(&f32::MAX);
622            if let Some(cluster) = self.clusters.iter().find(|c| c.id == cur) {
623                for &(nid, edge_cost) in &cluster.neighbors {
624                    let ng = cur_g + edge_cost;
625                    if ng < *g.get(&nid).unwrap_or(&f32::MAX) {
626                        g.insert(nid, ng);
627                        came_from.insert(nid, cur);
628                        let h = self.cluster_heuristic(nid, gc);
629                        open.push(ClusterEntry { id: nid, f: ng + h });
630                    }
631                }
632            }
633        }
634        Vec::new()
635    }
636
637    fn cluster_heuristic(&self, a: ClusterId, b: ClusterId) -> f32 {
638        let ca = self.clusters.iter().find(|c| c.id == a).map(|c| c.center_cell()).unwrap_or((0,0));
639        let cb = self.clusters.iter().find(|c| c.id == b).map(|c| c.center_cell()).unwrap_or((0,0));
640        let dx = (ca.0 as f32 - cb.0 as f32).abs();
641        let dy = (ca.1 as f32 - cb.1 as f32).abs();
642        dx.max(dy)
643    }
644}
645
646#[derive(PartialEq)]
647struct ClusterEntry { id: ClusterId, f: f32 }
648impl Eq for ClusterEntry {}
649impl PartialOrd for ClusterEntry {
650    fn partial_cmp(&self, o: &Self) -> Option<Ordering> { Some(self.cmp(o)) }
651}
652impl Ord for ClusterEntry {
653    fn cmp(&self, o: &Self) -> Ordering { o.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal) }
654}
655
656// ── Flow Field ────────────────────────────────────────────────────────────────
657
658/// Flow direction per cell: an 8-directional flow vector.
659#[derive(Clone, Copy, Debug, Default)]
660pub struct FlowVector {
661    pub dx: i8,   // -1, 0, +1
662    pub dy: i8,
663}
664
665impl FlowVector {
666    pub fn as_vec2(self) -> Vec2 {
667        Vec2::new(self.dx as f32, self.dy as f32).norm()
668    }
669    pub fn is_valid(self) -> bool { self.dx != 0 || self.dy != 0 }
670}
671
672/// A flow field: precomputed for a single goal, steers any number of agents.
673#[derive(Clone, Debug)]
674pub struct FlowField {
675    pub width:   usize,
676    pub height:  usize,
677    pub flow:    Vec<FlowVector>,
678    pub cost:    Vec<f32>,          // integration field (distance to goal)
679    pub goal:    (usize, usize),
680}
681
682/// Flow field grid: builds and stores flow fields.
683pub struct FlowFieldGrid<'a> {
684    pub grid: &'a GridMap,
685}
686
687impl<'a> FlowFieldGrid<'a> {
688    pub fn new(grid: &'a GridMap) -> Self { Self { grid } }
689
690    /// Build a flow field toward `goal` using Dijkstra integration.
691    pub fn build(&self, goal: (usize, usize)) -> FlowField {
692        let w = self.grid.width;
693        let h = self.grid.height;
694        let inf = f32::MAX / 2.0;
695        let mut cost = vec![inf; w * h];
696        let mut flow = vec![FlowVector::default(); w * h];
697
698        if !self.grid.walkable(goal.0 as i32, goal.1 as i32) {
699            return FlowField { width: w, height: h, flow, cost, goal };
700        }
701
702        let gi = goal.1 * w + goal.0;
703        cost[gi] = 0.0;
704
705        // BFS/Dijkstra integration field
706        let mut queue: VecDeque<(usize,usize)> = VecDeque::new();
707        queue.push_back(goal);
708
709        while let Some((cx, cy)) = queue.pop_front() {
710            let cur_cost = cost[cy * w + cx];
711            for (dx, dy) in DIRS_8 {
712                let nx = cx as i32 + dx;
713                let ny = cy as i32 + dy;
714                if !self.grid.walkable(nx, ny) { continue; }
715                let (nxi, nyi) = (nx as usize, ny as usize);
716                let ni = nyi * w + nxi;
717                let step_cost = if dx != 0 && dy != 0 { std::f32::consts::SQRT_2 } else { 1.0 };
718                let nc = cur_cost + step_cost;
719                if nc < cost[ni] {
720                    cost[ni] = nc;
721                    queue.push_back((nxi, nyi));
722                }
723            }
724        }
725
726        // Build flow vectors: each cell points toward the lowest-cost neighbor
727        for cy in 0..h {
728            for cx in 0..w {
729                if !self.grid.walkable(cx as i32, cy as i32) { continue; }
730                let ci = cy * w + cx;
731                if cost[ci] >= inf { continue; }
732
733                let mut best_cost = cost[ci];
734                let mut best_dx = 0i8;
735                let mut best_dy = 0i8;
736                for (dx, dy) in DIRS_8 {
737                    let nx = cx as i32 + dx;
738                    let ny = cy as i32 + dy;
739                    if !self.grid.walkable(nx, ny) { continue; }
740                    let ni = ny as usize * w + nx as usize;
741                    if cost[ni] < best_cost {
742                        best_cost = cost[ni];
743                        best_dx = dx as i8;
744                        best_dy = dy as i8;
745                    }
746                }
747                flow[ci] = FlowVector { dx: best_dx, dy: best_dy };
748            }
749        }
750
751        FlowField { width: w, height: h, flow, cost, goal }
752    }
753}
754
755const DIRS_8: [(i32,i32); 8] = [
756    (1,0),(-1,0),(0,1),(0,-1),(1,1),(1,-1),(-1,1),(-1,-1)
757];
758
759impl FlowField {
760    /// Get the flow vector at grid cell (x, y).
761    pub fn get_flow(&self, x: usize, y: usize) -> FlowVector {
762        if x < self.width && y < self.height {
763            self.flow[y * self.width + x]
764        } else {
765            FlowVector::default()
766        }
767    }
768
769    /// Get the integration cost at grid cell (x, y).
770    pub fn get_cost(&self, x: usize, y: usize) -> f32 {
771        if x < self.width && y < self.height {
772            self.cost[y * self.width + x]
773        } else {
774            f32::MAX
775        }
776    }
777
778    /// Sample the flow direction at world position `p` (grid-space integer lookup).
779    pub fn sample(&self, gx: usize, gy: usize) -> Vec2 {
780        self.get_flow(gx, gy).as_vec2()
781    }
782}
783
784// ── Path cache with invalidation ──────────────────────────────────────────────
785
786/// A cached path entry.
787#[derive(Clone, Debug)]
788pub struct CachedPath {
789    pub start:   (usize, usize),
790    pub goal:    (usize, usize),
791    pub path:    Vec<(usize, usize)>,
792    pub version: u64,
793}
794
795/// Cache of computed paths, invalidated when the grid changes.
796pub struct PathCache {
797    entries:       HashMap<((usize,usize),(usize,usize)), CachedPath>,
798    pub version:   u64,
799    capacity:      usize,
800    // LRU tracking via insertion order
801    order:         VecDeque<((usize,usize),(usize,usize))>,
802}
803
804impl PathCache {
805    pub fn new(capacity: usize) -> Self {
806        Self {
807            entries: HashMap::new(),
808            version: 0,
809            capacity,
810            order: VecDeque::new(),
811        }
812    }
813
814    /// Increment version, invalidating all stale cache entries.
815    pub fn invalidate(&mut self) {
816        self.version += 1;
817    }
818
819    /// Clear all entries.
820    pub fn clear(&mut self) {
821        self.entries.clear();
822        self.order.clear();
823    }
824
825    /// Look up a cached path; returns None if not present or stale.
826    pub fn get(&self, start: (usize,usize), goal: (usize,usize)) -> Option<&Vec<(usize,usize)>> {
827        let key = (start, goal);
828        let entry = self.entries.get(&key)?;
829        if entry.version == self.version {
830            Some(&entry.path)
831        } else {
832            None
833        }
834    }
835
836    /// Store a path in the cache, evicting LRU entry if over capacity.
837    pub fn insert(&mut self, start: (usize,usize), goal: (usize,usize), path: Vec<(usize,usize)>) {
838        let key = (start, goal);
839        if self.entries.contains_key(&key) {
840            self.entries.get_mut(&key).unwrap().path = path;
841            self.entries.get_mut(&key).unwrap().version = self.version;
842        } else {
843            if self.entries.len() >= self.capacity {
844                if let Some(evict_key) = self.order.pop_front() {
845                    self.entries.remove(&evict_key);
846                }
847            }
848            self.entries.insert(key, CachedPath { start, goal, path, version: self.version });
849            self.order.push_back(key);
850        }
851    }
852
853    /// Get or compute a path, using JPS if not cached.
854    pub fn get_or_compute<'g>(&mut self, grid: &'g GridMap, start: (usize,usize), goal: (usize,usize)) -> Vec<(usize,usize)> {
855        if let Some(cached) = self.get(start, goal) {
856            return cached.clone();
857        }
858        let jps = JpsPathfinder::new(grid);
859        let path = jps.find_path(start, goal).unwrap_or_default();
860        self.insert(start, goal, path.clone());
861        path
862    }
863
864    pub fn entry_count(&self) -> usize { self.entries.len() }
865}
866
867// ── Simple concrete graph for generic A* ─────────────────────────────────────
868
869/// Simple flat graph with node positions and weighted edges.
870pub struct SimpleGraph {
871    pub nodes:     Vec<Vec2>,
872    pub edges:     Vec<Vec<(NodeId, f32)>>,
873}
874
875impl SimpleGraph {
876    pub fn new() -> Self { Self { nodes: Vec::new(), edges: Vec::new() } }
877
878    pub fn add_node(&mut self, pos: Vec2) -> NodeId {
879        let id = NodeId(self.nodes.len() as u32);
880        self.nodes.push(pos);
881        self.edges.push(Vec::new());
882        id
883    }
884
885    pub fn add_edge(&mut self, a: NodeId, b: NodeId, cost: f32) {
886        let ai = a.0 as usize;
887        let bi = b.0 as usize;
888        if ai < self.edges.len() { self.edges[ai].push((b, cost)); }
889        if bi < self.edges.len() { self.edges[bi].push((a, cost)); }
890    }
891}
892
893impl AStarGraph for SimpleGraph {
894    type Cost = f32;
895    fn zero_cost() -> f32 { 0.0 }
896    fn max_cost() -> f32  { f32::MAX / 2.0 }
897    fn heuristic(&self, from: NodeId, to: NodeId) -> f32 {
898        let a = self.nodes.get(from.0 as usize).copied().unwrap_or(Vec2::zero());
899        let b = self.nodes.get(to.0  as usize).copied().unwrap_or(Vec2::zero());
900        a.dist(b)
901    }
902    fn neighbors(&self, node: NodeId) -> Vec<(NodeId, f32)> {
903        self.edges.get(node.0 as usize).cloned().unwrap_or_default()
904    }
905}
906
907// ── Tests ─────────────────────────────────────────────────────────────────────
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912
913    #[test]
914    fn test_astar_simple() {
915        let mut g = SimpleGraph::new();
916        let a = g.add_node(Vec2::new(0.0, 0.0));
917        let b = g.add_node(Vec2::new(1.0, 0.0));
918        let c = g.add_node(Vec2::new(2.0, 0.0));
919        g.add_edge(a, b, 1.0);
920        g.add_edge(b, c, 1.0);
921        let res = astar_search(&g, a, c).unwrap();
922        assert_eq!(res.path, vec![a, b, c]);
923        assert!((res.cost - 2.0).abs() < 1e-4);
924    }
925
926    #[test]
927    fn test_jps_straight() {
928        let mut grid = GridMap::new(10, 10, 1.0, Vec2::zero());
929        let jps = JpsPathfinder::new(&grid);
930        let path = jps.find_path((0,0), (5,0)).unwrap();
931        assert!(!path.is_empty());
932        assert_eq!(path[0], (0,0));
933        assert_eq!(*path.last().unwrap(), (5,0));
934    }
935
936    #[test]
937    fn test_jps_with_obstacle() {
938        let mut grid = GridMap::new(10, 10, 1.0, Vec2::zero());
939        // Wall in the middle
940        for y in 0..8 { grid.set_walkable(5, y, false); }
941        let jps = JpsPathfinder::new(&grid);
942        let path = jps.find_path((0,5), (9,5));
943        // Should find path around wall
944        assert!(path.is_some());
945    }
946
947    #[test]
948    fn test_flow_field() {
949        let grid = GridMap::new(8, 8, 1.0, Vec2::zero());
950        let ffg = FlowFieldGrid::new(&grid);
951        let ff = ffg.build((7, 7));
952        // Cell (0,0) should have valid flow toward goal
953        let fv = ff.get_flow(0, 0);
954        assert!(fv.is_valid());
955    }
956
957    #[test]
958    fn test_path_cache() {
959        let grid = GridMap::new(10, 10, 1.0, Vec2::zero());
960        let mut cache = PathCache::new(16);
961        let path = cache.get_or_compute(&grid, (0,0), (9,9));
962        assert!(!path.is_empty());
963        // Second call should hit cache
964        let path2 = cache.get_or_compute(&grid, (0,0), (9,9));
965        assert_eq!(path, path2);
966        // After invalidation, cache entry is stale
967        cache.invalidate();
968        let cached = cache.get((0,0), (9,9));
969        assert!(cached.is_none());
970    }
971
972    #[test]
973    fn test_hierarchical_abstract_path() {
974        let grid = GridMap::new(16, 16, 1.0, Vec2::zero());
975        let hpf = HierarchicalPathfinder::build(&grid, 4);
976        let abstract_path = hpf.abstract_path((0,0), (15,15));
977        assert!(!abstract_path.is_empty());
978    }
979}