use super::{Graph, GraphError, GraphResult, NodeId, Weight};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Debug, Clone)]
pub struct ShortestPathResult {
pub distances: HashMap<NodeId, Weight>,
pub predecessors: HashMap<NodeId, Option<NodeId>>,
}
impl ShortestPathResult {
pub fn reconstruct_path(&self, target: NodeId) -> Option<Vec<NodeId>> {
if !self.distances.contains_key(&target) {
return None;
}
let mut path = Vec::new();
let mut current = Some(target);
while let Some(node) = current {
path.push(node);
current = *self.predecessors.get(&node)?;
}
path.reverse();
Some(path)
}
pub fn distance_to(&self, node: NodeId) -> Option<Weight> {
self.distances.get(&node).copied()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct PriorityNode {
node: NodeId,
distance: Weight,
}
impl Eq for PriorityNode {}
impl PartialOrd for PriorityNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityNode {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
.then_with(|| self.node.cmp(&other.node))
}
}
pub fn dijkstra(graph: &Graph, source: NodeId) -> GraphResult<ShortestPathResult> {
if !graph.has_node(source) {
return Err(GraphError::NodeNotFound(source));
}
let mut distances = HashMap::new();
let mut predecessors = HashMap::new();
let mut heap = BinaryHeap::new();
distances.insert(source, 0.0);
predecessors.insert(source, None);
heap.push(PriorityNode {
node: source,
distance: 0.0,
});
while let Some(PriorityNode { node, distance }) = heap.pop() {
if let Some(&d) = distances.get(&node) {
if distance > d {
continue;
}
}
let neighbors = graph.neighbors(node)?;
for &(neighbor, edge_id) in neighbors {
let edge = graph.get_edge(edge_id)?;
if edge.weight < 0.0 {
return Err(GraphError::InvalidWeight(
"Dijkstra's algorithm requires non-negative weights".to_string(),
));
}
let new_distance = distance + edge.weight;
let should_update = distances.get(&neighbor).is_none_or(|&d| new_distance < d);
if should_update {
distances.insert(neighbor, new_distance);
predecessors.insert(neighbor, Some(node));
heap.push(PriorityNode {
node: neighbor,
distance: new_distance,
});
}
}
}
Ok(ShortestPathResult {
distances,
predecessors,
})
}
pub fn bellman_ford(graph: &Graph, source: NodeId) -> GraphResult<ShortestPathResult> {
if !graph.has_node(source) {
return Err(GraphError::NodeNotFound(source));
}
let nodes = graph.nodes();
let mut distances = HashMap::new();
let mut predecessors = HashMap::new();
for &node in &nodes {
distances.insert(node, Weight::INFINITY);
predecessors.insert(node, None);
}
distances.insert(source, 0.0);
for _ in 0..nodes.len() - 1 {
let mut updated = false;
for edge in graph.edges() {
let dist_from = distances[&edge.from];
if dist_from.is_finite() {
let new_distance = dist_from + edge.weight;
if new_distance < distances[&edge.to] {
distances.insert(edge.to, new_distance);
predecessors.insert(edge.to, Some(edge.from));
updated = true;
}
}
}
if !updated {
break;
}
}
for edge in graph.edges() {
let dist_from = distances[&edge.from];
if dist_from.is_finite() {
let new_distance = dist_from + edge.weight;
if new_distance < distances[&edge.to] {
return Err(GraphError::NegativeCycle(
"Graph contains a negative cycle".to_string(),
));
}
}
}
Ok(ShortestPathResult {
distances,
predecessors,
})
}
#[derive(Debug, Clone)]
pub struct AllPairsShortestPaths {
pub distances: HashMap<NodeId, HashMap<NodeId, Weight>>,
pub next: HashMap<NodeId, HashMap<NodeId, Option<NodeId>>>,
}
impl AllPairsShortestPaths {
pub fn reconstruct_path(&self, source: NodeId, target: NodeId) -> Option<Vec<NodeId>> {
let dist = self.distances.get(&source)?.get(&target)?;
if dist.is_infinite() {
return None;
}
let mut path = vec![source];
let mut current = source;
while current != target {
let next_opt = self.next.get(¤t)?.get(&target)?;
let next_node = (*next_opt)?;
path.push(next_node);
current = next_node;
}
Some(path)
}
pub fn distance(&self, source: NodeId, target: NodeId) -> Option<Weight> {
self.distances.get(&source)?.get(&target).copied()
}
}
pub fn floyd_warshall(graph: &Graph) -> GraphResult<AllPairsShortestPaths> {
let nodes = graph.nodes();
let mut distances: HashMap<NodeId, HashMap<NodeId, Weight>> = HashMap::new();
let mut next: HashMap<NodeId, HashMap<NodeId, Option<NodeId>>> = HashMap::new();
for &i in &nodes {
let mut row_dist = HashMap::new();
let mut row_next = HashMap::new();
for &j in &nodes {
if i == j {
row_dist.insert(j, 0.0);
row_next.insert(j, None);
} else {
row_dist.insert(j, Weight::INFINITY);
row_next.insert(j, None);
}
}
distances.insert(i, row_dist);
next.insert(i, row_next);
}
for edge in graph.edges() {
distances
.get_mut(&edge.from)
.ok_or(GraphError::NodeNotFound(edge.from))?
.insert(edge.to, edge.weight);
next.get_mut(&edge.from)
.ok_or(GraphError::NodeNotFound(edge.from))?
.insert(edge.to, Some(edge.to));
}
for &k in &nodes {
for &i in &nodes {
for &j in &nodes {
let dist_ik = distances[&i][&k];
let dist_kj = distances[&k][&j];
let dist_ij = distances[&i][&j];
if dist_ik.is_finite() && dist_kj.is_finite() {
let new_dist = dist_ik + dist_kj;
if new_dist < dist_ij {
distances
.get_mut(&i)
.ok_or(GraphError::NodeNotFound(i))?
.insert(j, new_dist);
let next_k = next[&k][&j];
next.get_mut(&i)
.ok_or(GraphError::NodeNotFound(i))?
.insert(j, next_k);
}
}
}
}
}
for &i in &nodes {
if distances[&i][&i] < 0.0 {
return Err(GraphError::NegativeCycle(format!(
"Negative cycle detected involving node {}",
i
)));
}
}
Ok(AllPairsShortestPaths { distances, next })
}
pub type HeuristicFn = Box<dyn Fn(NodeId, NodeId) -> Weight>;
pub fn astar(
graph: &Graph,
source: NodeId,
target: NodeId,
heuristic: HeuristicFn,
) -> GraphResult<(Vec<NodeId>, Weight)> {
if !graph.has_node(source) {
return Err(GraphError::NodeNotFound(source));
}
if !graph.has_node(target) {
return Err(GraphError::NodeNotFound(target));
}
let mut g_score = HashMap::new();
let mut f_score = HashMap::new();
let mut predecessors = HashMap::new();
let mut heap = BinaryHeap::new();
g_score.insert(source, 0.0);
f_score.insert(source, heuristic(source, target));
heap.push(PriorityNode {
node: source,
distance: heuristic(source, target),
});
while let Some(PriorityNode { node, distance: _ }) = heap.pop() {
if node == target {
let mut path = vec![target];
let mut current = target;
while let Some(&pred) = predecessors.get(¤t) {
path.push(pred);
current = pred;
}
path.reverse();
return Ok((path, g_score[&target]));
}
let neighbors = graph.neighbors(node)?;
for &(neighbor, edge_id) in neighbors {
let edge = graph.get_edge(edge_id)?;
let tentative_g_score = g_score[&node] + edge.weight;
let should_update = g_score
.get(&neighbor)
.is_none_or(|&g| tentative_g_score < g);
if should_update {
predecessors.insert(neighbor, node);
g_score.insert(neighbor, tentative_g_score);
let f = tentative_g_score + heuristic(neighbor, target);
f_score.insert(neighbor, f);
heap.push(PriorityNode {
node: neighbor,
distance: f,
});
}
}
}
Err(GraphError::InvalidPath(format!(
"No path exists from {} to {}",
source, target
)))
}
pub fn reconstruct_path(
predecessors: &HashMap<NodeId, Option<NodeId>>,
target: NodeId,
) -> Option<Vec<NodeId>> {
if !predecessors.contains_key(&target) {
return None;
}
let mut path = Vec::new();
let mut current = Some(target);
while let Some(node) = current {
path.push(node);
current = *predecessors.get(&node)?;
}
path.reverse();
Some(path)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> Graph {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
let n2 = graph.add_node();
let n3 = graph.add_node();
graph
.add_edge(n0, n1, 1.0)
.expect("test: valid edge addition");
graph
.add_edge(n0, n2, 4.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, n2, 2.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, n3, 5.0)
.expect("test: valid edge addition");
graph
.add_edge(n2, n3, 1.0)
.expect("test: valid edge addition");
graph
}
#[test]
fn test_dijkstra() {
let graph = create_test_graph();
let result = dijkstra(&graph, 0).expect("test: valid Dijkstra");
assert_eq!(result.distance_to(0), Some(0.0));
assert_eq!(result.distance_to(1), Some(1.0));
assert_eq!(result.distance_to(2), Some(3.0));
assert_eq!(result.distance_to(3), Some(4.0));
let path = result
.reconstruct_path(3)
.expect("test: valid path reconstruction");
assert_eq!(path, vec![0, 1, 2, 3]);
}
#[test]
fn test_dijkstra_negative_weight() {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
graph
.add_edge(n0, n1, -1.0)
.expect("test: valid edge addition");
let result = dijkstra(&graph, n0);
assert!(result.is_err());
}
#[test]
fn test_bellman_ford() {
let graph = create_test_graph();
let result = bellman_ford(&graph, 0).expect("test: valid Bellman-Ford");
assert_eq!(result.distance_to(0), Some(0.0));
assert_eq!(result.distance_to(1), Some(1.0));
assert_eq!(result.distance_to(2), Some(3.0));
assert_eq!(result.distance_to(3), Some(4.0));
}
#[test]
fn test_bellman_ford_negative_weights() {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
let n2 = graph.add_node();
graph
.add_edge(n0, n1, 1.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, n2, -2.0)
.expect("test: valid edge addition");
let result =
bellman_ford(&graph, n0).expect("test: valid Bellman-Ford with negative edges");
assert_eq!(result.distance_to(2), Some(-1.0));
}
#[test]
fn test_bellman_ford_negative_cycle() {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
graph
.add_edge(n0, n1, 1.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, n0, -2.0)
.expect("test: valid edge addition");
let result = bellman_ford(&graph, n0);
assert!(result.is_err());
}
#[test]
fn test_floyd_warshall() {
let graph = create_test_graph();
let result = floyd_warshall(&graph).expect("test: valid Floyd-Warshall");
assert_eq!(result.distance(0, 3), Some(4.0));
assert_eq!(result.distance(0, 0), Some(0.0));
let path = result
.reconstruct_path(0, 3)
.expect("test: valid path reconstruction");
assert_eq!(path[0], 0);
assert_eq!(path[path.len() - 1], 3);
}
#[test]
fn test_astar() {
let graph = create_test_graph();
let heuristic = Box::new(|_: NodeId, _: NodeId| 0.0);
let (path, cost) = astar(&graph, 0, 3, heuristic).expect("test: valid A* search");
assert_eq!(cost, 4.0);
assert_eq!(path[0], 0);
assert_eq!(path[path.len() - 1], 3);
}
#[test]
fn test_astar_no_path() {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
let heuristic = Box::new(|_: NodeId, _: NodeId| 0.0);
let result = astar(&graph, n0, n1, heuristic);
assert!(result.is_err());
}
#[test]
fn test_reconstruct_path() {
let mut predecessors = HashMap::new();
predecessors.insert(0, None);
predecessors.insert(1, Some(0));
predecessors.insert(2, Some(1));
let path = reconstruct_path(&predecessors, 2).expect("test: valid path reconstruction");
assert_eq!(path, vec![0, 1, 2]);
}
}