god_graph/algorithms/
mst.rs1use crate::edge::EdgeIndex;
6use crate::graph::traits::{GraphBase, GraphQuery};
7use crate::graph::Graph;
8use crate::node::NodeIndex;
9use std::cmp::Ordering;
10use std::collections::HashMap;
11
12struct UnionFind {
14 parent: Vec<usize>,
15 rank: Vec<usize>,
16}
17
18impl UnionFind {
19 fn new(n: usize) -> Self {
20 Self {
21 parent: (0..n).collect(),
22 rank: vec![0; n],
23 }
24 }
25
26 fn find(&mut self, x: usize) -> usize {
27 if self.parent[x] != x {
28 self.parent[x] = self.find(self.parent[x]);
29 }
30 self.parent[x]
31 }
32
33 fn union(&mut self, x: usize, y: usize) -> bool {
34 let rx = self.find(x);
35 let ry = self.find(y);
36 if rx == ry {
37 return false;
38 }
39 if self.rank[rx] < self.rank[ry] {
40 self.parent[rx] = ry;
41 } else if self.rank[rx] > self.rank[ry] {
42 self.parent[ry] = rx;
43 } else {
44 self.parent[ry] = rx;
45 self.rank[rx] += 1;
46 }
47 true
48 }
49}
50
51pub fn kruskal<T>(graph: &Graph<T, f64>) -> Vec<EdgeIndex> {
81 let n = graph.node_count();
82 if n == 0 {
83 return vec![];
84 }
85
86 let mut edges: Vec<(EdgeIndex, f64)> = graph
88 .edges()
89 .map(|edge| (edge.index(), *edge.data()))
90 .collect();
91
92 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
94
95 let mut uf = UnionFind::new(n);
97 let mut mst = Vec::with_capacity(n - 1);
98
99 for (edge_idx, _weight) in edges {
100 if let Ok((source, target)) = graph.edge_endpoints(edge_idx) {
101 if uf.union(source.index(), target.index()) {
102 mst.push(edge_idx);
103 if mst.len() == n - 1 {
104 break;
105 }
106 }
107 }
108 }
109
110 mst
111}
112
113pub fn prim<T>(graph: &Graph<T, f64>) -> Vec<EdgeIndex> {
143 use std::collections::BinaryHeap;
144
145 #[derive(Debug)]
146 struct EdgeCandidate {
147 weight: f64,
148 target_idx: usize,
149 edge_idx: EdgeIndex,
150 }
151
152 impl PartialEq for EdgeCandidate {
153 fn eq(&self, other: &Self) -> bool {
154 self.weight == other.weight
155 }
156 }
157
158 impl Eq for EdgeCandidate {}
159
160 impl PartialOrd for EdgeCandidate {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 Some(self.cmp(other))
163 }
164 }
165
166 impl Ord for EdgeCandidate {
167 fn cmp(&self, other: &Self) -> Ordering {
168 other.weight.total_cmp(&self.weight)
169 }
170 }
171
172 let n = graph.node_count();
173 if n == 0 {
174 return vec![];
175 }
176
177 let mut in_mst = vec![false; n];
178 let mut mst = Vec::with_capacity(n - 1);
179 let mut heap = BinaryHeap::new();
180
181 let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
183 let index_to_pos: HashMap<usize, usize> = node_indices
184 .iter()
185 .enumerate()
186 .map(|(i, ni)| (ni.index(), i))
187 .collect();
188
189 let start_pos = 0;
191 in_mst[start_pos] = true;
192
193 let start_node = node_indices[start_pos];
195 for edge in graph.incident_edges(start_node) {
196 if let Ok((source, target)) = graph.edge_endpoints(edge) {
197 let neighbor = if source == start_node { target } else { source };
198 if let Some(&pos) = index_to_pos.get(&neighbor.index()) {
199 if !in_mst[pos] {
200 if let Ok(weight) = graph.get_edge(edge) {
201 heap.push(EdgeCandidate {
202 weight: *weight,
203 target_idx: pos,
204 edge_idx: edge,
205 });
206 }
207 }
208 }
209 }
210 }
211
212 while let Some(EdgeCandidate {
213 target_idx,
214 edge_idx,
215 weight: _,
216 }) = heap.pop()
217 {
218 if in_mst[target_idx] {
219 continue;
220 }
221
222 in_mst[target_idx] = true;
223 mst.push(edge_idx);
224
225 if mst.len() == n - 1 {
226 break;
227 }
228
229 let target_node = node_indices[target_idx];
231 for edge in graph.incident_edges(target_node) {
232 if let Ok((source, target)) = graph.edge_endpoints(edge) {
233 let neighbor = if source == target_node {
234 target
235 } else {
236 source
237 };
238 if let Some(&pos) = index_to_pos.get(&neighbor.index()) {
239 if !in_mst[pos] {
240 if let Ok(weight) = graph.get_edge(edge) {
241 heap.push(EdgeCandidate {
242 weight: *weight,
243 target_idx: pos,
244 edge_idx: edge,
245 });
246 }
247 }
248 }
249 }
250 }
251 }
252
253 mst
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::graph::builders::GraphBuilder;
260
261 #[test]
262 fn test_kruskal_basic() {
263 let graph = GraphBuilder::undirected()
264 .with_nodes(vec!["A", "B", "C", "D"])
265 .with_edges(vec![
266 (0, 1, 1.0),
267 (0, 2, 4.0),
268 (1, 2, 2.0),
269 (1, 3, 5.0),
270 (2, 3, 3.0),
271 ])
272 .build()
273 .unwrap();
274
275 let mst = kruskal(&graph);
276 assert_eq!(mst.len(), 3); }
278
279 #[test]
280 fn test_prim_basic() {
281 let graph = GraphBuilder::undirected()
282 .with_nodes(vec!["A", "B", "C", "D"])
283 .with_edges(vec![
284 (0, 1, 1.0),
285 (0, 2, 4.0),
286 (1, 2, 2.0),
287 (1, 3, 5.0),
288 (2, 3, 3.0),
289 ])
290 .build()
291 .unwrap();
292
293 let mst = prim(&graph);
294 assert_eq!(mst.len(), 3);
295 }
296}