use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct YenPath<N> {
pub nodes: Vec<N>,
pub cost: f64,
}
#[derive(Clone)]
struct DijkEntry<N> {
node: N,
cost: f64,
}
impl<N: PartialEq> PartialEq for DijkEntry<N> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
}
}
impl<N: Eq> Eq for DijkEntry<N> {}
impl<N: Eq> PartialOrd for DijkEntry<N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<N: Eq> Ord for DijkEntry<N> {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
fn dijkstra_path<N, FN, I>(
src: N,
dest: N,
neighbors: &mut FN,
excluded_edges: &HashSet<(N, N)>,
excluded_nodes: &HashSet<N>,
) -> Option<YenPath<N>>
where
N: Clone + Eq + std::hash::Hash,
FN: FnMut(&N) -> I,
I: IntoIterator<Item = (N, f64)>,
{
let mut g_scores: HashMap<N, f64> = HashMap::new();
let mut came_from: HashMap<N, N> = HashMap::new();
let mut closed: HashSet<N> = HashSet::new();
let mut heap: BinaryHeap<DijkEntry<N>> = BinaryHeap::new();
g_scores.insert(src.clone(), 0.0);
heap.push(DijkEntry {
node: src.clone(),
cost: 0.0,
});
while let Some(current) = heap.pop() {
if current.node == dest {
let mut path = Vec::new();
let mut cur = dest.clone();
loop {
path.push(cur.clone());
match came_from.get(&cur) {
Some(prev) => cur = prev.clone(),
None => break,
}
}
path.reverse();
return Some(YenPath {
nodes: path,
cost: current.cost,
});
}
if !closed.insert(current.node.clone()) {
continue;
}
let g = current.cost;
for (nbr, edge_cost) in neighbors(¤t.node) {
if closed.contains(&nbr) {
continue;
}
if excluded_nodes.contains(&nbr) {
continue;
}
if excluded_edges.contains(&(current.node.clone(), nbr.clone())) {
continue;
}
let tentative = g + edge_cost;
let prev = g_scores.get(&nbr).copied().unwrap_or(f64::INFINITY);
if tentative < prev {
g_scores.insert(nbr.clone(), tentative);
came_from.insert(nbr.clone(), current.node.clone());
heap.push(DijkEntry {
node: nbr,
cost: tentative,
});
}
}
}
None
}
pub fn yen_k_shortest<N, FN, I>(src: N, dest: N, k: usize, mut neighbors: FN) -> Vec<YenPath<N>>
where
N: Clone + Eq + std::hash::Hash,
FN: FnMut(&N) -> I,
I: IntoIterator<Item = (N, f64)>,
{
if k == 0 {
return Vec::new();
}
let first = dijkstra_path(
src.clone(),
dest.clone(),
&mut neighbors,
&HashSet::new(),
&HashSet::new(),
);
let Some(first) = first else {
return Vec::new();
};
let mut accepted: Vec<YenPath<N>> = vec![first];
let mut candidates: BinaryHeap<CandidateEntry<N>> = BinaryHeap::new();
let mut candidate_set: HashSet<Vec<N>> = HashSet::new();
for ki in 1..k {
let prev_path = &accepted[ki - 1].nodes;
for spur_idx in 0..prev_path.len().saturating_sub(1) {
let spur_node = prev_path[spur_idx].clone();
let root_path: Vec<N> = prev_path[..=spur_idx].to_vec();
let mut excluded_edges: HashSet<(N, N)> = HashSet::new();
for accepted_path in &accepted {
if accepted_path.nodes.len() > spur_idx
&& accepted_path.nodes[..=spur_idx] == root_path[..]
{
excluded_edges.insert((
accepted_path.nodes[spur_idx].clone(),
accepted_path.nodes[spur_idx + 1].clone(),
));
}
}
let mut excluded_nodes: HashSet<N> = HashSet::new();
for node in &root_path[..spur_idx] {
excluded_nodes.insert(node.clone());
}
if let Some(spur_path) = dijkstra_path(
spur_node,
dest.clone(),
&mut neighbors,
&excluded_edges,
&excluded_nodes,
) {
let mut full_nodes = root_path.clone();
full_nodes.extend_from_slice(&spur_path.nodes[1..]);
let mut seen = HashSet::new();
if full_nodes.iter().any(|n| !seen.insert(n.clone())) {
continue;
}
if !candidate_set.contains(&full_nodes) {
let cost = path_cost(&full_nodes, &mut neighbors);
candidate_set.insert(full_nodes.clone());
candidates.push(CandidateEntry {
cost,
nodes: full_nodes,
});
}
}
}
if let Some(best) = candidates.pop() {
accepted.push(YenPath {
nodes: best.nodes,
cost: best.cost,
});
} else {
break;
}
}
accepted
}
fn path_cost<N, FN, I>(nodes: &[N], neighbors: &mut FN) -> f64
where
N: Clone + Eq + std::hash::Hash,
FN: FnMut(&N) -> I,
I: IntoIterator<Item = (N, f64)>,
{
let mut total = 0.0;
for pair in nodes.windows(2) {
let from = &pair[0];
let to = &pair[1];
let edge_cost = neighbors(from)
.into_iter()
.find(|(n, _)| n == to)
.map(|(_, c)| c)
.unwrap_or(0.0);
total += edge_cost;
}
total
}
struct CandidateEntry<N> {
cost: f64,
nodes: Vec<N>,
}
impl<N: PartialEq> PartialEq for CandidateEntry<N> {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost && self.nodes == other.nodes
}
}
impl<N: Eq> Eq for CandidateEntry<N> {}
impl<N: Eq> PartialOrd for CandidateEntry<N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<N: Eq> Ord for CandidateEntry<N> {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn diamond_neighbors(node: &usize) -> Vec<(usize, f64)> {
match *node {
0 => vec![(1, 1.0), (2, 2.0)],
1 => vec![(3, 2.0)],
2 => vec![(3, 1.0)],
_ => vec![],
}
}
#[test]
fn finds_two_paths_on_diamond() {
let paths = yen_k_shortest(0, 3, 3, diamond_neighbors);
assert_eq!(paths.len(), 2);
assert_eq!(paths[0].nodes, vec![0, 1, 3]);
assert!((paths[0].cost - 3.0).abs() < 1e-6);
assert_eq!(paths[1].nodes, vec![0, 2, 3]);
assert!((paths[1].cost - 3.0).abs() < 1e-6);
}
#[test]
fn single_path_on_line() {
let neighbors = |node: &usize| -> Vec<(usize, f64)> {
match *node {
0 => vec![(1, 1.0)],
1 => vec![(2, 1.0)],
_ => vec![],
}
};
let paths = yen_k_shortest(0, 2, 5, neighbors);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].nodes, vec![0, 1, 2]);
assert!((paths[0].cost - 2.0).abs() < 1e-6);
}
#[test]
fn no_path_returns_empty() {
let neighbors = |_: &usize| -> Vec<(usize, f64)> { vec![] };
let paths = yen_k_shortest(0, 5, 3, neighbors);
assert!(paths.is_empty());
}
#[test]
fn k_zero_returns_empty() {
let paths = yen_k_shortest(0, 3, 0, diamond_neighbors);
assert!(paths.is_empty());
}
#[test]
fn paths_are_loopless() {
let neighbors = |node: &usize| -> Vec<(usize, f64)> {
match *node {
0 => vec![(1, 1.0), (2, 3.0)],
1 => vec![(0, 1.0), (2, 1.0)],
2 => vec![(3, 1.0)],
_ => vec![],
}
};
let paths = yen_k_shortest(0, 3, 5, neighbors);
for path in &paths {
let mut seen = HashSet::new();
assert!(
path.nodes.iter().all(|n| seen.insert(n)),
"Path contains loop: {:?}",
path.nodes
);
}
}
#[test]
fn paths_sorted_by_cost() {
let neighbors = |node: &usize| -> Vec<(usize, f64)> {
match *node {
0 => vec![(1, 1.0), (2, 2.0), (3, 5.0)],
1 => vec![(4, 1.0)],
2 => vec![(4, 1.0)],
3 => vec![(4, 1.0)],
_ => vec![],
}
};
let paths = yen_k_shortest(0, 4, 5, neighbors);
for i in 1..paths.len() {
assert!(
paths[i].cost >= paths[i - 1].cost - 1e-12,
"Paths not sorted: cost[{}]={} < cost[{}]={}",
i,
paths[i].cost,
i - 1,
paths[i - 1].cost
);
}
}
}