use petgraph::algo::dijkstra;
use petgraph::visit::EdgeRef;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use crate::base::{DiGraph, EdgeWeight, Graph, Node};
use crate::error::{GraphError, Result};
#[derive(Debug, Clone)]
pub struct Path<N: Node + std::fmt::Debug, E: EdgeWeight> {
pub nodes: Vec<N>,
pub total_weight: E,
}
#[derive(Debug, Clone)]
pub struct AStarResult<N: Node + std::fmt::Debug, E: EdgeWeight> {
pub path: Vec<N>,
pub cost: E,
}
#[derive(Clone)]
struct AStarState<N: Node + std::fmt::Debug, E: EdgeWeight> {
node: N,
cost: E,
heuristic: E,
path: Vec<N>,
}
impl<N: Node + std::fmt::Debug, E: EdgeWeight> PartialEq for AStarState<N, E> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
}
}
impl<N: Node + std::fmt::Debug, E: EdgeWeight> Eq for AStarState<N, E> {}
impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd> Ord
for AStarState<N, E>
{
fn cmp(&self, other: &Self) -> Ordering {
let self_total = self.cost + self.heuristic;
let other_total = other.cost + other.heuristic;
other_total
.partial_cmp(&self_total)
.unwrap_or(Ordering::Equal)
.then_with(|| {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
})
}
}
impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd>
PartialOrd for AStarState<N, E>
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[deprecated(
note = "Use `dijkstra_path` for future compatibility. This function will return PathResult in v1.0"
)]
#[allow(dead_code)]
pub fn shortest_path<N, E, Ix>(
graph: &Graph<N, E, Ix>,
source: &N,
target: &N,
) -> Result<Option<Path<N, E>>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::ops::Add<Output = E>
+ PartialOrd
+ std::marker::Copy
+ std::fmt::Debug
+ std::default::Default,
Ix: petgraph::graph::IndexType,
{
if !graph.has_node(source) {
return Err(GraphError::InvalidGraph(format!(
"Source node {source:?} not found"
)));
}
if !graph.has_node(target) {
return Err(GraphError::InvalidGraph(format!(
"Target node {target:?} not found"
)));
}
let source_idx = graph
.inner()
.node_indices()
.find(|&idx| graph.inner()[idx] == *source)
.expect("Test: operation failed");
let target_idx = graph
.inner()
.node_indices()
.find(|&idx| graph.inner()[idx] == *target)
.expect("Test: operation failed");
let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
if !results.contains_key(&target_idx) {
return Ok(None);
}
let total_weight = results[&target_idx];
let mut path = Vec::new();
let mut current = target_idx;
path.push(graph.inner()[current].clone());
while current != source_idx {
let min_prev = graph
.inner()
.edges_directed(current, petgraph::Direction::Incoming)
.filter_map(|e| {
let from = e.source();
let edge_weight = *e.weight();
if let Some(from_dist) = results.get(&from) {
if *from_dist + edge_weight == results[¤t] {
return Some((from, *from_dist));
}
}
None
})
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
if let Some((prev, _)) = min_prev {
current = prev;
path.push(graph.inner()[current].clone());
} else {
return Err(GraphError::AlgorithmError(
"Failed to reconstruct path".to_string(),
));
}
}
path.reverse();
Ok(Some(Path {
nodes: path,
total_weight,
}))
}
#[allow(dead_code)]
pub fn dijkstra_path<N, E, Ix>(
graph: &Graph<N, E, Ix>,
source: &N,
target: &N,
) -> Result<Option<Path<N, E>>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::ops::Add<Output = E>
+ PartialOrd
+ std::marker::Copy
+ std::fmt::Debug
+ std::default::Default,
Ix: petgraph::graph::IndexType,
{
#[allow(deprecated)]
shortest_path(graph, source, target)
}
#[deprecated(
note = "Use `dijkstra_path_digraph` for future compatibility. This function will return PathResult in v1.0"
)]
#[allow(dead_code)]
pub fn shortest_path_digraph<N, E, Ix>(
graph: &DiGraph<N, E, Ix>,
source: &N,
target: &N,
) -> Result<Option<Path<N, E>>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::ops::Add<Output = E>
+ PartialOrd
+ std::marker::Copy
+ std::fmt::Debug
+ std::default::Default,
Ix: petgraph::graph::IndexType,
{
if !graph.has_node(source) {
return Err(GraphError::InvalidGraph(format!(
"Source node {source:?} not found"
)));
}
if !graph.has_node(target) {
return Err(GraphError::InvalidGraph(format!(
"Target node {target:?} not found"
)));
}
let source_idx = graph
.inner()
.node_indices()
.find(|&idx| graph.inner()[idx] == *source)
.expect("Test: operation failed");
let target_idx = graph
.inner()
.node_indices()
.find(|&idx| graph.inner()[idx] == *target)
.expect("Test: operation failed");
let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
if !results.contains_key(&target_idx) {
return Ok(None);
}
let total_weight = results[&target_idx];
let mut path = Vec::new();
let mut current = target_idx;
path.push(graph.inner()[current].clone());
while current != source_idx {
let min_prev = graph
.inner()
.edges_directed(current, petgraph::Direction::Incoming)
.filter_map(|e| {
let from = e.source();
let edge_weight = *e.weight();
if let Some(from_dist) = results.get(&from) {
if *from_dist + edge_weight == results[¤t] {
return Some((from, *from_dist));
}
}
None
})
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
if let Some((prev, _)) = min_prev {
current = prev;
path.push(graph.inner()[current].clone());
} else {
return Err(GraphError::AlgorithmError(
"Failed to reconstruct path".to_string(),
));
}
}
path.reverse();
Ok(Some(Path {
nodes: path,
total_weight,
}))
}
#[allow(dead_code)]
pub fn dijkstra_path_digraph<N, E, Ix>(
graph: &DiGraph<N, E, Ix>,
source: &N,
target: &N,
) -> Result<Option<Path<N, E>>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::ops::Add<Output = E>
+ PartialOrd
+ std::marker::Copy
+ std::fmt::Debug
+ std::default::Default,
Ix: petgraph::graph::IndexType,
{
#[allow(deprecated)]
shortest_path_digraph(graph, source, target)
}
#[allow(dead_code)]
pub fn floyd_warshall<N, E, Ix>(
graph: &Graph<N, E, Ix>,
) -> Result<scirs2_core::ndarray::Array2<f64>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight + Into<f64> + scirs2_core::numeric::Zero + Copy,
Ix: petgraph::graph::IndexType,
{
let n = graph.node_count();
if n == 0 {
return Ok(scirs2_core::ndarray::Array2::zeros((0, 0)));
}
let mut dist = scirs2_core::ndarray::Array2::from_elem((n, n), f64::INFINITY);
for i in 0..n {
dist[[i, i]] = 0.0;
}
for edge in graph.inner().edge_references() {
let i = edge.source().index();
let j = edge.target().index();
let weight: f64 = (*edge.weight()).into();
dist[[i, j]] = weight;
dist[[j, i]] = weight;
}
for k in 0..n {
for i in 0..n {
for j in 0..n {
let alt = dist[[i, k]] + dist[[k, j]];
if alt < dist[[i, j]] {
dist[[i, j]] = alt;
}
}
}
}
Ok(dist)
}
#[allow(dead_code)]
pub fn floyd_warshall_digraph<N, E, Ix>(
graph: &DiGraph<N, E, Ix>,
) -> Result<scirs2_core::ndarray::Array2<f64>>
where
N: Node + std::fmt::Debug,
E: EdgeWeight + Into<f64> + scirs2_core::numeric::Zero + Copy,
Ix: petgraph::graph::IndexType,
{
let n = graph.node_count();
if n == 0 {
return Ok(scirs2_core::ndarray::Array2::zeros((0, 0)));
}
let mut dist = scirs2_core::ndarray::Array2::from_elem((n, n), f64::INFINITY);
for i in 0..n {
dist[[i, i]] = 0.0;
}
for edge in graph.inner().edge_references() {
let i = edge.source().index();
let j = edge.target().index();
let weight: f64 = (*edge.weight()).into();
dist[[i, j]] = weight;
}
for k in 0..n {
for i in 0..n {
for j in 0..n {
let alt = dist[[i, k]] + dist[[k, j]];
if alt < dist[[i, j]] {
dist[[i, j]] = alt;
}
}
}
}
Ok(dist)
}
#[allow(dead_code)]
pub fn astar_search<N, E, Ix, H>(
graph: &Graph<N, E, Ix>,
source: &N,
target: &N,
heuristic: H,
) -> Result<AStarResult<N, E>>
where
N: Node + std::fmt::Debug + Clone + Hash + Eq,
E: EdgeWeight
+ Clone
+ std::ops::Add<Output = E>
+ scirs2_core::numeric::Zero
+ PartialOrd
+ Copy,
Ix: petgraph::graph::IndexType,
H: Fn(&N) -> E,
{
if !graph.contains_node(source) || !graph.contains_node(target) {
return Err(GraphError::node_not_found("node"));
}
let mut open_set = BinaryHeap::new();
let mut g_score: HashMap<N, E> = HashMap::new();
let mut came_from: HashMap<N, N> = HashMap::new();
g_score.insert(source.clone(), E::zero());
open_set.push(AStarState {
node: source.clone(),
cost: E::zero(),
heuristic: heuristic(source),
path: vec![source.clone()],
});
while let Some(current_state) = open_set.pop() {
let current = ¤t_state.node;
if current == target {
return Ok(AStarResult {
path: current_state.path,
cost: current_state.cost,
});
}
let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
if let Ok(neighbors) = graph.neighbors(current) {
for neighbor in neighbors {
if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
let tentative_g = current_g + edge_weight;
let current_neighbor_g = g_score.get(&neighbor);
if current_neighbor_g.is_none()
|| tentative_g < *current_neighbor_g.expect("Test: operation failed")
{
came_from.insert(neighbor.clone(), current.clone());
g_score.insert(neighbor.clone(), tentative_g);
let mut new_path = current_state.path.clone();
new_path.push(neighbor.clone());
open_set.push(AStarState {
node: neighbor.clone(),
cost: tentative_g,
heuristic: heuristic(&neighbor),
path: new_path,
});
}
}
}
}
}
Err(GraphError::NoPath {
src_node: format!("{source:?}"),
target: format!("{target:?}"),
nodes: 0,
edges: 0,
})
}
#[allow(dead_code)]
pub fn astar_search_digraph<N, E, Ix, H>(
graph: &DiGraph<N, E, Ix>,
source: &N,
target: &N,
heuristic: H,
) -> Result<AStarResult<N, E>>
where
N: Node + std::fmt::Debug + Clone + Hash + Eq,
E: EdgeWeight
+ Clone
+ std::ops::Add<Output = E>
+ scirs2_core::numeric::Zero
+ PartialOrd
+ Copy,
Ix: petgraph::graph::IndexType,
H: Fn(&N) -> E,
{
if !graph.contains_node(source) || !graph.contains_node(target) {
return Err(GraphError::node_not_found("node"));
}
let mut open_set = BinaryHeap::new();
let mut g_score: HashMap<N, E> = HashMap::new();
let mut came_from: HashMap<N, N> = HashMap::new();
g_score.insert(source.clone(), E::zero());
open_set.push(AStarState {
node: source.clone(),
cost: E::zero(),
heuristic: heuristic(source),
path: vec![source.clone()],
});
while let Some(current_state) = open_set.pop() {
let current = ¤t_state.node;
if current == target {
return Ok(AStarResult {
path: current_state.path,
cost: current_state.cost,
});
}
let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
if let Ok(successors) = graph.successors(current) {
for neighbor in successors {
if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
let tentative_g = current_g + edge_weight;
let current_neighbor_g = g_score.get(&neighbor);
if current_neighbor_g.is_none()
|| tentative_g < *current_neighbor_g.expect("Test: operation failed")
{
came_from.insert(neighbor.clone(), current.clone());
g_score.insert(neighbor.clone(), tentative_g);
let mut new_path = current_state.path.clone();
new_path.push(neighbor.clone());
open_set.push(AStarState {
node: neighbor.clone(),
cost: tentative_g,
heuristic: heuristic(&neighbor),
path: new_path,
});
}
}
}
}
}
Err(GraphError::NoPath {
src_node: format!("{source:?}"),
target: format!("{target:?}"),
nodes: 0,
edges: 0,
})
}
#[allow(dead_code)]
pub fn k_shortest_paths<N, E, Ix>(
graph: &Graph<N, E, Ix>,
source: &N,
target: &N,
k: usize,
) -> Result<Vec<(f64, Vec<N>)>>
where
N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
E: EdgeWeight
+ Into<f64>
+ Clone
+ scirs2_core::numeric::Zero
+ scirs2_core::numeric::One
+ std::ops::Add<Output = E>
+ PartialOrd
+ std::marker::Copy
+ std::fmt::Debug
+ std::default::Default,
Ix: petgraph::graph::IndexType,
{
if k == 0 {
return Ok(vec![]);
}
if !graph.contains_node(source) || !graph.contains_node(target) {
return Err(GraphError::node_not_found("node"));
}
let mut paths = Vec::new();
let mut candidates = std::collections::BinaryHeap::new();
match dijkstra_path(graph, source, target) {
Ok(Some(path)) => {
let weight: f64 = path.total_weight.into();
paths.push((weight, path.nodes));
}
Ok(None) => return Ok(vec![]), Err(e) => return Err(e),
}
for i in 0..k - 1 {
if i >= paths.len() {
break;
}
let (_, prev_path) = &paths[i];
for j in 0..prev_path.len() - 1 {
let spur_node = &prev_path[j];
let root_path = &prev_path[..=j];
let mut removed_edges = Vec::new();
for (_, path) in &paths {
if path.len() > j && &path[..=j] == root_path && j + 1 < path.len() {
removed_edges.push((path[j].clone(), path[j + 1].clone()));
}
}
if let Ok((spur_weight, spur_path)) =
shortest_path_avoiding_edges(graph, spur_node, target, &removed_edges, root_path)
{
let mut total_weight = spur_weight;
for idx in 0..j {
if let Ok(edge_weight) = graph.edge_weight(&prev_path[idx], &prev_path[idx + 1])
{
let weight: f64 = edge_weight.into();
total_weight += weight;
}
}
let mut complete_path = root_path[..j].to_vec();
complete_path.extend(spur_path);
candidates.push((
std::cmp::Reverse(ordered_float::OrderedFloat(total_weight)),
complete_path.clone(),
));
}
}
}
while paths.len() < k && !candidates.is_empty() {
let (std::cmp::Reverse(ordered_float::OrderedFloat(weight)), path) =
candidates.pop().expect("Operation failed");
let is_duplicate = paths.iter().any(|(_, p)| p == &path);
if !is_duplicate {
paths.push((weight, path));
}
}
Ok(paths)
}
#[allow(dead_code)]
fn shortest_path_avoiding_edges<N, E, Ix>(
graph: &Graph<N, E, Ix>,
source: &N,
target: &N,
avoided_edges: &[(N, N)],
excluded_nodes: &[N],
) -> Result<(f64, Vec<N>)>
where
N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
E: EdgeWeight + Into<f64>,
Ix: petgraph::graph::IndexType,
{
use std::cmp::Reverse;
let mut distances: HashMap<N, f64> = HashMap::new();
let mut previous: HashMap<N, N> = HashMap::new();
let mut heap = BinaryHeap::new();
distances.insert(source.clone(), 0.0);
heap.push((Reverse(ordered_float::OrderedFloat(0.0)), source.clone()));
while let Some((Reverse(ordered_float::OrderedFloat(dist)), node)) = heap.pop() {
if &node == target {
let mut path = vec![target.clone()];
let mut current = target.clone();
while let Some(prev) = previous.get(¤t) {
path.push(prev.clone());
current = prev.clone();
}
path.reverse();
return Ok((dist, path));
}
if distances.get(&node).is_none_or(|&d| dist > d) {
continue;
}
if let Ok(neighbors) = graph.neighbors(&node) {
for neighbor in neighbors {
if avoided_edges.contains(&(node.clone(), neighbor.clone())) {
continue;
}
if &neighbor != source && &neighbor != target && excluded_nodes.contains(&neighbor)
{
continue;
}
if let Ok(edge_weight) = graph.edge_weight(&node, &neighbor) {
let weight: f64 = edge_weight.into();
let new_dist = dist + weight;
if new_dist < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
distances.insert(neighbor.clone(), new_dist);
previous.insert(neighbor.clone(), node.clone());
heap.push((Reverse(ordered_float::OrderedFloat(new_dist)), neighbor));
}
}
}
}
}
Err(GraphError::NoPath {
src_node: format!("{source:?}"),
target: format!("{target:?}"),
nodes: 0,
edges: 0,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(deprecated)]
fn test_shortest_path() {
let mut graph: Graph<i32, f64> = Graph::new();
graph.add_edge(1, 2, 4.0).expect("Operation failed");
graph.add_edge(1, 3, 2.0).expect("Operation failed");
graph.add_edge(2, 3, 1.0).expect("Operation failed");
graph.add_edge(2, 4, 5.0).expect("Operation failed");
graph.add_edge(3, 4, 8.0).expect("Operation failed");
let path = shortest_path(&graph, &1, &4)
.expect("Operation failed")
.expect("Test: operation failed");
assert_eq!(path.total_weight, 8.0);
assert_eq!(path.nodes, vec![1, 3, 2, 4]);
}
#[test]
fn test_floyd_warshall() {
let mut graph: Graph<i32, f64> = Graph::new();
graph.add_edge(0, 1, 1.0).expect("Operation failed");
graph.add_edge(1, 2, 2.0).expect("Operation failed");
graph.add_edge(2, 0, 3.0).expect("Operation failed");
let distances = floyd_warshall(&graph).expect("Operation failed");
assert_eq!(distances[[0, 0]], 0.0);
assert_eq!(distances[[0, 1]], 1.0);
assert_eq!(distances[[0, 2]], 3.0); assert_eq!(distances[[1, 0]], 1.0); }
#[test]
fn test_astar_search() {
let mut graph: Graph<(i32, i32), f64> = Graph::new();
graph
.add_edge((0, 0), (0, 1), 1.0)
.expect("Operation failed");
graph
.add_edge((0, 1), (1, 1), 1.0)
.expect("Operation failed");
graph
.add_edge((1, 1), (1, 0), 1.0)
.expect("Operation failed");
graph
.add_edge((1, 0), (0, 0), 1.0)
.expect("Operation failed");
let heuristic = |&(x, y): &(i32, i32)| -> f64 { ((1 - x).abs() + (1 - y).abs()) as f64 };
let result = astar_search(&graph, &(0, 0), &(1, 1), heuristic);
let result = result.expect("Test: operation failed");
assert_eq!(result.cost, 2.0);
assert_eq!(result.path.len(), 3); }
#[test]
fn test_k_shortest_paths() {
let mut graph: Graph<char, f64> = Graph::new();
graph.add_edge('A', 'B', 2.0).expect("Operation failed");
graph.add_edge('B', 'D', 2.0).expect("Operation failed");
graph.add_edge('A', 'C', 1.0).expect("Operation failed");
graph.add_edge('C', 'D', 4.0).expect("Operation failed");
graph.add_edge('B', 'C', 1.0).expect("Operation failed");
let paths = k_shortest_paths(&graph, &'A', &'D', 3).expect("Operation failed");
assert!(paths.len() >= 2);
assert_eq!(paths[0].0, 4.0); assert_eq!(paths[0].1, vec!['A', 'B', 'D']);
}
}