use super::{Graph, NodeId, Path};
use crate::{Cost, Error, Result};
use petgraph::visit::EdgeRef;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
pub fn shortest_path<N, E>(
graph: &Graph<N, E>,
source: NodeId,
target: NodeId,
edge_cost: impl Fn(&E) -> Cost,
) -> Result<Option<Path>> {
let distances = dijkstra(graph, source, &edge_cost)?;
match distances.get(&target) {
Some(&cost) => {
Ok(Some(Path {
nodes: vec![source, target],
cost,
}))
}
None => Ok(None),
}
}
pub fn dijkstra<N, E>(
graph: &Graph<N, E>,
source: NodeId,
edge_cost: impl Fn(&E) -> Cost,
) -> Result<HashMap<NodeId, Cost>> {
let mut distances: HashMap<NodeId, Cost> = HashMap::new();
let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
distances.insert(source, 0);
heap.push(Reverse((0, source)));
while let Some(Reverse((cost, node))) = heap.pop() {
if let Some(&best) = distances.get(&node) {
if cost > best {
continue;
}
}
for edge in graph.edges(node) {
let edge_weight = edge_cost(edge.weight());
if edge_weight < 0 {
return Err(Error::invalid_input(
"Dijkstra requires non-negative edge weights"
));
}
let next = edge.target();
let next_cost = cost + edge_weight;
let is_better = distances
.get(&next)
.map_or(true, |&d| next_cost < d);
if is_better {
distances.insert(next, next_cost);
heap.push(Reverse((next_cost, next)));
}
}
}
Ok(distances)
}
pub fn dijkstra_with_paths<N, E>(
graph: &Graph<N, E>,
source: NodeId,
edge_cost: impl Fn(&E) -> Cost,
) -> Result<(HashMap<NodeId, Cost>, HashMap<NodeId, NodeId>)> {
let mut distances: HashMap<NodeId, Cost> = HashMap::new();
let mut predecessors: HashMap<NodeId, NodeId> = HashMap::new();
let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
distances.insert(source, 0);
heap.push(Reverse((0, source)));
while let Some(Reverse((cost, node))) = heap.pop() {
if let Some(&best) = distances.get(&node) {
if cost > best {
continue;
}
}
for edge in graph.edges(node) {
let edge_weight = edge_cost(edge.weight());
if edge_weight < 0 {
return Err(Error::invalid_input(
"Dijkstra requires non-negative edge weights"
));
}
let next = edge.target();
let next_cost = cost + edge_weight;
let is_better = distances
.get(&next)
.map_or(true, |&d| next_cost < d);
if is_better {
distances.insert(next, next_cost);
predecessors.insert(next, node);
heap.push(Reverse((next_cost, next)));
}
}
}
Ok((distances, predecessors))
}
pub fn reconstruct_path(
predecessors: &HashMap<NodeId, NodeId>,
source: NodeId,
target: NodeId,
total_cost: Cost,
) -> Path {
let mut path = vec![target];
let mut current = target;
while current != source {
if let Some(&pred) = predecessors.get(¤t) {
path.push(pred);
current = pred;
} else {
return Path::empty();
}
}
path.reverse();
Path { nodes: path, cost: total_cost }
}
#[cfg(test)]
mod tests {
use super::*;
use petgraph::graph::DiGraph;
#[test]
fn test_simple_dijkstra() {
let mut graph: DiGraph<(), i64> = DiGraph::new();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
graph.add_edge(a, b, 1);
graph.add_edge(b, c, 2);
graph.add_edge(a, c, 5);
let distances = dijkstra(&graph, a, |&w| w).unwrap();
assert_eq!(distances[&a], 0);
assert_eq!(distances[&b], 1);
assert_eq!(distances[&c], 3); }
#[test]
fn test_unreachable() {
let mut graph: DiGraph<(), i64> = DiGraph::new();
let a = graph.add_node(());
let b = graph.add_node(());
let distances = dijkstra(&graph, a, |&w| w).unwrap();
assert_eq!(distances.get(&a), Some(&0));
assert_eq!(distances.get(&b), None);
}
}