use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::algorithms::paths::dijkstra::DijkstraMode;
use crate::core::graph::EdgeId;
use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
#[derive(Copy, Clone)]
struct Frontier(f64, u64, VertexId);
impl PartialEq for Frontier {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1 && self.2 == other.2
}
}
impl Eq for Frontier {}
impl Ord for Frontier {
fn cmp(&self, other: &Self) -> Ordering {
other
.0
.total_cmp(&self.0)
.then(other.1.cmp(&self.1))
.then(other.2.cmp(&self.2))
}
}
impl PartialOrd for Frontier {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
fn validate_weights(graph: &Graph, weights: Option<&[f64]>) -> IgraphResult<()> {
let Some(w) = weights else {
return Ok(());
};
let m = graph.ecount();
if w.len() != m {
return Err(IgraphError::InvalidArgument(format!(
"weights vector size ({}) differs from edge count ({})",
w.len(),
m
)));
}
for (e, &v) in w.iter().enumerate() {
if v.is_nan() {
return Err(IgraphError::InvalidArgument(format!(
"weight at edge {e} is NaN"
)));
}
if v < 0.0 {
return Err(IgraphError::InvalidArgument(format!(
"weight at edge {e} is negative ({v}); A* requires non-negative weights"
)));
}
}
Ok(())
}
fn incident_for_mode(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<EdgeId>> {
if !graph.is_directed() {
return graph.incident(v);
}
match mode {
DijkstraMode::Out => graph.incident(v),
DijkstraMode::In => graph.incident_in(v),
DijkstraMode::All => {
let mut out = graph.incident(v)?;
out.extend(graph.incident_in(v)?);
Ok(out)
}
}
}
pub fn a_star_path<H: Fn(VertexId, VertexId) -> f64>(
graph: &Graph,
from: VertexId,
to: VertexId,
weights: Option<&[f64]>,
mode: DijkstraMode,
heuristic: H,
) -> IgraphResult<Option<(Vec<VertexId>, Vec<EdgeId>)>> {
let n = graph.vcount();
if from >= n {
return Err(IgraphError::VertexOutOfRange { id: from, n });
}
if to >= n {
return Err(IgraphError::VertexOutOfRange { id: to, n });
}
validate_weights(graph, weights)?;
if from == to {
return Ok(Some((vec![from], Vec::new())));
}
let n_us = n as usize;
let mut dist = vec![f64::INFINITY; n_us];
let mut parent_eid: Vec<Option<EdgeId>> = vec![None; n_us];
let mut closed = vec![false; n_us];
let mut tiebreaker: u64 = 0;
let mut next_tb = || {
let t = tiebreaker;
tiebreaker += 1;
t
};
let mut heap: BinaryHeap<Frontier> = BinaryHeap::new();
dist[from as usize] = 0.0;
let h0 = heuristic(from, to);
if h0.is_nan() || h0 < 0.0 {
return Err(IgraphError::InvalidArgument(format!(
"heuristic returned invalid estimate ({h0}); must be non-negative and not NaN"
)));
}
heap.push(Frontier(h0, next_tb(), from));
let mut found = false;
while let Some(Frontier(_, _, u)) = heap.pop() {
if closed[u as usize] {
continue;
}
closed[u as usize] = true;
if u == to {
found = true;
break;
}
for eid in incident_for_mode(graph, u, mode)? {
let w = match weights {
None => 1.0,
Some(ws) => ws[eid as usize],
};
if !w.is_finite() {
continue;
}
let v = graph.edge_other(eid as EdgeId, u)?;
if closed[v as usize] {
continue;
}
let altdist = dist[u as usize] + w;
let curdist = dist[v as usize];
if !curdist.is_finite() || altdist < curdist {
dist[v as usize] = altdist;
parent_eid[v as usize] = Some(eid as EdgeId);
let h = heuristic(v, to);
if h.is_nan() || h < 0.0 {
return Err(IgraphError::InvalidArgument(format!(
"heuristic returned invalid estimate ({h}); must be non-negative and not NaN"
)));
}
heap.push(Frontier(altdist + h, next_tb(), v));
}
}
}
if !found {
return Ok(None);
}
let mut vs = Vec::new();
let mut es = Vec::new();
let mut cur = to;
while let Some(eid) = parent_eid[cur as usize] {
es.push(eid);
vs.push(cur);
cur = graph.edge_other(eid, cur)?;
}
vs.push(cur);
vs.reverse();
es.reverse();
Ok(Some((vs, es)))
}
#[cfg(test)]
mod tests {
use super::*;
fn null_h(_: VertexId, _: VertexId) -> f64 {
0.0
}
#[test]
fn unit_weights_match_bfs_chain() {
let mut g = Graph::with_vertices(4);
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(2, 3).unwrap();
let (vs, es) = a_star_path(&g, 0, 3, None, DijkstraMode::Out, null_h)
.unwrap()
.unwrap();
assert_eq!(vs, vec![0, 1, 2, 3]);
assert_eq!(es, vec![0, 1, 2]);
}
#[test]
fn weighted_triangle_with_shortcut() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap(); g.add_edge(0, 2).unwrap(); g.add_edge(1, 2).unwrap(); let weights = [1.0, 4.0, 2.0];
let (vs, es) = a_star_path(&g, 0, 2, Some(&weights), DijkstraMode::Out, null_h)
.unwrap()
.unwrap();
assert_eq!(vs, vec![0, 1, 2]);
assert_eq!(es, vec![0, 2]);
}
#[test]
fn unreachable_target_returns_none() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap();
assert_eq!(
a_star_path(&g, 0, 2, None, DijkstraMode::Out, null_h).unwrap(),
None
);
}
#[test]
fn from_equals_to_singleton_chain() {
let g = Graph::with_vertices(3);
let (vs, es) = a_star_path(&g, 1, 1, None, DijkstraMode::Out, null_h)
.unwrap()
.unwrap();
assert_eq!(vs, vec![1]);
assert!(es.is_empty());
}
#[test]
fn admissible_heuristic_finds_same_path_as_null() {
let mut g = Graph::with_vertices(4);
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(0, 3).unwrap();
g.add_edge(3, 2).unwrap();
let h = |v: VertexId, target: VertexId| -> f64 { if v == target { 0.0 } else { 1.0 } };
let (vs, _) = a_star_path(&g, 0, 2, None, DijkstraMode::Out, h)
.unwrap()
.unwrap();
assert_eq!(vs.len(), 3);
assert_eq!(vs[0], 0);
assert_eq!(vs[2], 2);
}
#[test]
fn directed_in_mode_walks_reverse_edges() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
assert_eq!(
a_star_path(&g, 2, 0, None, DijkstraMode::Out, null_h).unwrap(),
None
);
let (vs, es) = a_star_path(&g, 2, 0, None, DijkstraMode::In, null_h)
.unwrap()
.unwrap();
assert_eq!(vs, vec![2, 1, 0]);
assert_eq!(es, vec![1, 0]);
}
#[test]
fn negative_weight_errors() {
let mut g = Graph::with_vertices(2);
g.add_edge(0, 1).unwrap();
let weights = [-1.0_f64];
assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
}
#[test]
fn nan_weight_errors() {
let mut g = Graph::with_vertices(2);
g.add_edge(0, 1).unwrap();
let weights = [f64::NAN];
assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
}
#[test]
fn weights_size_mismatch_errors() {
let mut g = Graph::with_vertices(2);
g.add_edge(0, 1).unwrap();
let weights = [1.0_f64, 2.0];
assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
}
#[test]
fn out_of_range_source_errors() {
let g = Graph::with_vertices(2);
assert!(a_star_path(&g, 99, 0, None, DijkstraMode::Out, null_h).is_err());
assert!(a_star_path(&g, 0, 99, None, DijkstraMode::Out, null_h).is_err());
}
#[test]
fn negative_heuristic_errors() {
let mut g = Graph::with_vertices(2);
g.add_edge(0, 1).unwrap();
let bad_h = |_v: VertexId, _t: VertexId| -1.0_f64;
assert!(a_star_path(&g, 0, 1, None, DijkstraMode::Out, bad_h).is_err());
}
#[test]
fn infinity_weight_skipped() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap(); g.add_edge(0, 2).unwrap(); g.add_edge(2, 1).unwrap(); let weights = [f64::INFINITY, 1.0, 1.0];
let (vs, _) = a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h)
.unwrap()
.unwrap();
assert_eq!(vs, vec![0, 2, 1]);
}
}