Skip to main content

proof_engine/graph/
pathfinding.rs

1use glam::{Vec2, Vec3, Vec4};
2use std::collections::{BinaryHeap, HashMap, HashSet};
3use std::cmp::Ordering;
4use super::graph_core::{Graph, NodeId, EdgeId};
5
6#[derive(Debug, Clone)]
7pub struct Path {
8    pub nodes: Vec<NodeId>,
9    pub total_weight: f32,
10}
11
12impl Path {
13    pub fn is_empty(&self) -> bool {
14        self.nodes.is_empty()
15    }
16
17    pub fn len(&self) -> usize {
18        self.nodes.len()
19    }
20}
21
22#[derive(Debug, Clone, PartialEq)]
23struct DijkstraEntry {
24    node: NodeId,
25    cost: f32,
26}
27
28impl Eq for DijkstraEntry {}
29
30impl PartialOrd for DijkstraEntry {
31    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
32        Some(self.cmp(other))
33    }
34}
35
36impl Ord for DijkstraEntry {
37    fn cmp(&self, other: &Self) -> Ordering {
38        other.cost.partial_cmp(&self.cost).unwrap_or(Ordering::Equal)
39    }
40}
41
42/// Dijkstra's shortest path from start to end.
43pub fn dijkstra<N, E>(graph: &Graph<N, E>, start: NodeId, end: NodeId) -> Option<Path> {
44    if !graph.has_node(start) || !graph.has_node(end) {
45        return None;
46    }
47    if start == end {
48        return Some(Path { nodes: vec![start], total_weight: 0.0 });
49    }
50
51    let mut dist: HashMap<NodeId, f32> = HashMap::new();
52    let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
53    let mut heap = BinaryHeap::new();
54
55    dist.insert(start, 0.0);
56    heap.push(DijkstraEntry { node: start, cost: 0.0 });
57
58    while let Some(DijkstraEntry { node, cost }) = heap.pop() {
59        if node == end {
60            // Reconstruct path
61            let mut path = Vec::new();
62            let mut current = end;
63            path.push(current);
64            while let Some(&p) = prev.get(&current) {
65                path.push(p);
66                current = p;
67            }
68            path.reverse();
69            return Some(Path { nodes: path, total_weight: cost });
70        }
71
72        if cost > dist.get(&node).copied().unwrap_or(f32::INFINITY) {
73            continue;
74        }
75
76        for (nbr, eid) in graph.neighbor_edges(node) {
77            let w = graph.edge_weight(eid);
78            let new_cost = cost + w;
79            if new_cost < dist.get(&nbr).copied().unwrap_or(f32::INFINITY) {
80                dist.insert(nbr, new_cost);
81                prev.insert(nbr, node);
82                heap.push(DijkstraEntry { node: nbr, cost: new_cost });
83            }
84        }
85    }
86
87    None
88}
89
90/// A* shortest path with a heuristic function.
91pub fn astar<N, E, H>(graph: &Graph<N, E>, start: NodeId, end: NodeId, heuristic: H) -> Option<Path>
92where
93    H: Fn(NodeId) -> f32,
94{
95    if !graph.has_node(start) || !graph.has_node(end) {
96        return None;
97    }
98    if start == end {
99        return Some(Path { nodes: vec![start], total_weight: 0.0 });
100    }
101
102    let mut g_score: HashMap<NodeId, f32> = HashMap::new();
103    let mut prev: HashMap<NodeId, NodeId> = HashMap::new();
104    let mut heap = BinaryHeap::new();
105    let mut closed = HashSet::new();
106
107    g_score.insert(start, 0.0);
108    heap.push(DijkstraEntry { node: start, cost: heuristic(start) });
109
110    while let Some(DijkstraEntry { node, cost: _ }) = heap.pop() {
111        if node == end {
112            let total = g_score[&end];
113            let mut path = Vec::new();
114            let mut current = end;
115            path.push(current);
116            while let Some(&p) = prev.get(&current) {
117                path.push(p);
118                current = p;
119            }
120            path.reverse();
121            return Some(Path { nodes: path, total_weight: total });
122        }
123
124        if !closed.insert(node) {
125            continue;
126        }
127
128        let current_g = g_score[&node];
129
130        for (nbr, eid) in graph.neighbor_edges(node) {
131            if closed.contains(&nbr) { continue; }
132            let w = graph.edge_weight(eid);
133            let tentative_g = current_g + w;
134            if tentative_g < g_score.get(&nbr).copied().unwrap_or(f32::INFINITY) {
135                g_score.insert(nbr, tentative_g);
136                prev.insert(nbr, node);
137                let f = tentative_g + heuristic(nbr);
138                heap.push(DijkstraEntry { node: nbr, cost: f });
139            }
140        }
141    }
142
143    None
144}
145
146/// Bellman-Ford: single-source shortest paths, handles negative weights.
147/// Returns distances from start to all reachable nodes.
148pub fn bellman_ford<N, E>(graph: &Graph<N, E>, start: NodeId) -> HashMap<NodeId, f32> {
149    let node_ids = graph.node_ids();
150    let mut dist: HashMap<NodeId, f32> = HashMap::new();
151    for &nid in &node_ids {
152        dist.insert(nid, f32::INFINITY);
153    }
154    dist.insert(start, 0.0);
155
156    let n = node_ids.len();
157
158    // Relax edges n-1 times
159    for _ in 0..(n - 1).max(1) {
160        let mut updated = false;
161        for edge in graph.edges() {
162            let du = dist.get(&edge.from).copied().unwrap_or(f32::INFINITY);
163            if du < f32::INFINITY {
164                let new_dist = du + edge.weight;
165                let dv = dist.get(&edge.to).copied().unwrap_or(f32::INFINITY);
166                if new_dist < dv {
167                    dist.insert(edge.to, new_dist);
168                    updated = true;
169                }
170            }
171            // For undirected graphs, relax in both directions
172            if graph.kind == super::graph_core::GraphKind::Undirected {
173                let dv = dist.get(&edge.to).copied().unwrap_or(f32::INFINITY);
174                if dv < f32::INFINITY {
175                    let new_dist = dv + edge.weight;
176                    let du_cur = dist.get(&edge.from).copied().unwrap_or(f32::INFINITY);
177                    if new_dist < du_cur {
178                        dist.insert(edge.from, new_dist);
179                        updated = true;
180                    }
181                }
182            }
183        }
184        if !updated { break; }
185    }
186
187    dist
188}
189
190/// Floyd-Warshall: all-pairs shortest paths.
191pub fn all_pairs_shortest<N, E>(graph: &Graph<N, E>) -> HashMap<(NodeId, NodeId), f32> {
192    let node_ids = graph.node_ids();
193    let n = node_ids.len();
194    let idx: HashMap<NodeId, usize> = node_ids.iter().enumerate().map(|(i, &nid)| (nid, i)).collect();
195
196    let mut dist = vec![vec![f32::INFINITY; n]; n];
197    for i in 0..n {
198        dist[i][i] = 0.0;
199    }
200
201    for edge in graph.edges() {
202        if let (Some(&i), Some(&j)) = (idx.get(&edge.from), idx.get(&edge.to)) {
203            dist[i][j] = dist[i][j].min(edge.weight);
204            if graph.kind == super::graph_core::GraphKind::Undirected {
205                dist[j][i] = dist[j][i].min(edge.weight);
206            }
207        }
208    }
209
210    // Floyd-Warshall relaxation
211    for k in 0..n {
212        for i in 0..n {
213            for j in 0..n {
214                let through_k = dist[i][k] + dist[k][j];
215                if through_k < dist[i][j] {
216                    dist[i][j] = through_k;
217                }
218            }
219        }
220    }
221
222    let mut result = HashMap::new();
223    for i in 0..n {
224        for j in 0..n {
225            if dist[i][j] < f32::INFINITY {
226                result.insert((node_ids[i], node_ids[j]), dist[i][j]);
227            }
228        }
229    }
230    result
231}
232
233/// Converts paths to visual glyph data: a glowing trail along edges.
234pub struct PathVisualizer {
235    pub trail_width: f32,
236    pub trail_color: Vec4,
237    pub glow_intensity: f32,
238}
239
240impl PathVisualizer {
241    pub fn new() -> Self {
242        Self {
243            trail_width: 3.0,
244            trail_color: Vec4::new(0.2, 0.8, 1.0, 1.0),
245            glow_intensity: 1.5,
246        }
247    }
248
249    pub fn with_color(mut self, color: Vec4) -> Self {
250        self.trail_color = color;
251        self
252    }
253
254    pub fn with_width(mut self, width: f32) -> Self {
255        self.trail_width = width;
256        self
257    }
258
259    /// Generate trail segments from a path and node positions.
260    /// Returns Vec of (start_pos, end_pos, color, width) for each edge in the path.
261    pub fn generate_trail<N, E>(
262        &self,
263        path: &Path,
264        graph: &Graph<N, E>,
265    ) -> Vec<TrailSegment> {
266        let mut segments = Vec::new();
267        if path.nodes.len() < 2 { return segments; }
268
269        let total = path.nodes.len() - 1;
270        for i in 0..total {
271            let from = path.nodes[i];
272            let to = path.nodes[i + 1];
273            let p0 = graph.node_position(from);
274            let p1 = graph.node_position(to);
275            let progress = i as f32 / total as f32;
276            // Glow fades along the trail
277            let alpha = self.trail_color.w * (1.0 - progress * 0.5);
278            let color = Vec4::new(
279                self.trail_color.x * self.glow_intensity,
280                self.trail_color.y * self.glow_intensity,
281                self.trail_color.z * self.glow_intensity,
282                alpha,
283            );
284            segments.push(TrailSegment {
285                start: p0,
286                end: p1,
287                color,
288                width: self.trail_width * (1.0 - progress * 0.3),
289            });
290        }
291        segments
292    }
293}
294
295#[derive(Debug, Clone)]
296pub struct TrailSegment {
297    pub start: Vec2,
298    pub end: Vec2,
299    pub color: Vec4,
300    pub width: f32,
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::graph::graph_core::{Graph, GraphKind, NodeId};
307
308    fn make_weighted_graph() -> Graph<(), ()> {
309        let mut g = Graph::new(GraphKind::Undirected);
310        let a = g.add_node(());
311        let b = g.add_node(());
312        let c = g.add_node(());
313        let d = g.add_node(());
314        g.add_edge_weighted(a, b, (), 1.0);
315        g.add_edge_weighted(b, c, (), 2.0);
316        g.add_edge_weighted(a, c, (), 10.0);
317        g.add_edge_weighted(c, d, (), 1.0);
318        g
319    }
320
321    #[test]
322    fn test_dijkstra_shortest_path() {
323        let g = make_weighted_graph();
324        let ids = g.node_ids();
325        let path = dijkstra(&g, ids[0], ids[3]).unwrap();
326        assert_eq!(path.total_weight, 4.0); // a->b(1) + b->c(2) + c->d(1)
327        assert_eq!(path.nodes.len(), 4);
328    }
329
330    #[test]
331    fn test_dijkstra_same_node() {
332        let g = make_weighted_graph();
333        let ids = g.node_ids();
334        let path = dijkstra(&g, ids[0], ids[0]).unwrap();
335        assert_eq!(path.total_weight, 0.0);
336        assert_eq!(path.nodes.len(), 1);
337    }
338
339    #[test]
340    fn test_dijkstra_no_path() {
341        let mut g = Graph::new(GraphKind::Directed);
342        let a = g.add_node(());
343        let b = g.add_node(());
344        // No edge from a to b in directed graph
345        assert!(dijkstra(&g, a, b).is_none());
346    }
347
348    #[test]
349    fn test_astar() {
350        let mut g = Graph::new(GraphKind::Undirected);
351        let a = g.add_node_with_pos((), Vec2::new(0.0, 0.0));
352        let b = g.add_node_with_pos((), Vec2::new(1.0, 0.0));
353        let c = g.add_node_with_pos((), Vec2::new(2.0, 0.0));
354        g.add_edge_weighted(a, b, (), 1.0);
355        g.add_edge_weighted(b, c, (), 1.0);
356        g.add_edge_weighted(a, c, (), 5.0);
357
358        let positions: HashMap<NodeId, Vec2> = g.node_ids().iter()
359            .map(|&nid| (nid, g.node_position(nid)))
360            .collect();
361        let goal_pos = positions[&c];
362
363        let path = astar(&g, a, c, |nid| {
364            let pos = positions.get(&nid).copied().unwrap_or(Vec2::ZERO);
365            (pos - goal_pos).length()
366        }).unwrap();
367
368        assert_eq!(path.total_weight, 2.0);
369        assert_eq!(path.nodes, vec![a, b, c]);
370    }
371
372    #[test]
373    fn test_bellman_ford() {
374        let g = make_weighted_graph();
375        let ids = g.node_ids();
376        let dist = bellman_ford(&g, ids[0]);
377        assert_eq!(*dist.get(&ids[0]).unwrap(), 0.0);
378        assert_eq!(*dist.get(&ids[1]).unwrap(), 1.0);
379        assert_eq!(*dist.get(&ids[2]).unwrap(), 3.0);
380        assert_eq!(*dist.get(&ids[3]).unwrap(), 4.0);
381    }
382
383    #[test]
384    fn test_floyd_warshall() {
385        let g = make_weighted_graph();
386        let ids = g.node_ids();
387        let apsp = all_pairs_shortest(&g);
388        assert_eq!(*apsp.get(&(ids[0], ids[3])).unwrap(), 4.0);
389        assert_eq!(*apsp.get(&(ids[3], ids[0])).unwrap(), 4.0);
390        assert_eq!(*apsp.get(&(ids[0], ids[0])).unwrap(), 0.0);
391    }
392
393    #[test]
394    fn test_path_visualizer() {
395        let mut g = Graph::new(GraphKind::Undirected);
396        let a = g.add_node_with_pos((), Vec2::new(0.0, 0.0));
397        let b = g.add_node_with_pos((), Vec2::new(10.0, 0.0));
398        let c = g.add_node_with_pos((), Vec2::new(20.0, 0.0));
399        g.add_edge(a, b, ());
400        g.add_edge(b, c, ());
401
402        let path = Path { nodes: vec![a, b, c], total_weight: 2.0 };
403        let viz = PathVisualizer::new();
404        let segments = viz.generate_trail(&path, &g);
405        assert_eq!(segments.len(), 2);
406        assert_eq!(segments[0].start, Vec2::new(0.0, 0.0));
407        assert_eq!(segments[0].end, Vec2::new(10.0, 0.0));
408    }
409
410    #[test]
411    fn test_dijkstra_directed() {
412        let mut g = Graph::new(GraphKind::Directed);
413        let a = g.add_node(());
414        let b = g.add_node(());
415        let c = g.add_node(());
416        g.add_edge_weighted(a, b, (), 1.0);
417        g.add_edge_weighted(b, c, (), 1.0);
418        g.add_edge_weighted(a, c, (), 5.0);
419        let path = dijkstra(&g, a, c).unwrap();
420        assert_eq!(path.total_weight, 2.0);
421    }
422
423    #[test]
424    fn test_bellman_ford_directed() {
425        let mut g = Graph::new(GraphKind::Directed);
426        let a = g.add_node(());
427        let b = g.add_node(());
428        let c = g.add_node(());
429        g.add_edge_weighted(a, b, (), 3.0);
430        g.add_edge_weighted(b, c, (), 4.0);
431        let dist = bellman_ford(&g, a);
432        assert_eq!(*dist.get(&a).unwrap(), 0.0);
433        assert_eq!(*dist.get(&b).unwrap(), 3.0);
434        assert_eq!(*dist.get(&c).unwrap(), 7.0);
435    }
436}