use num_traits::{Float, Zero};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Debug;
use std::hash::Hash;
use crate::error::{GraphError, Result};
use crate::graph::Graph;
#[derive(Copy, Clone, Debug)]
struct State<V, W> {
vertex: V,
cost: W,
}
impl<V: Eq, W: PartialOrd> Eq for State<V, W> {}
impl<V: Eq, W: PartialOrd> PartialEq for State<V, W> {
fn eq(&self, other: &Self) -> bool {
self.vertex == other.vertex
}
}
impl<V: Eq, W: PartialOrd> PartialOrd for State<V, W> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<V: Eq, W: PartialOrd> Ord for State<V, W> {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
pub fn shortest_paths<V, W>(graph: &Graph<V, W>, source: &V) -> Result<HashMap<V, Option<W>>>
where
V: Hash + Eq + Copy + Debug,
W: Float + Zero + Copy + Debug,
{
if !graph.has_vertex(source) {
return Err(GraphError::VertexNotFound);
}
let mut distances = HashMap::new();
for v in graph.vertices() {
distances.insert(*v, if v == source { Some(W::zero()) } else { None });
}
let mut heap = BinaryHeap::new();
heap.push(State {
vertex: *source,
cost: W::zero(),
});
while let Some(State { vertex, cost }) = heap.pop() {
if let Some(Some(best)) = distances.get(&vertex) {
if *best < cost {
continue;
}
}
if let Ok(neighbors) = graph.neighbors(&vertex) {
for (neighbor, edge_cost) in neighbors {
if edge_cost < W::zero() {
return Err(GraphError::invalid_input(
"Dijkstra's algorithm requires non-negative weights",
));
}
let next = State {
vertex: *neighbor,
cost: cost + edge_cost,
};
let update = match distances.get(neighbor) {
None => true,
Some(None) => true,
Some(Some(best)) => next.cost < *best,
};
if update {
distances.insert(*neighbor, Some(next.cost));
heap.push(next);
}
}
}
}
Ok(distances)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dijkstra_simple_path() {
let mut graph = Graph::new();
graph.add_edge(0, 1, 1.0);
graph.add_edge(1, 2, 2.0);
graph.add_edge(0, 2, 4.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(1.0));
assert_eq!(distances[&2], Some(3.0));
}
#[test]
fn test_dijkstra_unreachable_vertices() {
let mut graph = Graph::new();
graph.add_edge(0, 1, 1.0);
graph.add_vertex(2);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(1.0));
assert_eq!(distances[&2], None);
}
#[test]
fn test_dijkstra_negative_weights() {
let mut graph = Graph::new();
graph.add_edge(0, 1, -1.0);
assert!(matches!(
shortest_paths(&graph, &0),
Err(GraphError::InvalidInput(_))
));
}
#[test]
fn test_dijkstra_vertex_not_found() {
let graph: Graph<i32, f64> = Graph::new();
assert!(matches!(
shortest_paths(&graph, &0),
Err(GraphError::VertexNotFound)
));
}
#[test]
fn test_dijkstra_multiple_paths() {
let mut graph = Graph::new();
graph.add_edge(0, 1, 4.0);
graph.add_edge(0, 2, 2.0);
graph.add_edge(1, 3, 3.0);
graph.add_edge(2, 1, 1.0);
graph.add_edge(2, 3, 5.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(3.0)); assert_eq!(distances[&2], Some(2.0));
assert_eq!(distances[&3], Some(6.0)); }
#[test]
fn test_dijkstra_undirected_graph() {
let mut graph = Graph::new_undirected();
graph.add_edge(0, 1, 1.0);
graph.add_edge(1, 2, 2.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(1.0));
assert_eq!(distances[&2], Some(3.0));
let distances = shortest_paths(&graph, &2).unwrap();
assert_eq!(distances[&0], Some(3.0));
assert_eq!(distances[&1], Some(2.0));
assert_eq!(distances[&2], Some(0.0));
}
#[test]
fn test_dijkstra_cycle() {
let mut graph = Graph::new();
graph.add_edge(0, 1, 1.0);
graph.add_edge(1, 2, 2.0);
graph.add_edge(2, 0, 3.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(1.0));
assert_eq!(distances[&2], Some(3.0));
}
#[test]
fn test_dijkstra_self_loop() {
let mut graph = Graph::new();
graph.add_edge(0, 0, 1.0);
graph.add_edge(0, 1, 2.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0)); assert_eq!(distances[&1], Some(2.0));
}
#[test]
fn test_dijkstra_parallel_edges() {
let mut graph = Graph::new();
graph.add_edge(0, 1, 2.0);
graph.add_edge(0, 1, 1.0);
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&1], Some(1.0)); }
#[test]
fn test_dijkstra_large_graph() {
let mut graph = Graph::new();
for i in 0..999 {
graph.add_edge(i, i + 1, 1.0);
}
let distances = shortest_paths(&graph, &0).unwrap();
assert_eq!(distances[&0], Some(0.0));
assert_eq!(distances[&500], Some(500.0));
assert_eq!(distances[&999], Some(999.0));
}
}