rshyper_algo/
astar.rs

1/*
2    Appellation: impl_astar <module>
3    Contrib: @FL03
4*/
5//! this module implements the A* search algorithm
6#[doc(inline)]
7pub use self::priority_node::PriorityNode;
8
9mod priority_node;
10
11use crate::error::{Error, Result};
12use crate::traits::{Heuristic, PathFinder, Search, Traversal};
13use alloc::collections::BinaryHeap;
14use core::hash::{BuildHasher, Hash};
15use hashbrown::{DefaultHashBuilder, HashMap, HashSet};
16use rshyper::idx::{HyperIndex, RawIndex, VertexId, VertexSet};
17use rshyper::rel::RawLayout;
18use rshyper::{GraphProps, GraphType, HyperGraph, HyperGraphIter};
19
20type SourceMap<Ix, S = DefaultHashBuilder> = HashMap<VertexId<Ix>, VertexId<Ix>, S>;
21
22type ScoreMap<K, V, S = DefaultHashBuilder> = HashMap<VertexId<K>, V, S>;
23
24/// An A* Search algorithm implementation for hypergraphs
25pub struct AStarSearch<'a, N, E, A, F, H, S = DefaultHashBuilder>
26where
27    A: GraphProps,
28    H: HyperGraph<N, E, A>,
29    F: Heuristic<A::Ix>,
30{
31    pub(crate) graph: &'a H,
32    pub(crate) open_set: VertexSet<A::Ix, S>,
33    pub(crate) closed_set: VertexSet<A::Ix, S>,
34    pub(crate) came_from: SourceMap<A::Ix, S>,
35    pub(crate) g_score: ScoreMap<A::Ix, F::Output, S>,
36    pub(crate) f_score: ScoreMap<A::Ix, F::Output, S>,
37    pub(crate) heuristic: F,
38    _marker: core::marker::PhantomData<(N, E)>,
39}
40
41impl<'a, N, E, A, F, H, S, K, Idx> AStarSearch<'a, N, E, A, F, H, S>
42where
43    A: GraphProps<Ix = Idx, Kind = K>,
44    H: HyperGraph<N, E, A>,
45    F: Heuristic<Idx>,
46    S: BuildHasher,
47    K: GraphType,
48    Idx: RawIndex,
49{
50    /// Create a new A* search instance with the given heuristic function
51    pub fn new(graph: &'a H, heuristic: F) -> Self
52    where
53        S: Default,
54    {
55        Self {
56            heuristic,
57            graph,
58            open_set: VertexSet::default(),
59            closed_set: VertexSet::default(),
60            came_from: SourceMap::default(),
61            g_score: ScoreMap::default(),
62            f_score: ScoreMap::default(),
63            _marker: core::marker::PhantomData::<(N, E)>,
64        }
65    }
66    /// consumes the current instance to create another from the given heuristic function;
67    /// **note:** while the functions may be different, the output type of both must match.
68    pub fn with_heuristic<G>(self, heuristic: G) -> AStarSearch<'a, N, E, A, G, H, S>
69    where
70        G: Heuristic<Idx, Output = F::Output>,
71    {
72        AStarSearch {
73            heuristic,
74            graph: self.graph,
75            open_set: self.open_set,
76            closed_set: self.closed_set,
77            came_from: self.came_from,
78            g_score: self.g_score,
79            f_score: self.f_score,
80            _marker: self._marker,
81        }
82    }
83
84    pub const fn came_from(&self) -> &SourceMap<A::Ix, S> {
85        &self.came_from
86    }
87    /// returns a mutable reference to the map of vertices that have been processed
88    pub const fn came_from_mut(&mut self) -> &mut SourceMap<A::Ix, S> {
89        &mut self.came_from
90    }
91    /// returns an immutable reference to the closed set of vertices
92    pub const fn closed_set(&self) -> &VertexSet<A::Ix, S> {
93        &self.closed_set
94    }
95    /// returns a mutable reference to the closed set of vertices
96    pub const fn closed_set_mut(&mut self) -> &mut VertexSet<A::Ix, S> {
97        &mut self.closed_set
98    }
99    /// returns an immutable reference to the f_score map
100    pub const fn f_score(&self) -> &ScoreMap<A::Ix, F::Output, S> {
101        &self.f_score
102    }
103    /// returns a mutable reference to the f_score map
104    pub const fn f_score_mut(&mut self) -> &mut ScoreMap<A::Ix, F::Output, S> {
105        &mut self.f_score
106    }
107    /// returns an immutable reference to the g_score map
108    pub const fn g_score(&self) -> &ScoreMap<A::Ix, F::Output, S> {
109        &self.g_score
110    }
111    /// returns a mutable reference to the g_score map
112    pub const fn g_score_mut(&mut self) -> &mut ScoreMap<A::Ix, F::Output, S> {
113        &mut self.g_score
114    }
115    /// returns an immutable reference to the heuristic function of the algorithm
116    pub const fn heuristic(&self) -> &F {
117        &self.heuristic
118    }
119    /// returns an immutable reference to the set of vertices that have been visited
120    pub const fn open_set(&self) -> &VertexSet<A::Ix, S> {
121        &self.open_set
122    }
123    /// returns amutable reference to the open set of vertices
124    pub const fn open_set_mut(&mut self) -> &mut VertexSet<A::Ix, S> {
125        &mut self.open_set
126    }
127    /// returns true if the given vertex has a f_score
128    pub fn has_f_score<Q>(&self, vertex: &Q) -> bool
129    where
130        Q: ?Sized + Eq + Hash,
131        Idx: Eq + Hash,
132        VertexId<Idx>: core::borrow::Borrow<Q>,
133    {
134        self.f_score().contains_key(vertex)
135    }
136    /// returns true if the given vertex has a g_score
137    pub fn has_g_score<Q>(&self, vertex: &Q) -> bool
138    where
139        Q: ?Sized + Eq + Hash,
140        Idx: Eq + Hash,
141        VertexId<Idx>: core::borrow::Borrow<Q>,
142    {
143        self.g_score().contains_key(vertex)
144    }
145    /// returns true if the given vertex has been visited
146    pub fn has_visited<Q>(&self, vertex: &Q) -> bool
147    where
148        Q: ?Sized + Eq + Hash,
149        Idx: Eq + Hash,
150        VertexId<Idx>: core::borrow::Borrow<Q>,
151    {
152        self.closed_set().contains(vertex)
153    }
154    /// returns true if the given vertex is in the open set
155    pub fn in_open_set<Q>(&self, vertex: &Q) -> bool
156    where
157        Q: ?Sized + Eq + Hash,
158        Idx: Eq + Hash,
159        VertexId<Idx>: core::borrow::Borrow<Q>,
160    {
161        self.open_set().contains(vertex)
162    }
163    /// moves the vertex from the open set before inserting it into the closed set; this is
164    /// useful for updating the state, marking a node as processed.
165    pub fn move_open_to_closed(&mut self, vertex: &VertexId<Idx>)
166    where
167        Idx: Copy + Eq + Hash,
168    {
169        self.open_set_mut().remove(vertex);
170        self.closed_set_mut().insert(*vertex);
171    }
172    /// reset the state
173    pub fn reset(&mut self) -> &mut Self {
174        self.open_set_mut().clear();
175        self.closed_set_mut().clear();
176        self.came_from_mut().clear();
177        self.g_score_mut().clear();
178        self.f_score_mut().clear();
179        self
180    }
181    /// find a path between two nodes
182    pub fn find_path(
183        &mut self,
184        start: VertexId<Idx>,
185        goal: VertexId<Idx>,
186    ) -> Result<<Self as PathFinder<Idx>>::Path>
187    where
188        Self: PathFinder<Idx>,
189    {
190        PathFinder::find_path(self, start, goal)
191    }
192    /// a convience method to perform a search
193    pub fn search(
194        &mut self,
195        start: VertexId<Idx>,
196    ) -> Result<<Self as Search<VertexId<Idx>>>::Output>
197    where
198        Self: Search<VertexId<Idx>>,
199    {
200        Search::search(self, start)
201    }
202}
203
204impl<'a, N, E, F, A, H, S> PathFinder<A::Ix> for AStarSearch<'a, N, E, A, F, H, S>
205where
206    A: GraphProps,
207    H: HyperGraph<N, E, A>,
208    F: Heuristic<A::Ix, Output = f64>,
209    S: BuildHasher,
210    A::Ix: HyperIndex,
211    for<'b> &'b <H::Edge<E> as RawLayout>::Store: IntoIterator<Item = &'b VertexId<A::Ix>>,
212{
213    type Path = Vec<VertexId<A::Ix>>;
214    /// Find the shortest path between start and goal vertices
215    fn find_path(&mut self, start: VertexId<A::Ix>, goal: VertexId<A::Ix>) -> Result<Self::Path> {
216        // Check if both vertices exist
217        if !self.graph.contains_node(&start) {
218            return Err(rshyper::Error::NodeNotFound.into());
219        }
220        if !self.graph.contains_node(&goal) {
221            return Err(rshyper::Error::NodeNotFound.into());
222        }
223
224        // reset state
225        self.reset();
226        // initialize g_score for start node (0) and infinity for all other nodes
227        self.g_score_mut().insert(start, 0.0);
228
229        // initialize f_score for start node (heuristic only since g=0)
230        let initial_fscore = self.heuristic().compute(start, goal);
231        self.f_score_mut().insert(start, initial_fscore);
232        // add start node to the open set
233        self.open_set_mut().insert(start);
234        // initialize priority queue
235        let mut priority_queue = BinaryHeap::new();
236        // push the start node with its f_score
237        priority_queue.push(PriorityNode {
238            vertex: start,
239            priority: -(initial_fscore as i64),
240        });
241        // track processed nodes to avoid duplicate processing
242        let mut processed = HashSet::new();
243        // process nodes until the queue is empty or we attain the goal
244        while let Some(PriorityNode {
245            vertex: current, ..
246        }) = priority_queue.pop()
247        {
248            // Skip if we've already processed this vertex with a better path
249            // or it's no longer in the open set
250            if processed.contains(&current) || !self.in_open_set(&current) {
251                continue;
252            }
253            // add the current vertex to the processed set
254            processed.insert(current);
255
256            // If we've reached the goal, construct and return the path
257            if current == goal {
258                return Ok(self.reconstruct_path(goal));
259            }
260
261            // Move from open to closed set
262            self.move_open_to_closed(&current);
263
264            // Get all hyperedges containing the current vertex
265            self.graph
266                .find_edges_with_node(&current)
267                .for_each(|edge_id| {
268                    // Get all vertices in this hyperedge
269                    let vertices = self
270                        .graph
271                        .get_edge_domain(edge_id)
272                        .expect("Failed to get edge vertices");
273
274                    // Process each vertex in this hyperedge
275                    for &neighbor in vertices {
276                        // Skip if this is the current vertex or already evaluated
277                        if neighbor == current || self.has_visited(&neighbor) {
278                            continue;
279                        }
280
281                        // Cost to reach neighbor through current vertex
282                        let tentative_g_score = self.g_score[&current] + 1.0;
283
284                        // Check if this path is better than any previous path
285                        let is_better_path = !self.has_g_score(&neighbor)
286                            || tentative_g_score < self.g_score[&neighbor];
287
288                        if is_better_path {
289                            // Update path info
290                            self.came_from_mut().insert(neighbor, current);
291                            self.g_score_mut().insert(neighbor, tentative_g_score);
292
293                            // Update f_score (g_score + heuristic)
294                            let f_score =
295                                tentative_g_score + self.heuristic().compute(neighbor, goal);
296                            self.f_score_mut().insert(neighbor, f_score);
297
298                            // Add to open set if not already there
299                            if !self.in_open_set(&neighbor) {
300                                self.open_set_mut().insert(neighbor);
301                            }
302
303                            // push the neighbor into the priority queue with its f_score (negative for min-heap behavior)
304                            priority_queue.push(PriorityNode {
305                                vertex: neighbor,
306                                priority: -(f_score as i64),
307                            });
308                        }
309                    }
310                });
311        }
312
313        // No path found
314        Err(Error::PathNotFound)
315    }
316
317    // Reconstruct path from came_from map
318    fn reconstruct_path(&self, goal: VertexId<A::Ix>) -> Self::Path {
319        let mut path = vec![goal];
320        let mut current = goal;
321
322        while let Some(&prev) = self.came_from.get(&current) {
323            path.push(prev);
324            current = prev;
325        }
326
327        path.reverse();
328        path
329    }
330}
331
332impl<'a, N, E, F, A, H, S> Traversal<VertexId<A::Ix>> for AStarSearch<'a, N, E, A, F, H, S>
333where
334    A: GraphProps,
335    F: Heuristic<A::Ix, Output = f64>,
336    H: HyperGraph<N, E, A>,
337    S: BuildHasher,
338    A::Ix: Eq + Hash,
339{
340    type Store<U> = HashSet<U, S>;
341
342    fn has_visited(&self, vertex: &VertexId<A::Ix>) -> bool {
343        self.visited().contains(vertex)
344    }
345
346    fn visited(&self) -> &Self::Store<VertexId<A::Ix>> {
347        self.closed_set()
348    }
349}
350
351impl<'a, N, E, F, A, H, S> Search<VertexId<A::Ix>> for AStarSearch<'a, N, E, A, F, H, S>
352where
353    A: GraphProps,
354    F: Heuristic<A::Ix, Output = f64>,
355    H: HyperGraphIter<N, E, A>,
356    S: BuildHasher,
357    A::Ix: HyperIndex,
358    for<'b> &'b <H::Edge<E> as RawLayout>::Store: IntoIterator<Item = &'b VertexId<A::Ix>>,
359{
360    type Output = Vec<VertexId<A::Ix>>;
361
362    fn search(&mut self, start: VertexId<A::Ix>) -> Result<Self::Output> {
363        // For A*, we need a goal vertex to compute the heuristic
364        // This implementation of search will explore the graph and return
365        // all reachable vertices ordered by their distance from start
366        self.reset();
367
368        if !self.graph.contains_node(&start) {
369            return Err(rshyper::Error::NodeNotFound.into());
370        }
371
372        // Using the vertex with the largest ID as a pseudo-goal
373        // This is a hack to make A* behave more like a general search
374        let max_vertex_id = match self.graph.vertices().max() {
375            Some(&id) => id,
376            None => return Ok(vec![]),
377        };
378
379        self.find_path(start, max_vertex_id)
380    }
381}