algorithms_edu/algo/graph/shortest_path/
floyd_warshall.rs1use crate::algo::graph::WeightedAdjacencyMatrix;
12
13#[derive(Debug, Eq, PartialEq)]
14pub enum ShortestPathError {
15 NegativeCycle,
16 Unreachable,
17}
18
19pub struct FloydWarshall {
20 dp: Vec<Vec<f64>>,
21 next: Vec<Vec<Option<usize>>>,
22}
23
24impl FloydWarshall {
25 pub fn new(graph: &WeightedAdjacencyMatrix) -> Self {
26 let n = graph.node_count();
27 let mut dp = graph.inner.clone();
29 let mut next = vec![vec![None; n]; n];
30 for i in 0..n {
31 for j in 0..n {
32 if graph[i][j] != f64::INFINITY {
33 next[i][j] = Some(j);
34 }
35 }
36 }
37
38 for k in 0..n {
40 for i in 0..n {
41 for j in 0..n {
42 if dp[i][k] + dp[k][j] < dp[i][j] {
43 dp[i][j] = dp[i][k] + dp[k][j];
44 next[i][j] = next[i][k];
45 }
46 }
47 }
48 }
49
50 for k in 0..n {
53 for i in 0..n {
54 for j in 0..n {
55 if dp[i][k] + dp[k][j] < dp[i][j] {
56 dp[i][j] = f64::NEG_INFINITY;
57 next[i][j] = None;
58 }
59 }
60 }
61 }
62
63 Self { dp, next }
64 }
65
66 pub fn distance(&self, start: usize, end: usize) -> f64 {
67 self.dp[start][end]
68 }
69
70 pub fn path(&self, start: usize, end: usize) -> Result<Vec<usize>, ShortestPathError> {
72 let mut path = Vec::new();
73 if self.dp[start][end] == f64::INFINITY {
74 return Err(ShortestPathError::Unreachable);
75 };
76 let mut prev = start;
77 while let Some(at) = self.next[prev][end] {
78 path.push(prev);
79 if at == end {
80 if at != prev {
82 path.push(at);
83 }
84 return Ok(path);
85 }
86 prev = at;
87 }
88 Err(ShortestPathError::NegativeCycle)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::algo::graph::WeightedAdjacencyList;
97 #[test]
98 fn test_floyd_warshall() {
99 let graph: WeightedAdjacencyMatrix = WeightedAdjacencyList::new_directed(
100 7,
101 &[
102 (0, 1, 2.),
103 (0, 2, 5.),
104 (0, 6, 10.),
105 (1, 2, 2.),
106 (1, 4, 11.),
107 (2, 6, 2.),
108 (6, 5, 11.),
109 (4, 5, 1.),
110 (5, 4, -2.),
111 ],
112 )
113 .into();
114 let result = FloydWarshall::new(&graph);
115
116 assert_eq!(result.distance(0, 0), 0.0);
117 assert_eq!(result.distance(0, 1), 2.000);
118 assert_eq!(result.distance(0, 2), 4.000);
119 assert_eq!(result.distance(0, 3), f64::INFINITY);
120 assert_eq!(result.distance(0, 4), f64::NEG_INFINITY);
121 assert_eq!(result.distance(0, 5), f64::NEG_INFINITY);
122 assert_eq!(result.distance(0, 6), 6.000);
123 assert_eq!(result.distance(1, 0), f64::INFINITY);
124 assert_eq!(result.distance(1, 1), 0.000);
125 assert_eq!(result.distance(1, 2), 2.000);
126 assert_eq!(result.distance(1, 3), f64::INFINITY);
127
128 assert_eq!(result.path(0, 0), Ok(vec![0]));
129 assert_eq!(result.path(0, 1), Ok(vec![0, 1]));
130 assert_eq!(result.path(0, 2), Ok(vec![0, 1, 2]));
131 assert_eq!(result.path(0, 3), Err(ShortestPathError::Unreachable));
132 assert_eq!(result.path(0, 4), Err(ShortestPathError::NegativeCycle));
133 assert_eq!(result.path(0, 5), Err(ShortestPathError::NegativeCycle));
134 assert_eq!(result.path(0, 6), Ok(vec![0, 1, 2, 6]));
135 assert_eq!(result.path(1, 0), Err(ShortestPathError::Unreachable));
136 assert_eq!(result.path(1, 1), Ok(vec![1]));
137 assert_eq!(result.path(1, 2), Ok(vec![1, 2]));
138 assert_eq!(result.path(1, 3), Err(ShortestPathError::Unreachable));
139 assert_eq!(result.path(1, 4), Err(ShortestPathError::NegativeCycle));
140 assert_eq!(result.path(1, 5), Err(ShortestPathError::NegativeCycle));
141 assert_eq!(result.path(1, 6), Ok(vec![1, 2, 6]));
142 assert_eq!(result.path(2, 0), Err(ShortestPathError::Unreachable));
143 }
144}