algorithms_edu/algo/graph/shortest_path/
floyd_warshall.rs

1//! This mod contains an implementation of the Floyd-Warshall algorithm to find all pairs of
2//! shortest paths between nodes in a graph. We also demonstrate how to detect negative cycles and
3//! reconstruct the shortest path.
4//!
5//! - Time Complexity: $O(V^3)$
6//!
7//! # Resources
8//!
9//! - [W. Fiset's video](https://www.youtube.com/watch?v=4NQ3HnhyNfQ)
10
11use crate::algo::graph::WeightedAdjacencyMatrix;
12
13#[derive(Debug, Eq, PartialEq)]
14pub enum ShortestPathError {
15    NegativeCycle,
16    Unreachable,
17}
18
19pub struct FloydWarshall {
20    dp: Vec<Vec<f64>>,
21    next: Vec<Vec<Option<usize>>>,
22}
23
24impl FloydWarshall {
25    pub fn new(graph: &WeightedAdjacencyMatrix) -> Self {
26        let n = graph.node_count();
27        // Copy input matrix and setup 'next' matrix for path reconstruction.
28        let mut dp = graph.inner.clone();
29        let mut next = vec![vec![None; n]; n];
30        for i in 0..n {
31            for j in 0..n {
32                if graph[i][j] != f64::INFINITY {
33                    next[i][j] = Some(j);
34                }
35            }
36        }
37
38        // Compute all pairs shortest paths.
39        for k in 0..n {
40            for i in 0..n {
41                for j in 0..n {
42                    if dp[i][k] + dp[k][j] < dp[i][j] {
43                        dp[i][j] = dp[i][k] + dp[k][j];
44                        next[i][j] = next[i][k];
45                    }
46                }
47            }
48        }
49
50        // Identify negative cycles by propagating the value 'f64::NEG_INFINITY'
51        // to every edge that is part of or reaches into a negative cycle.
52        for k in 0..n {
53            for i in 0..n {
54                for j in 0..n {
55                    if dp[i][k] + dp[k][j] < dp[i][j] {
56                        dp[i][j] = f64::NEG_INFINITY;
57                        next[i][j] = None;
58                    }
59                }
60            }
61        }
62
63        Self { dp, next }
64    }
65
66    pub fn distance(&self, start: usize, end: usize) -> f64 {
67        self.dp[start][end]
68    }
69
70    /// Reconstructs the shortest path (of nodes) from `start` to `end` inclusive.
71    pub fn path(&self, start: usize, end: usize) -> Result<Vec<usize>, ShortestPathError> {
72        let mut path = Vec::new();
73        if self.dp[start][end] == f64::INFINITY {
74            return Err(ShortestPathError::Unreachable);
75        };
76        let mut prev = start;
77        while let Some(at) = self.next[prev][end] {
78            path.push(prev);
79            if at == end {
80                // produce `[i]` instead of `[i, i]` when constructing the path from `i` to itself
81                if at != prev {
82                    path.push(at);
83                }
84                return Ok(path);
85            }
86            prev = at;
87        }
88        // if `None` is encountered it must be a negative cycle
89        Err(ShortestPathError::NegativeCycle)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::algo::graph::WeightedAdjacencyList;
97    #[test]
98    fn test_floyd_warshall() {
99        let graph: WeightedAdjacencyMatrix = WeightedAdjacencyList::new_directed(
100            7,
101            &[
102                (0, 1, 2.),
103                (0, 2, 5.),
104                (0, 6, 10.),
105                (1, 2, 2.),
106                (1, 4, 11.),
107                (2, 6, 2.),
108                (6, 5, 11.),
109                (4, 5, 1.),
110                (5, 4, -2.),
111            ],
112        )
113        .into();
114        let result = FloydWarshall::new(&graph);
115
116        assert_eq!(result.distance(0, 0), 0.0);
117        assert_eq!(result.distance(0, 1), 2.000);
118        assert_eq!(result.distance(0, 2), 4.000);
119        assert_eq!(result.distance(0, 3), f64::INFINITY);
120        assert_eq!(result.distance(0, 4), f64::NEG_INFINITY);
121        assert_eq!(result.distance(0, 5), f64::NEG_INFINITY);
122        assert_eq!(result.distance(0, 6), 6.000);
123        assert_eq!(result.distance(1, 0), f64::INFINITY);
124        assert_eq!(result.distance(1, 1), 0.000);
125        assert_eq!(result.distance(1, 2), 2.000);
126        assert_eq!(result.distance(1, 3), f64::INFINITY);
127
128        assert_eq!(result.path(0, 0), Ok(vec![0]));
129        assert_eq!(result.path(0, 1), Ok(vec![0, 1]));
130        assert_eq!(result.path(0, 2), Ok(vec![0, 1, 2]));
131        assert_eq!(result.path(0, 3), Err(ShortestPathError::Unreachable));
132        assert_eq!(result.path(0, 4), Err(ShortestPathError::NegativeCycle));
133        assert_eq!(result.path(0, 5), Err(ShortestPathError::NegativeCycle));
134        assert_eq!(result.path(0, 6), Ok(vec![0, 1, 2, 6]));
135        assert_eq!(result.path(1, 0), Err(ShortestPathError::Unreachable));
136        assert_eq!(result.path(1, 1), Ok(vec![1]));
137        assert_eq!(result.path(1, 2), Ok(vec![1, 2]));
138        assert_eq!(result.path(1, 3), Err(ShortestPathError::Unreachable));
139        assert_eq!(result.path(1, 4), Err(ShortestPathError::NegativeCycle));
140        assert_eq!(result.path(1, 5), Err(ShortestPathError::NegativeCycle));
141        assert_eq!(result.path(1, 6), Ok(vec![1, 2, 6]));
142        assert_eq!(result.path(2, 0), Err(ShortestPathError::Unreachable));
143    }
144}