Skip to main content

oxicuda_graphalg/shortest_path/
spfa.rs

1//! Shortest Path Faster Algorithm (queue-based Bellman-Ford variant).
2
3use std::collections::VecDeque;
4
5use crate::error::{GraphalgError, GraphalgResult};
6use crate::repr::weighted_graph::WeightedGraph;
7
8use super::bellman_ford::BellmanFordOutput;
9
10pub fn spfa(g: &WeightedGraph, source: usize) -> GraphalgResult<BellmanFordOutput> {
11    if source >= g.n {
12        return Err(GraphalgError::SourceOutOfRange {
13            node: source,
14            n: g.n,
15        });
16    }
17    let mut dist = vec![f64::INFINITY; g.n];
18    let mut parent = vec![usize::MAX; g.n];
19    let mut in_queue = vec![false; g.n];
20    let mut count = vec![0usize; g.n];
21    dist[source] = 0.0;
22    parent[source] = source;
23    let mut q: VecDeque<usize> = VecDeque::new();
24    q.push_back(source);
25    in_queue[source] = true;
26    while let Some(u) = q.pop_front() {
27        in_queue[u] = false;
28        for &(v, w) in g.neighbors(u)? {
29            let nd = dist[u] + w;
30            if nd < dist[v] {
31                dist[v] = nd;
32                parent[v] = u;
33                if !in_queue[v] {
34                    q.push_back(v);
35                    in_queue[v] = true;
36                    count[v] += 1;
37                    if count[v] > g.n {
38                        return Err(GraphalgError::NegativeCycle);
39                    }
40                }
41            }
42        }
43    }
44    Ok(BellmanFordOutput { dist, parent })
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn spfa_simple() {
53        let mut g = WeightedGraph::new(4);
54        g.add_edge(0, 1, 1.0).expect("ok");
55        g.add_edge(1, 2, 2.0).expect("ok");
56        g.add_edge(2, 3, 3.0).expect("ok");
57        let out = spfa(&g, 0).expect("ok");
58        assert!((out.dist[3] - 6.0).abs() < 1e-12);
59    }
60
61    #[test]
62    fn spfa_negative_cycle() {
63        let mut g = WeightedGraph::new(3);
64        g.add_edge(0, 1, 1.0).expect("ok");
65        g.add_edge(1, 2, -3.0).expect("ok");
66        g.add_edge(2, 0, 1.0).expect("ok");
67        assert!(matches!(spfa(&g, 0), Err(GraphalgError::NegativeCycle)));
68    }
69}