Skip to main content

rustsim_pathfinding/
yen.rs

1//! Yen's K-shortest loopless paths algorithm.
2//!
3//! Finds up to K distinct loopless paths between a source and destination node,
4//! ordered by ascending cost. Uses Dijkstra as the inner shortest-path routine.
5//!
6//! The implementation follows Yen (1971) with standard spur-node iteration.
7//!
8//! # Example
9//!
10//! ```
11//! use rustsim_pathfinding::yen::{yen_k_shortest, YenPath};
12//!
13//! // Diamond graph: 0→1 (cost 1), 0→2 (cost 2), 1→3 (cost 2), 2→3 (cost 1)
14//! let neighbors = |node: &usize| -> Vec<(usize, f64)> {
15//!     match *node {
16//!         0 => vec![(1, 1.0), (2, 2.0)],
17//!         1 => vec![(3, 2.0)],
18//!         2 => vec![(3, 1.0)],
19//!         _ => vec![],
20//!     }
21//! };
22//!
23//! let paths = yen_k_shortest(0, 3, 3, neighbors);
24//! assert_eq!(paths.len(), 2);
25//! assert_eq!(paths[0].nodes, vec![0, 1, 3]);
26//! assert_eq!(paths[1].nodes, vec![0, 2, 3]);
27//! ```
28//!
29//! # References
30//!
31//! Yen, J. Y. (1971). "Finding the K Shortest Loopless Paths in a Network."
32//! Management Science, 17(11), 712–716.
33
34use std::cmp::Ordering;
35use std::collections::{BinaryHeap, HashMap, HashSet};
36
37/// A single path with its node sequence and total cost.
38#[derive(Debug, Clone)]
39pub struct YenPath<N> {
40    /// Ordered node sequence from origin to destination (inclusive).
41    pub nodes: Vec<N>,
42    /// Total path cost.
43    pub cost: f64,
44}
45
46#[derive(Clone)]
47struct DijkEntry<N> {
48    node: N,
49    cost: f64,
50}
51
52impl<N: PartialEq> PartialEq for DijkEntry<N> {
53    fn eq(&self, other: &Self) -> bool {
54        self.node == other.node
55    }
56}
57
58impl<N: Eq> Eq for DijkEntry<N> {}
59
60impl<N: Eq> PartialOrd for DijkEntry<N> {
61    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62        Some(self.cmp(other))
63    }
64}
65
66impl<N: Eq> Ord for DijkEntry<N> {
67    fn cmp(&self, other: &Self) -> Ordering {
68        other
69            .cost
70            .partial_cmp(&self.cost)
71            .unwrap_or(Ordering::Equal)
72    }
73}
74
75/// Dijkstra shortest path from `src` to `dest`, respecting excluded edges and
76/// nodes.
77fn dijkstra_path<N, FN, I>(
78    src: N,
79    dest: N,
80    neighbors: &mut FN,
81    excluded_edges: &HashSet<(N, N)>,
82    excluded_nodes: &HashSet<N>,
83) -> Option<YenPath<N>>
84where
85    N: Clone + Eq + std::hash::Hash,
86    FN: FnMut(&N) -> I,
87    I: IntoIterator<Item = (N, f64)>,
88{
89    let mut g_scores: HashMap<N, f64> = HashMap::new();
90    let mut came_from: HashMap<N, N> = HashMap::new();
91    let mut closed: HashSet<N> = HashSet::new();
92    let mut heap: BinaryHeap<DijkEntry<N>> = BinaryHeap::new();
93
94    g_scores.insert(src.clone(), 0.0);
95    heap.push(DijkEntry {
96        node: src.clone(),
97        cost: 0.0,
98    });
99
100    while let Some(current) = heap.pop() {
101        if current.node == dest {
102            // Reconstruct path.
103            let mut path = Vec::new();
104            let mut cur = dest.clone();
105            loop {
106                path.push(cur.clone());
107                match came_from.get(&cur) {
108                    Some(prev) => cur = prev.clone(),
109                    None => break,
110                }
111            }
112            path.reverse();
113            return Some(YenPath {
114                nodes: path,
115                cost: current.cost,
116            });
117        }
118
119        if !closed.insert(current.node.clone()) {
120            continue;
121        }
122
123        let g = current.cost;
124        for (nbr, edge_cost) in neighbors(&current.node) {
125            if closed.contains(&nbr) {
126                continue;
127            }
128            if excluded_nodes.contains(&nbr) {
129                continue;
130            }
131            if excluded_edges.contains(&(current.node.clone(), nbr.clone())) {
132                continue;
133            }
134            let tentative = g + edge_cost;
135            let prev = g_scores.get(&nbr).copied().unwrap_or(f64::INFINITY);
136            if tentative < prev {
137                g_scores.insert(nbr.clone(), tentative);
138                came_from.insert(nbr.clone(), current.node.clone());
139                heap.push(DijkEntry {
140                    node: nbr,
141                    cost: tentative,
142                });
143            }
144        }
145    }
146
147    None
148}
149
150/// Find up to `k` shortest loopless paths from `src` to `dest`.
151///
152/// Uses Yen's algorithm with Dijkstra as the inner shortest-path routine.
153///
154/// # Arguments
155///
156/// - `src` - source node
157/// - `dest` - destination node
158/// - `k` - maximum number of paths to find
159/// - `neighbors` - closure returning (neighbor, edge_cost) pairs for a node
160///
161/// Returns paths sorted by ascending cost. May return fewer than `k` paths
162/// if fewer distinct loopless paths exist.
163pub fn yen_k_shortest<N, FN, I>(src: N, dest: N, k: usize, mut neighbors: FN) -> Vec<YenPath<N>>
164where
165    N: Clone + Eq + std::hash::Hash,
166    FN: FnMut(&N) -> I,
167    I: IntoIterator<Item = (N, f64)>,
168{
169    if k == 0 {
170        return Vec::new();
171    }
172
173    // Find the shortest path first.
174    let first = dijkstra_path(
175        src.clone(),
176        dest.clone(),
177        &mut neighbors,
178        &HashSet::new(),
179        &HashSet::new(),
180    );
181    let Some(first) = first else {
182        return Vec::new();
183    };
184
185    let mut accepted: Vec<YenPath<N>> = vec![first];
186
187    // Candidate heap: (cost, path_nodes).
188    // We wrap in a struct for ordering.
189    let mut candidates: BinaryHeap<CandidateEntry<N>> = BinaryHeap::new();
190    let mut candidate_set: HashSet<Vec<N>> = HashSet::new();
191
192    for ki in 1..k {
193        let prev_path = &accepted[ki - 1].nodes;
194
195        // Spur from each node along the previous path (except dest).
196        for spur_idx in 0..prev_path.len().saturating_sub(1) {
197            let spur_node = prev_path[spur_idx].clone();
198            let root_path: Vec<N> = prev_path[..=spur_idx].to_vec();
199
200            // Exclude edges at spur_node that share the same root prefix.
201            let mut excluded_edges: HashSet<(N, N)> = HashSet::new();
202            for accepted_path in &accepted {
203                if accepted_path.nodes.len() > spur_idx
204                    && accepted_path.nodes[..=spur_idx] == root_path[..]
205                {
206                    excluded_edges.insert((
207                        accepted_path.nodes[spur_idx].clone(),
208                        accepted_path.nodes[spur_idx + 1].clone(),
209                    ));
210                }
211            }
212
213            // Exclude nodes in root (except spur node) to guarantee loopless.
214            let mut excluded_nodes: HashSet<N> = HashSet::new();
215            for node in &root_path[..spur_idx] {
216                excluded_nodes.insert(node.clone());
217            }
218
219            if let Some(spur_path) = dijkstra_path(
220                spur_node,
221                dest.clone(),
222                &mut neighbors,
223                &excluded_edges,
224                &excluded_nodes,
225            ) {
226                // Concatenate root + spur (spur starts at spur_node).
227                let mut full_nodes = root_path.clone();
228                full_nodes.extend_from_slice(&spur_path.nodes[1..]);
229
230                // Check for loops.
231                let mut seen = HashSet::new();
232                if full_nodes.iter().any(|n| !seen.insert(n.clone())) {
233                    continue;
234                }
235
236                if !candidate_set.contains(&full_nodes) {
237                    // Compute total cost.
238                    let cost = path_cost(&full_nodes, &mut neighbors);
239                    candidate_set.insert(full_nodes.clone());
240                    candidates.push(CandidateEntry {
241                        cost,
242                        nodes: full_nodes,
243                    });
244                }
245            }
246        }
247
248        // Pop the cheapest candidate.
249        if let Some(best) = candidates.pop() {
250            accepted.push(YenPath {
251                nodes: best.nodes,
252                cost: best.cost,
253            });
254        } else {
255            break;
256        }
257    }
258
259    accepted
260}
261
262/// Compute the total cost of a path by summing edge costs from the neighbor
263/// function.
264fn path_cost<N, FN, I>(nodes: &[N], neighbors: &mut FN) -> f64
265where
266    N: Clone + Eq + std::hash::Hash,
267    FN: FnMut(&N) -> I,
268    I: IntoIterator<Item = (N, f64)>,
269{
270    let mut total = 0.0;
271    for pair in nodes.windows(2) {
272        let from = &pair[0];
273        let to = &pair[1];
274        // Find the edge cost from `from` to `to`.
275        let edge_cost = neighbors(from)
276            .into_iter()
277            .find(|(n, _)| n == to)
278            .map(|(_, c)| c)
279            .unwrap_or(0.0);
280        total += edge_cost;
281    }
282    total
283}
284
285struct CandidateEntry<N> {
286    cost: f64,
287    nodes: Vec<N>,
288}
289
290impl<N: PartialEq> PartialEq for CandidateEntry<N> {
291    fn eq(&self, other: &Self) -> bool {
292        self.cost == other.cost && self.nodes == other.nodes
293    }
294}
295
296impl<N: Eq> Eq for CandidateEntry<N> {}
297
298impl<N: Eq> PartialOrd for CandidateEntry<N> {
299    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
300        Some(self.cmp(other))
301    }
302}
303
304impl<N: Eq> Ord for CandidateEntry<N> {
305    fn cmp(&self, other: &Self) -> Ordering {
306        // Min-heap: reverse the cost comparison.
307        other
308            .cost
309            .partial_cmp(&self.cost)
310            .unwrap_or(Ordering::Equal)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    fn diamond_neighbors(node: &usize) -> Vec<(usize, f64)> {
319        match *node {
320            0 => vec![(1, 1.0), (2, 2.0)],
321            1 => vec![(3, 2.0)],
322            2 => vec![(3, 1.0)],
323            _ => vec![],
324        }
325    }
326
327    #[test]
328    fn finds_two_paths_on_diamond() {
329        let paths = yen_k_shortest(0, 3, 3, diamond_neighbors);
330        assert_eq!(paths.len(), 2);
331        assert_eq!(paths[0].nodes, vec![0, 1, 3]);
332        assert!((paths[0].cost - 3.0).abs() < 1e-6);
333        assert_eq!(paths[1].nodes, vec![0, 2, 3]);
334        assert!((paths[1].cost - 3.0).abs() < 1e-6);
335    }
336
337    #[test]
338    fn single_path_on_line() {
339        let neighbors = |node: &usize| -> Vec<(usize, f64)> {
340            match *node {
341                0 => vec![(1, 1.0)],
342                1 => vec![(2, 1.0)],
343                _ => vec![],
344            }
345        };
346        let paths = yen_k_shortest(0, 2, 5, neighbors);
347        assert_eq!(paths.len(), 1);
348        assert_eq!(paths[0].nodes, vec![0, 1, 2]);
349        assert!((paths[0].cost - 2.0).abs() < 1e-6);
350    }
351
352    #[test]
353    fn no_path_returns_empty() {
354        let neighbors = |_: &usize| -> Vec<(usize, f64)> { vec![] };
355        let paths = yen_k_shortest(0, 5, 3, neighbors);
356        assert!(paths.is_empty());
357    }
358
359    #[test]
360    fn k_zero_returns_empty() {
361        let paths = yen_k_shortest(0, 3, 0, diamond_neighbors);
362        assert!(paths.is_empty());
363    }
364
365    #[test]
366    fn paths_are_loopless() {
367        // Graph with potential loops: 0→1→2→3, 0→2→3, 0→1→0 (loop edge)
368        let neighbors = |node: &usize| -> Vec<(usize, f64)> {
369            match *node {
370                0 => vec![(1, 1.0), (2, 3.0)],
371                1 => vec![(0, 1.0), (2, 1.0)],
372                2 => vec![(3, 1.0)],
373                _ => vec![],
374            }
375        };
376        let paths = yen_k_shortest(0, 3, 5, neighbors);
377        for path in &paths {
378            let mut seen = HashSet::new();
379            assert!(
380                path.nodes.iter().all(|n| seen.insert(n)),
381                "Path contains loop: {:?}",
382                path.nodes
383            );
384        }
385    }
386
387    #[test]
388    fn paths_sorted_by_cost() {
389        // Grid-like graph with multiple paths
390        let neighbors = |node: &usize| -> Vec<(usize, f64)> {
391            match *node {
392                0 => vec![(1, 1.0), (2, 2.0), (3, 5.0)],
393                1 => vec![(4, 1.0)],
394                2 => vec![(4, 1.0)],
395                3 => vec![(4, 1.0)],
396                _ => vec![],
397            }
398        };
399        let paths = yen_k_shortest(0, 4, 5, neighbors);
400        for i in 1..paths.len() {
401            assert!(
402                paths[i].cost >= paths[i - 1].cost - 1e-12,
403                "Paths not sorted: cost[{}]={} < cost[{}]={}",
404                i,
405                paths[i].cost,
406                i - 1,
407                paths[i - 1].cost
408            );
409        }
410    }
411}