rs_graph/search/
astar.rs

1/*
2 * Copyright (c) 2018, 2021 Frank Fischer <frank-fischer@shadow-soft.de>
3 *
4 * This program is free software: you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License as
6 * published by the Free Software Foundation, either version 3 of the
7 * License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful, but
10 * WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 * General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
16 */
17
18//! A* search.
19//!
20//! This module implements an A*-search for finding a shortest paths
21//! from some node to all other nodes. Each node may be assigned a
22//! potential (or "heuristic value") estimating the distance to the target
23//! node. The potential $h\colon V \to \mathbb{R}$ must satisfy
24//! \\[ w(u,v) - h(u) + h(v) \ge 0, (u,v) \in E \\]
25//! where $w\colon E \to \mathbb{R}$ are the weights (or lengths) of the edges.
26//! (The relation must hold for both directions in case the graph is
27//! undirected).
28//!
29//! If $s \in V$ is the start node and $t$ some destination node, then
30//! $h(u) - h(t)$ is a lower bound on the distance from $u$ to $t$ for all nodes $u \in V$.
31//! Hence, f the shortest path to some specific destination node $t$ should be
32//! found the canonical choice for $h$ is such that $h(t) = 0$ and $h(u)$ is a
33//! lower bound on the distance from $u$ to $t$.
34//!
35//! # Example
36//!
37//! ```
38//! use rs_graph::traits::*;
39//! use rs_graph::search::astar;
40//! use rs_graph::string::{from_ascii, Data};
41//! use rs_graph::LinkedListGraph;
42//!
43//! let Data {
44//!     graph: g,
45//!     weights,
46//!     nodes,
47//! } = from_ascii::<LinkedListGraph>(
48//!     r"
49//!     *--1--*--1--*--1--*--1--*--1--*--1--*--1--*
50//!     |     |     |     |     |     |     |     |
51//!     1     1     1     1     1     1     1     1
52//!     |     |     |     |     |     |     |     |
53//!     *--1--*--2--*--1--*--2--e--1--f--1--t--1--*
54//!     |     |     |     |     |     |     |     |
55//!     1     1     1     1     1     2     1     1
56//!     |     |     |     |     |     |     |     |
57//!     *--1--*--1--*--2--c--1--d--1--*--2--*--1--*
58//!     |     |     |     |     |     |     |     |
59//!     1     1     1     1     1     1     1     1
60//!     |     |     |     |     |     |     |     |
61//!     *--1--s--1--a--1--b--2--*--1--*--1--*--1--*
62//!     |     |     |     |     |     |     |     |
63//!     1     1     1     1     1     1     1     1
64//!     |     |     |     |     |     |     |     |
65//!     *--1--*--1--*--1--*--1--*--1--*--1--*--1--*
66//!     ",
67//! )
68//! .unwrap();
69//!
70//! let s = g.id2node(nodes[&'s']);
71//! let t = g.id2node(nodes[&'t']);
72//!
73//! // nodes are numbered row-wise -> get node coordinates
74//! let coords = |u| ((g.node_id(u) % 8) as isize, (g.node_id(u) / 8) as isize);
75//!
76//! let (xs, ys) = coords(s);
77//! let (xt, yt) = coords(t);
78//!
79//! // manhatten distance heuristic
80//! let manh_heur = |u| {
81//!     let (x, y) = coords(u);
82//!     ((x - xt).abs() + (y - yt).abs()) as usize
83//! };
84//!
85//! // verify that we do not go in the "wrong" direction
86//! for (v, _, _) in astar::start(g.neighbors(), s, |e| weights[e.index()], manh_heur) {
87//!     let (x, y) = coords(v);
88//!     assert!(x >= xs && x <= xt && y >= yt && y <= ys);
89//!     if v == t {
90//!         break;
91//!     }
92//! }
93//!
94//! // obtain the shortest path directly
95//! let (path, dist) = astar::find_undirected_path(&g, s, t, |e| weights[e.index()], manh_heur).unwrap();
96//!
97//! assert_eq!(dist, 7);
98//!
99//! let mut pathnodes = vec![s];
100//! for e in path {
101//!     let uv = g.enodes(e);
102//!     if uv.0 == *pathnodes.last().unwrap() {
103//!         pathnodes.push(uv.1);
104//!     } else {
105//!         pathnodes.push(uv.0);
106//!     }
107//! }
108//! assert_eq!(pathnodes, "sabcdeft".chars().map(|c| g.id2node(nodes[&c])).collect::<Vec<_>>());
109//! ```
110
111use crate::adjacencies::{Adjacencies, Neighbors, OutEdges};
112use crate::collections::BinHeap;
113use crate::collections::{ItemMap, ItemPriQueue};
114use crate::search::path_from_incomings;
115use crate::traits::{Digraph, Graph};
116
117use num_traits::Zero;
118
119use std::cmp::Ordering;
120use std::collections::HashMap;
121use std::hash::Hash;
122use std::marker::PhantomData;
123use std::ops::{Add, Sub};
124
125/// A* search iterator.
126pub struct AStar<'a, A, D, W, M, P, H, Accum>
127where
128    A: Adjacencies<'a>,
129    M: ItemMap<A::Node, Option<P::Item>>,
130    P: ItemPriQueue<A::Node, Data<A::Edge, D, H::Result>>,
131    D: Copy,
132    W: Fn(A::Edge) -> D,
133    H: AStarHeuristic<A::Node>,
134    H::Result: Copy,
135    Accum: Accumulator<D>,
136{
137    adj: A,
138    nodes: M,
139    pqueue: P,
140    weights: W,
141    heur: H,
142    phantom: PhantomData<&'a (D, Accum)>,
143}
144
145/// The data stored with an edge during the search.
146#[derive(Clone)]
147pub struct Data<E, D, H> {
148    /// incoming edge on currently best path
149    pub incoming_edge: E,
150    /// currently best known distance
151    pub distance: D,
152    /// the lower bound of this node
153    lower: H,
154}
155
156impl<E, D, H> PartialEq for Data<E, D, H>
157where
158    D: PartialEq,
159{
160    fn eq(&self, data: &Self) -> bool {
161        self.distance.eq(&data.distance)
162    }
163}
164
165impl<E, D, H> PartialOrd for Data<E, D, H>
166where
167    D: PartialOrd + Clone,
168    H: Add<D, Output = D> + Clone,
169{
170    fn partial_cmp(&self, data: &Self) -> Option<Ordering> {
171        (self.lower.clone() + self.distance.clone()).partial_cmp(&(data.lower.clone() + data.distance.clone()))
172    }
173}
174
175/// A heuristic providing a node potential.
176///
177/// The node potential must satisfy that $w(u,v) - h(u) + h(v) \ge 0$ for all
178/// edges $(u,v) \in E$. This means that $h(u) - h(t)$ must be a lower bound for
179/// the distance from $u$ to the destination node $t$. Usually one chooses $h(t)
180/// = 0$ for the destination node $t$.
181pub trait AStarHeuristic<N> {
182    type Result: Copy + Default;
183
184    fn call(&self, u: N) -> Self::Result;
185}
186
187impl<F, N, H> AStarHeuristic<N> for F
188where
189    F: Fn(N) -> H,
190    H: Copy + Default,
191{
192    type Result = H;
193
194    fn call(&self, u: N) -> H {
195        (*self)(u)
196    }
197}
198
199/// A binary operation used to accumulate edge weight and distance.
200///
201/// The default operation for Dijkstra's algorithm is the sum, for Prim's
202/// algorithm it is simply the edge weight ignoring the "distance".
203pub trait Accumulator<T> {
204    fn accum(dist: T, weight: T) -> T;
205}
206
207/// Accumulates by adding distance and weight.
208pub struct SumAccumulator;
209
210impl<T> Accumulator<T> for SumAccumulator
211where
212    T: Add<Output = T>,
213{
214    fn accum(dist: T, weight: T) -> T {
215        dist + weight
216    }
217}
218
219/// Default map type to be used in an A* search.
220///
221/// - `A` is the graph type information
222/// - `D` is the type of distance values
223/// - `H` is the type of heuristic values
224pub type DefaultMap<'a, A, D, H> = HashMap<
225    <A as Adjacencies<'a>>::Node,
226    Option<
227        <BinHeap<<A as Adjacencies<'a>>::Node, Data<<A as Adjacencies<'a>>::Edge, D, H>> as ItemPriQueue<
228            <A as Adjacencies<'a>>::Node,
229            Data<<A as Adjacencies<'a>>::Edge, D, H>,
230        >>::Item,
231    >,
232>;
233
234/// Default priority queue type to be used in an A* search.
235///
236/// - `A` is the graph type information
237/// - `D` is the type of distance values
238/// - `H` is the type of heuristic values
239pub type DefaultPriQueue<'a, A, D, H> = BinHeap<<A as Adjacencies<'a>>::Node, Data<<A as Adjacencies<'a>>::Edge, D, H>>;
240
241/// The default data structures to be used in an A* search.
242pub type DefaultData<ID, N, I, D> = (HashMap<N, I>, BinHeap<N, D, ID>);
243
244/// The A*-iterator with default types.
245pub type AStarDefault<'a, A, D, W, H> = AStar<
246    'a,
247    A,
248    D,
249    W,
250    DefaultMap<'a, A, D, <H as AStarHeuristic<<A as Adjacencies<'a>>::Node>>::Result>,
251    DefaultPriQueue<'a, A, D, <H as AStarHeuristic<<A as Adjacencies<'a>>::Node>>::Result>,
252    H,
253    SumAccumulator,
254>;
255
256/// Start and return an A*-iterator using default data structures.
257///
258/// This is a convenience wrapper around [`start_with_data`] using the default
259/// data structures [`DefaultData`].
260///
261/// # Parameter
262/// - `adj`: adjacency information for the graph
263/// - `src`: the source node at which the search should start.
264/// - `weights`: the weight function for each edge
265/// - `heur`: the heuristic used in the search
266pub fn start<'a, A, D, W, H>(adj: A, src: A::Node, weights: W, heur: H) -> AStarDefault<'a, A, D, W, H>
267where
268    A: Adjacencies<'a>,
269    A::Node: Hash,
270    D: Copy + PartialOrd + Zero,
271    W: Fn(A::Edge) -> D,
272    H: AStarHeuristic<A::Node>,
273    H::Result: Add<D, Output = D>,
274{
275    start_with_data(adj, src, weights, heur, DefaultData::default())
276}
277
278/// Start and return an A*-iterator with custom data structures.
279///
280/// The returned iterator traverses the edges in the order of an A*-search. The
281/// iterator returns the next node, its incoming edge and the distance to the
282/// start node.
283///
284/// The heuristic is a assigning a potential to each node. The potential of all
285/// nodes must be so that $w(u,v) - h(u) + h(v) \ge 0$ for all edges $(u,v) \in
286/// E$. If $t$ is the destination node of the path then $h(u) - h(t)$ is a lower
287/// bound on the distance from $u$ to $t$ for each node $u \in V$ (in this case
288/// one usually chooses $h(t) = 0$). The value returned by the heuristic must be
289/// compatible with the distance type, i.e., is must be possible to compute the
290/// sum of both.
291///
292/// Note that the start node is *not* returned by the iterator.
293///
294/// The algorithm requires a pair `(M, P)` with `M` implementing [`ItemMap<Node,
295/// Item>`][crate::collections::ItemMap], and `P` implementing
296/// [`ItemPriQueue<Node, D>`][crate::collections::ItemStack] as internal data
297/// structures. The map is used to store information about the last edge on a
298/// shortest path for each reachable node. The priority queue is used the handle
299/// the nodes in the correct order. The data structures can be reused for
300/// multiple searches.
301///
302/// This function uses the default data structures [`DefaultData`].
303///
304/// # Parameter
305/// - `adj`: adjacency information for the graph
306/// - `src`: the source node at which the search should start.
307/// - `weights`: the weight function for each edge
308/// - `heur`: the heuristic used in the search
309/// - `data`: the custom data structures
310pub fn start_with_data<'a, A, D, W, H, M, P>(
311    adj: A,
312    src: A::Node,
313    weights: W,
314    heur: H,
315    data: (M, P),
316) -> AStar<'a, A, D, W, M, P, H, SumAccumulator>
317where
318    A: Adjacencies<'a>,
319    D: Copy + PartialOrd + Zero,
320    W: Fn(A::Edge) -> D,
321    H: AStarHeuristic<A::Node>,
322    H::Result: Add<D, Output = D>,
323    M: ItemMap<A::Node, Option<P::Item>>,
324    P: ItemPriQueue<A::Node, Data<A::Edge, D, H::Result>>,
325{
326    start_generic(adj, src, weights, heur, data)
327}
328
329/// Start and return an A*-iterator with a custom accumulator and custom data structures.
330///
331/// This function differs from [`start_with_data`] in the additional type
332/// parameter `Accum`. The type parameter is the accumulation function for
333/// combining the length to the previous node with the weight of the current
334/// edge. It is usually just the sum ([`SumAccumulator`]). One possible use is
335/// the Prim's algorithm for the minimum spanning tree problem (see
336/// [`mst::prim`](crate::mst::prim())).
337pub fn start_generic<'a, A, D, W, H, Accum, M, P>(
338    adj: A,
339    src: A::Node,
340    weights: W,
341    heur: H,
342    data: (M, P),
343) -> AStar<'a, A, D, W, M, P, H, Accum>
344where
345    A: Adjacencies<'a>,
346    D: Copy + PartialOrd + PartialOrd + Zero,
347    W: Fn(A::Edge) -> D,
348    H: AStarHeuristic<A::Node>,
349    H::Result: Add<D, Output = D>,
350    M: ItemMap<A::Node, Option<P::Item>>,
351    P: ItemPriQueue<A::Node, Data<A::Edge, D, H::Result>>,
352    Accum: Accumulator<D>,
353{
354    let (mut nodes, mut pqueue) = data;
355    pqueue.clear();
356    nodes.clear();
357    nodes.insert(src, None);
358
359    // insert neighbors of source
360    for (e, v) in adj.neighs(src) {
361        let dist = Accum::accum(D::zero(), (weights)(e));
362        match nodes.get_mut(v) {
363            Some(Some(vitem)) => {
364                // node is known but unhandled
365                let (olddist, lower) = {
366                    let data = pqueue.value(vitem);
367                    (data.distance, data.lower)
368                };
369                if dist < olddist {
370                    pqueue.decrease_key(
371                        vitem,
372                        Data {
373                            incoming_edge: e,
374                            distance: dist,
375                            lower,
376                        },
377                    );
378                }
379            }
380            None => {
381                // node is unknown
382                let item = pqueue.push(
383                    v,
384                    Data {
385                        incoming_edge: e,
386                        distance: dist,
387                        lower: heur.call(v),
388                    },
389                );
390                nodes.insert(v, Some(item));
391            }
392            _ => (), // node has been handled
393        };
394    }
395
396    AStar {
397        adj,
398        nodes,
399        pqueue,
400        weights,
401        heur,
402        phantom: PhantomData,
403    }
404}
405
406impl<'a, A, D, W, M, P, H, Accum> Iterator for AStar<'a, A, D, W, M, P, H, Accum>
407where
408    A: Adjacencies<'a>,
409    D: Copy + PartialOrd + Add<D, Output = D> + Sub<D, Output = D>,
410    W: Fn(A::Edge) -> D,
411    M: ItemMap<A::Node, Option<P::Item>>,
412    P: ItemPriQueue<A::Node, Data<A::Edge, D, H::Result>>,
413    H: AStarHeuristic<A::Node>,
414    H::Result: Add<D, Output = D>,
415    Accum: Accumulator<D>,
416{
417    type Item = (A::Node, A::Edge, D);
418
419    fn next(&mut self) -> Option<Self::Item> {
420        if let Some((u, data)) = self.pqueue.pop_min() {
421            // node is not in the queue anymore, forget its item
422            self.nodes.insert_or_replace(u, None);
423            let (d, incoming_edge) = (data.distance, data.incoming_edge);
424            for (e, v) in self.adj.neighs(u) {
425                let dist = Accum::accum(d, (self.weights)(e));
426                match self.nodes.get_mut(v) {
427                    Some(Some(vitem)) => {
428                        // node is known but unhandled
429                        let (olddist, lower) = {
430                            let data = self.pqueue.value(vitem);
431                            (data.distance, data.lower)
432                        };
433                        if dist < olddist {
434                            self.pqueue.decrease_key(
435                                vitem,
436                                Data {
437                                    incoming_edge: e,
438                                    distance: dist,
439                                    lower,
440                                },
441                            );
442                        }
443                    }
444                    None => {
445                        // node is unknown
446                        let item = self.pqueue.push(
447                            v,
448                            Data {
449                                incoming_edge: e,
450                                distance: dist,
451                                lower: self.heur.call(v),
452                            },
453                        );
454                        self.nodes.insert(v, Some(item));
455                    }
456                    _ => (), // node has been handled
457                };
458            }
459            Some((u, incoming_edge, d))
460        } else {
461            None
462        }
463    }
464}
465
466impl<'a, A, D, W, M, P, H, Accum> AStar<'a, A, D, W, M, P, H, Accum>
467where
468    A: Adjacencies<'a>,
469    D: Copy + PartialOrd + Add<D, Output = D> + Sub<D, Output = D>,
470    W: Fn(A::Edge) -> D,
471    M: ItemMap<A::Node, Option<P::Item>>,
472    P: ItemPriQueue<A::Node, Data<A::Edge, D, H::Result>>,
473    H: AStarHeuristic<A::Node>,
474    H::Result: Add<D, Output = D>,
475    Accum: Accumulator<D>,
476{
477    /// Run the search completely.
478    ///
479    /// Note that this method may run forever on an infinite graph.
480    pub fn run(&mut self) {
481        while self.next().is_some() {}
482    }
483
484    /// Return the data structures used during the algorithm
485    pub fn into_data(self) -> (M, P) {
486        (self.nodes, self.pqueue)
487    }
488}
489
490/// Start an A*-search on a undirected graph.
491///
492/// Each edge can be traversed in both directions with the same weight.
493///
494/// This is a convenience wrapper to start the search on an undirected graph
495/// with the default data structures.
496///
497/// # Parameter
498/// - `g`: the graph
499/// - `weights`: the (non-negative) edge weights
500/// - `src`: the source node
501/// - `heur`: the lower bound heuristic
502pub fn start_undirected<'a, G, D, W, H>(
503    g: &'a G,
504    src: G::Node<'a>,
505    weights: W,
506    heur: H,
507) -> AStarDefault<'a, Neighbors<'a, G>, D, W, H>
508where
509    G: Graph,
510    G::Node<'a>: Hash,
511    D: Copy + PartialOrd + Zero,
512    W: Fn(G::Edge<'a>) -> D,
513    H: AStarHeuristic<G::Node<'a>>,
514    H::Result: Add<D, Output = D>,
515{
516    start(Neighbors(g), src, weights, heur)
517}
518
519/// Run an A*-search on an undirected graph and return the path.
520///
521/// Each edge can be traversed in both directions with the same weight.
522///
523/// This is a convenience wrapper to run the search on an undirected graph with
524/// the default data structures and return the resulting path from `src` to
525/// `snk`.
526///
527/// # Parameter
528/// - `g`: the graph
529/// - `weights`: the (non-negative) edge weights
530/// - `src`: the source node
531/// - `snk`: the sink node
532/// - `heur`: the lower bound heuristic
533///
534/// The function returns the edges on the path and its length.
535pub fn find_undirected_path<'a, G, D, W, H>(
536    g: &'a G,
537    src: G::Node<'a>,
538    snk: G::Node<'a>,
539    weights: W,
540    heur: H,
541) -> Option<(Vec<G::Edge<'a>>, D)>
542where
543    G: Graph,
544    G::Node<'a>: Hash,
545    D: 'a + Copy + PartialOrd + Zero + Add<D, Output = D> + Sub<D, Output = D>,
546    W: Fn(G::Edge<'a>) -> D,
547    H: AStarHeuristic<G::Node<'a>>,
548    H::Result: Add<D, Output = D>,
549{
550    if src == snk {
551        return Some((vec![], D::zero()));
552    }
553    // run search until sink node has been found
554    let mut incoming_edges = HashMap::new();
555    for (u, e, d) in start_undirected(g, src, weights, heur) {
556        incoming_edges.insert(u, e);
557        if u == snk {
558            let mut path = path_from_incomings(snk, |u| {
559                incoming_edges
560                    .get(&u)
561                    .map(|&e| (e, g.enodes(e)))
562                    .map(|(e, (v, w))| (e, if v == u { w } else { v }))
563            })
564            .collect::<Vec<_>>();
565            path.reverse();
566            return Some((path, d));
567        }
568    }
569    None
570}
571
572/// Start an A*-search on a directed graph.
573///
574/// This is a convenience wrapper to start the search on an directed graph
575/// with the default data structures.
576///
577/// # Parameter
578/// - `g`: the graph
579/// - `weights`: the (non-negative) edge weights
580/// - `src`: the source node
581/// - `heur`: the lower bound heuristic
582pub fn start_directed<'a, G, D, W, H>(
583    g: &'a G,
584    src: G::Node<'a>,
585    weights: W,
586    heur: H,
587) -> AStarDefault<'a, OutEdges<'a, G>, D, W, H>
588where
589    G: Digraph,
590    G::Node<'a>: Hash,
591    D: Copy + PartialOrd + Zero,
592    W: Fn(G::Edge<'a>) -> D,
593    H: AStarHeuristic<G::Node<'a>>,
594    H::Result: Add<D, Output = D>,
595{
596    start(OutEdges(g), src, weights, heur)
597}
598
599/// Run an A*-search on a directed graph and return the path.
600///
601/// This is a convenience wrapper to run the search on an directed graph with
602/// the default data structures and return the resulting path from `src` to
603/// `snk`.
604///
605/// # Parameter
606/// - `g`: the graph
607/// - `weights`: the (non-negative) edge weights
608/// - `src`: the source node
609/// - `snk`: the sink node
610/// - `heur`: the lower bound heuristic
611///
612/// The function returns the edges on the path and its length.
613pub fn find_directed_path<'a, G, D, W, H>(
614    g: &'a G,
615    src: G::Node<'a>,
616    snk: G::Node<'a>,
617    weights: W,
618    heur: H,
619) -> Option<(Vec<G::Edge<'a>>, D)>
620where
621    G: Digraph,
622    G::Node<'a>: Hash,
623    D: 'a + Copy + PartialOrd + Zero + Add<D, Output = D> + Sub<D, Output = D>,
624    W: Fn(G::Edge<'a>) -> D,
625    H: AStarHeuristic<G::Node<'a>>,
626    H::Result: Add<D, Output = D>,
627{
628    if src == snk {
629        return Some((vec![], D::zero()));
630    }
631    // run search until sink node has been found
632    let mut incoming_edges = HashMap::new();
633    for (u, e, d) in start_directed(g, src, weights, heur) {
634        incoming_edges.insert(u, e);
635        if u == snk {
636            let mut path =
637                path_from_incomings(snk, |u| incoming_edges.get(&u).map(|&e| (e, g.src(e)))).collect::<Vec<_>>();
638            path.reverse();
639            return Some((path, d));
640        }
641    }
642    None
643}