use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use std::hash::Hash;
fn reconstruct_path<Id: Eq + Hash + Clone>(prev: &HashMap<Id, Id>, end: &Id) -> Vec<Id> {
let mut path = Vec::new();
let mut cur = end.clone();
loop {
path.push(cur.clone());
match prev.get(&cur) {
Some(p) => cur = p.clone(),
None => break,
}
}
path.reverse();
path
}
pub struct Graph<Id: Eq + Hash + Clone> {
adj: HashMap<Id, Vec<(Id, u64)>>,
}
impl<Id: Eq + Hash + Clone> Graph<Id> {
#[must_use]
pub fn new() -> Self {
Self {
adj: HashMap::new(),
}
}
pub fn add_edge(&mut self, from: Id, to: Id, weight: u64) {
self.adj.entry(from).or_default().push((to, weight));
}
pub fn bfs(&self, start: &Id) -> Vec<Id> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut result = Vec::new();
visited.insert(start.clone());
queue.push_back(start.clone());
while let Some(node) = queue.pop_front() {
result.push(node.clone());
if let Some(neighbors) = self.adj.get(&node) {
for (to, _) in neighbors {
if visited.insert(to.clone()) {
queue.push_back(to.clone());
}
}
}
}
result
}
}
impl<Id: Eq + Hash + Clone + Ord> Graph<Id> {
pub fn dijkstra(&self, start: &Id, end: &Id) -> Option<(Vec<Id>, u64)> {
let mut dist: HashMap<Id, u64> = HashMap::new();
let mut prev: HashMap<Id, Id> = HashMap::new();
let mut pq: BinaryHeap<Reverse<(u64, Id)>> = BinaryHeap::new();
dist.insert(start.clone(), 0);
pq.push(Reverse((0, start.clone())));
while let Some(Reverse((d, u))) = pq.pop() {
if &u == end {
break;
}
if d > *dist.get(&u).unwrap_or(&u64::MAX) {
continue;
}
if let Some(neighbors) = self.adj.get(&u) {
for (v, w) in neighbors {
let alt = d.saturating_add(*w);
if alt < *dist.get(v).unwrap_or(&u64::MAX) {
dist.insert(v.clone(), alt);
prev.insert(v.clone(), u.clone());
pq.push(Reverse((alt, v.clone())));
}
}
}
}
let cost = *dist.get(end)?;
Some((reconstruct_path(&prev, end), cost))
}
pub fn astar<H: Fn(&Id, &Id) -> u64>(
&self,
start: &Id,
end: &Id,
h: H,
) -> Option<(Vec<Id>, u64)> {
let mut g_score: HashMap<Id, u64> = HashMap::new();
let mut prev: HashMap<Id, Id> = HashMap::new();
let mut pq: BinaryHeap<Reverse<(u64, Id)>> = BinaryHeap::new();
g_score.insert(start.clone(), 0);
pq.push(Reverse((h(start, end), start.clone())));
while let Some(Reverse((_, u))) = pq.pop() {
if &u == end {
return Some((reconstruct_path(&prev, end), *g_score.get(end)?));
}
let cost = *g_score.get(&u).unwrap_or(&u64::MAX);
if let Some(neighbors) = self.adj.get(&u) {
for (v, w) in neighbors {
let alt = cost.saturating_add(*w);
if alt < *g_score.get(v).unwrap_or(&u64::MAX) {
g_score.insert(v.clone(), alt);
prev.insert(v.clone(), u.clone());
pq.push(Reverse((alt.saturating_add(h(v, end)), v.clone())));
}
}
}
}
None
}
}
impl<Id: Eq + Hash + Clone> Default for Graph<Id> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn bfs_visits_all() {
let mut g = Graph::<&str>::new();
g.add_edge("A", "B", 1);
g.add_edge("A", "C", 1);
g.add_edge("B", "D", 1);
assert_eq!(g.bfs(&"A"), vec!["A", "B", "C", "D"]);
}
#[test]
fn dijkstra_shortest_path() {
let mut g = Graph::<&str>::new();
g.add_edge("A", "B", 1);
g.add_edge("A", "C", 4);
g.add_edge("B", "C", 2);
g.add_edge("C", "D", 1);
let (path, cost) = g.dijkstra(&"A", &"D").expect("Path A->D should exist");
assert_eq!(path, vec!["A", "B", "C", "D"]);
assert_eq!(cost, 4);
}
#[test]
fn dijkstra_unreachable() {
let mut g = Graph::<&str>::new();
g.add_edge("A", "B", 1);
assert!(g.dijkstra(&"A", &"Z").is_none());
}
#[test]
fn astar_finds_path() {
let mut g = Graph::<u64>::new();
g.add_edge(0, 1, 1);
g.add_edge(1, 2, 1);
g.add_edge(0, 3, 10);
g.add_edge(3, 2, 1);
let (path, cost) = g
.astar(&0, &2, |a, b| a.abs_diff(*b))
.expect("Path 0->2 should exist");
assert_eq!(path, vec![0, 1, 2]);
assert_eq!(cost, 2);
}
}