use crate::graph::graph::Graph;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Add;
#[derive(Debug)]
pub enum ShortestPathError {
InvalidStartNode(usize),
InvalidNeighborNode(usize),
}
impl std::fmt::Display for ShortestPathError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShortestPathError::InvalidStartNode(id) => {
write!(f, "Invalid start node: node ID {} not found", id)
}
ShortestPathError::InvalidNeighborNode(id) => {
write!(f, "Invalid neighbor node: node ID {} not found", id)
}
}
}
}
impl std::error::Error for ShortestPathError {}
#[derive(Copy, Clone, Eq, PartialEq)]
struct State<W> {
cost: W,
node: usize,
}
impl<W: Ord> Ord for State<W> {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost)
}
}
impl<W: Ord> PartialOrd for State<W> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub trait ShortestPath<W, N, E>
where
W: Add<Output = W> + Ord + Copy + Default + From<u8> + Send,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn dijkstra(&self, start: usize) -> Result<HashMap<usize, W>, ShortestPathError>;
}
impl<W, N, E> ShortestPath<W, N, E> for Graph<W, N, E>
where
W: Add<Output = W> + Ord + Copy + Default + From<u8> + Send,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn dijkstra(&self, start: usize) -> Result<HashMap<usize, W>, ShortestPathError> {
if !self.nodes.contains(start) {
return Err(ShortestPathError::InvalidStartNode(start));
}
let mut distances = HashMap::new();
let mut heap = BinaryHeap::new();
distances.insert(start, W::default());
heap.push(State {
cost: W::default(),
node: start,
});
while let Some(State { cost, node }) = heap.pop() {
if let Some(¤t_cost) = distances.get(&node) {
if cost > current_cost {
continue;
}
} else {
continue;
}
for &(neighbor, weight) in &self.nodes[node].neighbors {
if !self.nodes.contains(neighbor) {
return Err(ShortestPathError::InvalidNeighborNode(neighbor));
}
let next_cost = cost + weight;
match distances.entry(neighbor) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if next_cost < *entry.get() {
entry.insert(next_cost);
heap.push(State {
cost: next_cost,
node: neighbor,
});
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(next_cost);
heap.push(State {
cost: next_cost,
node: neighbor,
});
}
}
}
}
Ok(distances)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dijkstra_basic() {
let mut graph = Graph::<u32, (), ()>::new(false);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 2, ()).unwrap();
graph.add_edge(n1, n2, 3, ()).unwrap();
let distances = graph.dijkstra(n0).unwrap();
assert_eq!(distances[&n0], 0);
assert_eq!(distances[&n1], 2);
assert_eq!(distances[&n2], 5);
}
#[test]
fn test_unreachable_node() {
let mut graph = Graph::<u64, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let distances = graph.dijkstra(n0).unwrap();
assert_eq!(distances.get(&n1), None);
}
#[test]
fn test_invalid_start_node() {
let graph = Graph::<u32, (), ()>::new(false);
assert!(matches!(
graph.dijkstra(999),
Err(ShortestPathError::InvalidStartNode(999))
));
}
}