Skip to main content

god_graph/algorithms/
shortest_path.rs

1//! 最短路径算法模块
2//!
3//! 包含 Dijkstra、Bellman-Ford、Floyd-Warshall、A* 等算法
4
5use crate::errors::{GraphError, GraphResult};
6use crate::graph::traits::{GraphBase, GraphQuery};
7use crate::graph::Graph;
8use crate::node::NodeIndex;
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap};
11
12/// Dijkstra 最短路径算法
13///
14/// 计算从源节点到所有其他节点的最短路径距离
15/// 适用于非负权重的图
16///
17/// # 参数
18/// * `graph` - 图
19/// * `source` - 源节点
20/// * `get_weight` - 获取边权重的闭包
21///
22/// # 返回
23/// HashMap,键为节点索引,值为最短距离
24///
25/// # 错误
26/// * `GraphError::NegativeWeight` - 检测到负权重边,建议使用 Bellman-Ford 算法
27///
28/// # 注意
29/// Dijkstra 算法不适用于负权重图。如果图可能包含负权重,请使用 `bellman_ford` 算法。
30pub fn dijkstra<T, E, F>(
31    graph: &Graph<T, E>,
32    source: NodeIndex,
33    mut get_weight: F,
34) -> GraphResult<HashMap<NodeIndex, f64>>
35where
36    F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
37{
38    // 检测负权重边
39    for edge in graph.edges() {
40        let u = edge.source();
41        let v = edge.target();
42        let weight = get_weight(u, v, edge.data());
43        if weight < 0.0 {
44            return Err(GraphError::NegativeWeight {
45                from: u.index(),
46                to: v.index(),
47                weight,
48            });
49        }
50    }
51
52    // 优先队列项:(节点,距离),使用 Reverse 实现最小堆
53    struct State {
54        node: NodeIndex,
55        distance: f64,
56    }
57
58    impl PartialEq for State {
59        fn eq(&self, other: &Self) -> bool {
60            self.distance == other.distance
61        }
62    }
63
64    impl Eq for State {}
65
66    impl PartialOrd for State {
67        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
68            Some(self.cmp(other))
69        }
70    }
71
72    impl Ord for State {
73        fn cmp(&self, other: &Self) -> Ordering {
74            other.distance.total_cmp(&self.distance)
75        }
76    }
77
78    let mut distances: HashMap<NodeIndex, f64> = HashMap::new();
79    let mut heap = BinaryHeap::new();
80
81    distances.insert(source, 0.0);
82    heap.push(State {
83        node: source,
84        distance: 0.0,
85    });
86
87    while let Some(State { node, distance }) = heap.pop() {
88        // 跳过过期的条目
89        if distance > *distances.get(&node).unwrap_or(&f64::INFINITY) {
90            continue;
91        }
92
93        for neighbor in graph.neighbors(node) {
94            let edge_data = graph.get_edge_by_nodes(node, neighbor)?;
95            let weight = get_weight(node, neighbor, edge_data);
96            let new_distance = distance + weight;
97
98            if new_distance < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
99                distances.insert(neighbor, new_distance);
100                heap.push(State {
101                    node: neighbor,
102                    distance: new_distance,
103                });
104            }
105        }
106    }
107
108    Ok(distances)
109}
110
111/// Bellman-Ford 算法
112///
113/// 计算从源节点到所有其他节点的最短路径
114/// 可以处理负权重边,并能检测负权环
115///
116/// # 返回
117/// * `Ok(HashMap)` - 最短距离
118/// * `Err(GraphError::NegativeCycle)` - 检测到负权环
119pub fn bellman_ford<T, E, F>(
120    graph: &Graph<T, E>,
121    source: NodeIndex,
122    mut get_weight: F,
123) -> Result<HashMap<NodeIndex, f64>, GraphError>
124where
125    F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
126{
127    let mut distances: HashMap<NodeIndex, f64> = HashMap::new();
128
129    // 初始化距离
130    for node in graph.nodes() {
131        distances.insert(node.index(), f64::INFINITY);
132    }
133    distances.insert(source, 0.0);
134
135    let n = graph.node_count();
136
137    // 松弛操作,执行 n-1 轮
138    for _ in 0..n - 1 {
139        for edge in graph.edges() {
140            let u = edge.source();
141            let v = edge.target();
142            let w = get_weight(u, v, edge.data());
143
144            if distances.get(&u) != Some(&f64::INFINITY) {
145                let new_dist = distances[&u] + w;
146                if new_dist < *distances.get(&v).unwrap_or(&f64::INFINITY) {
147                    distances.insert(v, new_dist);
148                }
149            }
150        }
151    }
152
153    // 检测负权环
154    for edge in graph.edges() {
155        let u = edge.source();
156        let v = edge.target();
157        let w = get_weight(u, v, edge.data());
158
159        if distances.get(&u) != Some(&f64::INFINITY)
160            && distances[&u] + w < *distances.get(&v).unwrap_or(&f64::INFINITY)
161        {
162            return Err(GraphError::NegativeCycle);
163        }
164    }
165
166    Ok(distances)
167}
168
169/// A* 搜索算法
170///
171/// 使用启发式函数找到从起点到终点的最短路径
172///
173/// # 参数
174/// * `graph` - 图
175/// * `start` - 起始节点
176/// * `goal` - 目标节点
177/// * `get_weight` - 获取边权重的闭包
178/// * `heuristic` - 启发式函数,估计节点到目标的距离
179pub fn astar<T, E, F, H>(
180    graph: &Graph<T, E>,
181    start: NodeIndex,
182    goal: NodeIndex,
183    mut get_weight: F,
184    mut heuristic: H,
185) -> GraphResult<(f64, Vec<NodeIndex>)>
186where
187    F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
188    H: FnMut(NodeIndex) -> f64,
189{
190    #[derive(Debug)]
191    struct State {
192        node: NodeIndex,
193        f_score: f64,
194    }
195
196    impl PartialEq for State {
197        fn eq(&self, other: &Self) -> bool {
198            self.f_score == other.f_score
199        }
200    }
201
202    impl Eq for State {}
203
204    impl PartialOrd for State {
205        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
206            Some(self.cmp(other))
207        }
208    }
209
210    impl Ord for State {
211        fn cmp(&self, other: &Self) -> Ordering {
212            other.f_score.total_cmp(&self.f_score)
213        }
214    }
215
216    let mut g_scores: HashMap<NodeIndex, f64> = HashMap::new();
217    let mut came_from: HashMap<NodeIndex, NodeIndex> = HashMap::new();
218    let mut heap = BinaryHeap::new();
219
220    g_scores.insert(start, 0.0);
221    heap.push(State {
222        node: start,
223        f_score: heuristic(start),
224    });
225
226    while let Some(State { node, .. }) = heap.pop() {
227        if node == goal {
228            // 重构路径
229            let mut path = vec![goal];
230            let mut current = goal;
231            while let Some(&prev) = came_from.get(&current) {
232                path.push(prev);
233                current = prev;
234            }
235            path.reverse();
236            return Ok((*g_scores.get(&goal).unwrap_or(&0.0), path));
237        }
238
239        let current_g = *g_scores.get(&node).unwrap_or(&f64::INFINITY);
240
241        for neighbor in graph.neighbors(node) {
242            let edge_data = graph.get_edge_by_nodes(node, neighbor)?;
243            let weight = get_weight(node, neighbor, edge_data);
244            let tentative_g = current_g + weight;
245
246            if tentative_g < *g_scores.get(&neighbor).unwrap_or(&f64::INFINITY) {
247                came_from.insert(neighbor, node);
248                g_scores.insert(neighbor, tentative_g);
249                let f_score = tentative_g + heuristic(neighbor);
250                heap.push(State {
251                    node: neighbor,
252                    f_score,
253                });
254            }
255        }
256    }
257
258    Err(GraphError::NodeNotFound {
259        index: goal.index(),
260    })
261}
262
263/// Floyd-Warshall 算法
264///
265/// 计算所有节点对之间的最短路径距离
266/// 适用于任意权重的图(可处理负权重)
267/// 时间复杂度:O(V³),空间复杂度:O(V²)
268///
269/// # 参数
270/// * `graph` - 图
271/// * `get_weight` - 获取边权重的闭包
272///
273/// # 返回
274/// * `Ok(HashMap<(NodeIndex, NodeIndex), f64>)` - 所有节点对的最短距离
275/// * `Err(GraphError::NegativeCycle)` - 检测到负权环
276///
277/// # 示例
278/// ```rust
279/// use god_gragh::graph::builders::GraphBuilder;
280/// use god_gragh::algorithms::shortest_path::floyd_warshall;
281///
282/// let graph = GraphBuilder::directed()
283///     .with_nodes(vec!["A", "B", "C"])
284///     .with_edges(vec![(0, 1, 1.0), (1, 2, 2.0), (0, 2, 4.0)])
285///     .build()
286///     .unwrap();
287///
288/// let distances = floyd_warshall(&graph, |_, _, w| *w).unwrap();
289/// ```
290pub fn floyd_warshall<T, E, F>(
291    graph: &Graph<T, E>,
292    mut get_weight: F,
293) -> Result<HashMap<(NodeIndex, NodeIndex), f64>, GraphError>
294where
295    F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
296{
297    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
298    let n = node_indices.len();
299
300    if n == 0 {
301        return Ok(HashMap::new());
302    }
303
304    // 创建索引到 NodeIndex 的映射
305    let index_to_node = &node_indices;
306    let node_to_index: std::collections::HashMap<usize, usize> = node_indices
307        .iter()
308        .enumerate()
309        .map(|(i, ni)| (ni.index(), i))
310        .collect();
311
312    // 初始化距离矩阵
313    const INF: f64 = f64::INFINITY;
314    let mut dist = vec![vec![INF; n]; n];
315
316    // 设置对角线为 0
317    for (i, row) in dist.iter_mut().enumerate().take(n) {
318        row[i] = 0.0;
319    }
320
321    // 设置直接边的权重
322    for edge in graph.edges() {
323        let u = edge.source();
324        let v = edge.target();
325        if let (Some(&i), Some(&j)) = (node_to_index.get(&u.index()), node_to_index.get(&v.index()))
326        {
327            let weight = get_weight(u, v, edge.data());
328            dist[i][j] = dist[i][j].min(weight);
329        }
330    }
331
332    // Floyd-Warshall 主循环
333    for k in 0..n {
334        for i in 0..n {
335            for j in 0..n {
336                if dist[i][k] != INF && dist[k][j] != INF {
337                    let new_dist = dist[i][k] + dist[k][j];
338                    if new_dist < dist[i][j] {
339                        dist[i][j] = new_dist;
340                    }
341                }
342            }
343        }
344    }
345
346    // 检测负权环:检查对角线是否有负值
347    #[allow(clippy::needless_range_loop)]
348    for i in 0..n {
349        if dist[i][i] < 0.0 {
350            return Err(GraphError::NegativeCycle);
351        }
352    }
353
354    // 构建结果 HashMap
355    let mut result = HashMap::with_capacity(n * n);
356    for (i, row) in dist.iter().enumerate().take(n) {
357        for (j, &value) in row.iter().enumerate().take(n) {
358            if value != INF {
359                result.insert((index_to_node[i], index_to_node[j]), value);
360            }
361        }
362    }
363
364    Ok(result)
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::graph::builders::GraphBuilder;
371
372    #[test]
373    fn test_dijkstra_basic() {
374        let graph = GraphBuilder::directed()
375            .with_nodes(vec!["A", "B", "C", "D"])
376            .with_edges(vec![
377                (0, 1, 1.0),
378                (0, 2, 4.0),
379                (1, 2, 2.0),
380                (1, 3, 5.0),
381                (2, 3, 1.0),
382            ])
383            .build()
384            .unwrap();
385
386        let start = NodeIndex::new(0, 1);
387        let distances = dijkstra(&graph, start, |_, _, _| 1.0).unwrap();
388
389        assert!(distances.contains_key(&start));
390    }
391
392    #[test]
393    fn test_dijkstra_negative_weights() {
394        let graph = GraphBuilder::directed()
395            .with_nodes(vec!["A", "B", "C"])
396            .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
397            .build()
398            .unwrap();
399
400        let start = NodeIndex::new(0, 1);
401        let result = dijkstra(&graph, start, |_, _, w| *w);
402
403        assert!(matches!(result, Err(GraphError::NegativeWeight { .. })));
404    }
405
406    #[test]
407    fn test_bellman_ford_basic() {
408        let graph = GraphBuilder::directed()
409            .with_nodes(vec!["A", "B", "C", "D"])
410            .with_edges(vec![
411                (0, 1, 4.0),
412                (0, 2, 2.0),
413                (1, 2, 1.0),
414                (1, 3, 5.0),
415                (2, 3, 3.0),
416            ])
417            .build()
418            .unwrap();
419
420        let source = NodeIndex::new(0, 1);
421        let distances = bellman_ford(&graph, source, |_, _, w| *w).unwrap();
422
423        assert_eq!(distances.get(&source), Some(&0.0));
424        assert!(distances.len() == 4);
425    }
426
427    #[test]
428    fn test_bellman_ford_negative_weights() {
429        let graph = GraphBuilder::directed()
430            .with_nodes(vec!["A", "B", "C"])
431            .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
432            .build()
433            .unwrap();
434
435        let source = NodeIndex::new(0, 1);
436        let distances = bellman_ford(&graph, source, |_, _, w| *w).unwrap();
437
438        // A->B->C = 1 + (-2) = -1, which is shorter than A->C = 3
439        assert_eq!(distances.get(&NodeIndex::new(2, 1)), Some(&-1.0));
440    }
441
442    #[test]
443    fn test_bellman_ford_negative_cycle() {
444        let graph = GraphBuilder::directed()
445            .with_nodes(vec!["A", "B", "C"])
446            .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (2, 0, -3.0)])
447            .build()
448            .unwrap();
449
450        let source = NodeIndex::new(0, 1);
451        let result = bellman_ford(&graph, source, |_, _, w| *w);
452
453        assert!(matches!(result, Err(GraphError::NegativeCycle)));
454    }
455
456    #[test]
457    fn test_astar_basic() {
458        let graph = GraphBuilder::directed()
459            .with_nodes(vec!["A", "B", "C", "D"])
460            .with_edges(vec![
461                (0, 1, 1.0),
462                (0, 2, 4.0),
463                (1, 2, 2.0),
464                (1, 3, 5.0),
465                (2, 3, 1.0),
466            ])
467            .build()
468            .unwrap();
469
470        let start = NodeIndex::new(0, 1);
471        let goal = NodeIndex::new(3, 1);
472
473        // 使用简单的启发式函数(始终返回 0,退化为 Dijkstra)
474        let (distance, path) = astar(&graph, start, goal, |_, _, _| 1.0, |_| 0.0).unwrap();
475
476        assert!(distance > 0.0);
477        assert!(!path.is_empty());
478        assert_eq!(path.first(), Some(&start));
479        assert_eq!(path.last(), Some(&goal));
480    }
481
482    #[test]
483    fn test_floyd_warshall_basic() {
484        let graph = GraphBuilder::directed()
485            .with_nodes(vec!["A", "B", "C", "D"])
486            .with_edges(vec![
487                (0, 1, 1.0),
488                (0, 2, 4.0),
489                (1, 2, 2.0),
490                (1, 3, 5.0),
491                (2, 3, 1.0),
492            ])
493            .build()
494            .unwrap();
495
496        let distances = floyd_warshall(&graph, |_, _, w| *w).unwrap();
497
498        // 验证节点对之间的距离
499        let nodes: Vec<_> = graph.nodes().collect();
500        assert_eq!(
501            distances.get(&(nodes[0].index(), nodes[3].index())),
502            Some(&4.0)
503        ); // A->B->C->D = 1+2+1 = 4
504    }
505
506    #[test]
507    fn test_floyd_warshall_negative_weights() {
508        let graph = GraphBuilder::directed()
509            .with_nodes(vec!["A", "B", "C"])
510            .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
511            .build()
512            .unwrap();
513
514        let distances = floyd_warshall(&graph, |_, _, w| *w).unwrap();
515
516        let nodes: Vec<_> = graph.nodes().collect();
517        // A->B->C = 1 + (-2) = -1
518        assert_eq!(
519            distances.get(&(nodes[0].index(), nodes[2].index())),
520            Some(&-1.0)
521        );
522    }
523
524    #[test]
525    fn test_floyd_warshall_negative_cycle() {
526        let graph = GraphBuilder::directed()
527            .with_nodes(vec!["A", "B", "C"])
528            .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (2, 0, -3.0)])
529            .build()
530            .unwrap();
531
532        let result = floyd_warshall(&graph, |_, _, w| *w);
533        assert!(matches!(result, Err(GraphError::NegativeCycle)));
534    }
535
536    #[test]
537    fn test_floyd_warshall_empty_graph() {
538        let graph: Graph<i32, f64> = GraphBuilder::directed().build().unwrap();
539        let distances = floyd_warshall(&graph, |_, _, _: &f64| 1.0).unwrap();
540        assert!(distances.is_empty());
541    }
542}