Skip to main content

god_graph/algorithms/
mst.rs

1//! 最小生成树算法模块
2//!
3//! 包含 Kruskal 和 Prim 算法
4
5use crate::edge::EdgeIndex;
6use crate::graph::traits::{GraphBase, GraphQuery};
7use crate::graph::Graph;
8use crate::node::NodeIndex;
9use std::cmp::Ordering;
10use std::collections::HashMap;
11
12/// 并查集数据结构,用于 Kruskal 算法
13struct UnionFind {
14    parent: Vec<usize>,
15    rank: Vec<usize>,
16}
17
18impl UnionFind {
19    fn new(n: usize) -> Self {
20        Self {
21            parent: (0..n).collect(),
22            rank: vec![0; n],
23        }
24    }
25
26    fn find(&mut self, x: usize) -> usize {
27        if self.parent[x] != x {
28            self.parent[x] = self.find(self.parent[x]);
29        }
30        self.parent[x]
31    }
32
33    fn union(&mut self, x: usize, y: usize) -> bool {
34        let rx = self.find(x);
35        let ry = self.find(y);
36        if rx == ry {
37            return false;
38        }
39        if self.rank[rx] < self.rank[ry] {
40            self.parent[rx] = ry;
41        } else if self.rank[rx] > self.rank[ry] {
42            self.parent[ry] = rx;
43        } else {
44            self.parent[ry] = rx;
45            self.rank[rx] += 1;
46        }
47        true
48    }
49}
50
51/// Kruskal 算法
52///
53/// 计算无向加权图的最小生成树
54/// 返回构成 MST 的边索引列表
55///
56/// # 复杂度
57/// - 时间:O(E log E) - 主要来自边排序
58/// - 空间:O(V) - 并查集存储
59///
60/// # 示例
61/// ```rust,no_run
62/// use god_gragh::graph::Graph;
63/// use god_gragh::graph::traits::GraphOps;
64/// use god_gragh::algorithms::mst::kruskal;
65///
66/// let mut graph = Graph::<&str, f64>::undirected();
67/// let a = graph.add_node("A").unwrap();
68/// let b = graph.add_node("B").unwrap();
69/// let c = graph.add_node("C").unwrap();
70/// let d = graph.add_node("D").unwrap();
71/// graph.add_edge(a, b, 1.0).unwrap();
72/// graph.add_edge(a, c, 4.0).unwrap();
73/// graph.add_edge(b, c, 2.0).unwrap();
74/// graph.add_edge(b, d, 5.0).unwrap();
75/// graph.add_edge(c, d, 3.0).unwrap();
76///
77/// let mst_edges = kruskal(&graph);
78/// assert_eq!(mst_edges.len(), 3); // n-1 条边
79/// ```
80pub fn kruskal<T>(graph: &Graph<T, f64>) -> Vec<EdgeIndex> {
81    let n = graph.node_count();
82    if n == 0 {
83        return vec![];
84    }
85
86    // 收集所有边及其权重
87    let mut edges: Vec<(EdgeIndex, f64)> = graph
88        .edges()
89        .map(|edge| (edge.index(), *edge.data()))
90        .collect();
91
92    // 按权重排序
93    edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
94
95    // 使用并查集构建 MST
96    let mut uf = UnionFind::new(n);
97    let mut mst = Vec::with_capacity(n - 1);
98
99    for (edge_idx, _weight) in edges {
100        if let Ok((source, target)) = graph.edge_endpoints(edge_idx) {
101            if uf.union(source.index(), target.index()) {
102                mst.push(edge_idx);
103                if mst.len() == n - 1 {
104                    break;
105                }
106            }
107        }
108    }
109
110    mst
111}
112
113/// Prim 算法
114///
115/// 计算无向加权图的最小生成树
116/// 返回构成 MST 的边索引列表
117///
118/// # 复杂度
119/// - 时间:O((V + E) log V) - 使用二叉堆
120/// - 空间:O(V) - 距离数组和堆
121///
122/// # 示例
123/// ```rust,no_run
124/// use god_gragh::graph::Graph;
125/// use god_gragh::graph::traits::GraphOps;
126/// use god_gragh::algorithms::mst::prim;
127///
128/// let mut graph = Graph::<&str, f64>::undirected();
129/// let a = graph.add_node("A").unwrap();
130/// let b = graph.add_node("B").unwrap();
131/// let c = graph.add_node("C").unwrap();
132/// let d = graph.add_node("D").unwrap();
133/// graph.add_edge(a, b, 1.0).unwrap();
134/// graph.add_edge(a, c, 4.0).unwrap();
135/// graph.add_edge(b, c, 2.0).unwrap();
136/// graph.add_edge(b, d, 5.0).unwrap();
137/// graph.add_edge(c, d, 3.0).unwrap();
138///
139/// let mst_edges = prim(&graph);
140/// assert_eq!(mst_edges.len(), 3); // n-1 条边
141/// ```
142pub fn prim<T>(graph: &Graph<T, f64>) -> Vec<EdgeIndex> {
143    use std::collections::BinaryHeap;
144
145    #[derive(Debug)]
146    struct EdgeCandidate {
147        weight: f64,
148        target_idx: usize,
149        edge_idx: EdgeIndex,
150    }
151
152    impl PartialEq for EdgeCandidate {
153        fn eq(&self, other: &Self) -> bool {
154            self.weight == other.weight
155        }
156    }
157
158    impl Eq for EdgeCandidate {}
159
160    impl PartialOrd for EdgeCandidate {
161        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162            Some(self.cmp(other))
163        }
164    }
165
166    impl Ord for EdgeCandidate {
167        fn cmp(&self, other: &Self) -> Ordering {
168            other.weight.total_cmp(&self.weight)
169        }
170    }
171
172    let n = graph.node_count();
173    if n == 0 {
174        return vec![];
175    }
176
177    let mut in_mst = vec![false; n];
178    let mut mst = Vec::with_capacity(n - 1);
179    let mut heap = BinaryHeap::new();
180
181    // 收集所有有效节点及其索引映射
182    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
183    let index_to_pos: HashMap<usize, usize> = node_indices
184        .iter()
185        .enumerate()
186        .map(|(i, ni)| (ni.index(), i))
187        .collect();
188
189    // 从第一个节点开始
190    let start_pos = 0;
191    in_mst[start_pos] = true;
192
193    // 将起始节点的边加入堆
194    let start_node = node_indices[start_pos];
195    for edge in graph.incident_edges(start_node) {
196        if let Ok((source, target)) = graph.edge_endpoints(edge) {
197            let neighbor = if source == start_node { target } else { source };
198            if let Some(&pos) = index_to_pos.get(&neighbor.index()) {
199                if !in_mst[pos] {
200                    if let Ok(weight) = graph.get_edge(edge) {
201                        heap.push(EdgeCandidate {
202                            weight: *weight,
203                            target_idx: pos,
204                            edge_idx: edge,
205                        });
206                    }
207                }
208            }
209        }
210    }
211
212    while let Some(EdgeCandidate {
213        target_idx,
214        edge_idx,
215        weight: _,
216    }) = heap.pop()
217    {
218        if in_mst[target_idx] {
219            continue;
220        }
221
222        in_mst[target_idx] = true;
223        mst.push(edge_idx);
224
225        if mst.len() == n - 1 {
226            break;
227        }
228
229        // 将新加入节点的边加入堆
230        let target_node = node_indices[target_idx];
231        for edge in graph.incident_edges(target_node) {
232            if let Ok((source, target)) = graph.edge_endpoints(edge) {
233                let neighbor = if source == target_node {
234                    target
235                } else {
236                    source
237                };
238                if let Some(&pos) = index_to_pos.get(&neighbor.index()) {
239                    if !in_mst[pos] {
240                        if let Ok(weight) = graph.get_edge(edge) {
241                            heap.push(EdgeCandidate {
242                                weight: *weight,
243                                target_idx: pos,
244                                edge_idx: edge,
245                            });
246                        }
247                    }
248                }
249            }
250        }
251    }
252
253    mst
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::graph::builders::GraphBuilder;
260
261    #[test]
262    fn test_kruskal_basic() {
263        let graph = GraphBuilder::undirected()
264            .with_nodes(vec!["A", "B", "C", "D"])
265            .with_edges(vec![
266                (0, 1, 1.0),
267                (0, 2, 4.0),
268                (1, 2, 2.0),
269                (1, 3, 5.0),
270                (2, 3, 3.0),
271            ])
272            .build()
273            .unwrap();
274
275        let mst = kruskal(&graph);
276        assert_eq!(mst.len(), 3); // 4 个节点需要 3 条边
277    }
278
279    #[test]
280    fn test_prim_basic() {
281        let graph = GraphBuilder::undirected()
282            .with_nodes(vec!["A", "B", "C", "D"])
283            .with_edges(vec![
284                (0, 1, 1.0),
285                (0, 2, 4.0),
286                (1, 2, 2.0),
287                (1, 3, 5.0),
288                (2, 3, 3.0),
289            ])
290            .build()
291            .unwrap();
292
293        let mst = prim(&graph);
294        assert_eq!(mst.len(), 3);
295    }
296}