ringkernel_graph/algorithms/
union_find.rs1use crate::models::{ComponentId, NodeId};
11use crate::Result;
12
13#[derive(Debug, Clone)]
15pub struct UnionFind {
16 parent: Vec<u32>,
18 rank: Vec<u32>,
20 num_components: usize,
22}
23
24impl UnionFind {
25 pub fn new(n: usize) -> Self {
27 Self {
28 parent: (0..n as u32).collect(),
29 rank: vec![0; n],
30 num_components: n,
31 }
32 }
33
34 pub fn len(&self) -> usize {
36 self.parent.len()
37 }
38
39 pub fn is_empty(&self) -> bool {
41 self.parent.is_empty()
42 }
43
44 pub fn num_components(&self) -> usize {
46 self.num_components
47 }
48
49 pub fn find(&mut self, x: NodeId) -> NodeId {
51 let mut root = x.0;
52
53 while self.parent[root as usize] != root {
55 root = self.parent[root as usize];
56 }
57
58 let mut node = x.0;
60 while self.parent[node as usize] != root {
61 let next = self.parent[node as usize];
62 self.parent[node as usize] = root;
63 node = next;
64 }
65
66 NodeId(root)
67 }
68
69 pub fn union(&mut self, x: NodeId, y: NodeId) -> bool {
73 let root_x = self.find(x);
74 let root_y = self.find(y);
75
76 if root_x == root_y {
77 return false; }
79
80 let rx = self.rank[root_x.0 as usize];
82 let ry = self.rank[root_y.0 as usize];
83
84 if rx < ry {
85 self.parent[root_x.0 as usize] = root_y.0;
86 } else if rx > ry {
87 self.parent[root_y.0 as usize] = root_x.0;
88 } else {
89 self.parent[root_y.0 as usize] = root_x.0;
91 self.rank[root_x.0 as usize] += 1;
92 }
93
94 self.num_components -= 1;
95 true
96 }
97
98 pub fn connected(&mut self, x: NodeId, y: NodeId) -> bool {
100 self.find(x) == self.find(y)
101 }
102
103 pub fn component_ids(&mut self) -> Vec<ComponentId> {
108 let n = self.parent.len();
109 let mut comp_id = vec![ComponentId::UNASSIGNED; n];
110 let mut next_id = 0u32;
111
112 for i in 0..n {
113 let root = self.find(NodeId(i as u32));
114
115 if !comp_id[root.0 as usize].is_assigned() {
117 comp_id[root.0 as usize] = ComponentId::new(next_id);
118 next_id += 1;
119 }
120
121 comp_id[i] = comp_id[root.0 as usize];
123 }
124
125 comp_id
126 }
127
128 pub fn component_size(&mut self, x: NodeId) -> usize {
130 let root = self.find(x);
131 let mut count = 0;
132 for i in 0..self.parent.len() {
133 if self.find(NodeId(i as u32)) == root {
134 count += 1;
135 }
136 }
137 count
138 }
139}
140
141pub fn union_find_sequential(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
145 let mut uf = UnionFind::new(n);
146
147 for &(u, v) in edges {
148 uf.union(u, v);
149 }
150
151 Ok(uf.component_ids())
152}
153
154pub fn union_find_parallel(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
163 use std::sync::atomic::{AtomicU32, Ordering};
164
165 if n == 0 {
166 return Ok(vec![]);
167 }
168
169 let parent: Vec<AtomicU32> = (0..n as u32).map(AtomicU32::new).collect();
171
172 let mut changed = true;
174 let mut iterations = 0;
175 const MAX_ITERATIONS: usize = 64; while changed && iterations < MAX_ITERATIONS {
178 changed = false;
179 iterations += 1;
180
181 for &(u, v) in edges {
184 let mut pu = parent[u.0 as usize].load(Ordering::Relaxed);
185 let mut pv = parent[v.0 as usize].load(Ordering::Relaxed);
186
187 for _ in 0..n {
189 let gpu = parent[pu as usize].load(Ordering::Relaxed);
190 if gpu == pu {
191 break;
192 }
193 pu = gpu;
194 }
195 for _ in 0..n {
196 let gpv = parent[pv as usize].load(Ordering::Relaxed);
197 if gpv == pv {
198 break;
199 }
200 pv = gpv;
201 }
202
203 if pu != pv {
205 let (smaller, larger) = if pu < pv { (pu, pv) } else { (pv, pu) };
206 if parent[smaller as usize]
208 .compare_exchange(smaller, larger, Ordering::AcqRel, Ordering::Relaxed)
209 .is_ok()
210 {
211 changed = true;
212 }
213 }
214 }
215
216 for i in 0..n {
219 let pi = parent[i].load(Ordering::Relaxed);
220 if pi != i as u32 {
221 let gpi = parent[pi as usize].load(Ordering::Relaxed);
222 if gpi != pi {
223 let _ =
225 parent[i].compare_exchange(pi, gpi, Ordering::AcqRel, Ordering::Relaxed);
226 changed = true;
227 }
228 }
229 }
230 }
231
232 let mut final_parent: Vec<u32> = parent.iter().map(|p| p.load(Ordering::Relaxed)).collect();
234
235 for i in 0..n {
237 let mut root = i as u32;
238 while final_parent[root as usize] != root {
239 root = final_parent[root as usize];
240 }
241 let mut node = i as u32;
243 while final_parent[node as usize] != root {
244 let next = final_parent[node as usize];
245 final_parent[node as usize] = root;
246 node = next;
247 }
248 }
249
250 let mut comp_id = vec![ComponentId::UNASSIGNED; n];
252 let mut next_id = 0u32;
253
254 for i in 0..n {
255 let root = final_parent[i] as usize;
256
257 if !comp_id[root].is_assigned() {
258 comp_id[root] = ComponentId::new(next_id);
259 next_id += 1;
260 }
261
262 comp_id[i] = comp_id[root];
263 }
264
265 Ok(comp_id)
266}
267
268#[cfg(feature = "cuda")]
275pub fn union_find_gpu_ready(n: usize, edges: &[(NodeId, NodeId)]) -> Result<(Vec<u32>, usize)> {
276 let components = union_find_parallel(n, edges)?;
279
280 let mut parent: Vec<u32> = (0..n as u32).collect();
282 let mut num_components = 0u32;
283
284 for i in 0..n {
285 if components[i].0 == num_components {
286 num_components += 1;
287 }
288 let comp = components[i].0;
290 for j in 0..=i {
292 if components[j].0 == comp {
293 parent[i] = j as u32;
294 break;
295 }
296 }
297 }
298
299 Ok((parent, num_components as usize))
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_singleton_sets() {
308 let mut uf = UnionFind::new(5);
309 assert_eq!(uf.num_components(), 5);
310
311 for i in 0..5 {
313 assert_eq!(uf.find(NodeId(i)), NodeId(i));
314 }
315 }
316
317 #[test]
318 fn test_union_basic() {
319 let mut uf = UnionFind::new(5);
320
321 assert!(uf.union(NodeId(0), NodeId(1)));
322 assert_eq!(uf.num_components(), 4);
323 assert!(uf.connected(NodeId(0), NodeId(1)));
324
325 assert!(uf.union(NodeId(2), NodeId(3)));
326 assert_eq!(uf.num_components(), 3);
327
328 assert!(uf.union(NodeId(0), NodeId(2)));
329 assert_eq!(uf.num_components(), 2);
330 assert!(uf.connected(NodeId(0), NodeId(3)));
331 }
332
333 #[test]
334 fn test_union_same_component() {
335 let mut uf = UnionFind::new(3);
336
337 uf.union(NodeId(0), NodeId(1));
338 uf.union(NodeId(1), NodeId(2));
339
340 assert!(uf.connected(NodeId(0), NodeId(2)));
342
343 assert!(!uf.union(NodeId(0), NodeId(2)));
345 assert_eq!(uf.num_components(), 1);
346 }
347
348 #[test]
349 fn test_path_compression() {
350 let mut uf = UnionFind::new(10);
351
352 for i in 0..9 {
354 uf.union(NodeId(i), NodeId(i + 1));
355 }
356
357 let root = uf.find(NodeId(9));
359
360 for i in 0..10 {
363 assert_eq!(uf.find(NodeId(i)), root);
364 }
365 }
366
367 #[test]
368 fn test_component_ids() {
369 let mut uf = UnionFind::new(5);
370
371 uf.union(NodeId(0), NodeId(1));
372 uf.union(NodeId(2), NodeId(3));
373
374 let ids = uf.component_ids();
375
376 assert_eq!(ids[0], ids[1]);
378 assert_eq!(ids[2], ids[3]);
380 assert_ne!(ids[4], ids[0]);
382 assert_ne!(ids[4], ids[2]);
383 assert_eq!(uf.num_components(), 3);
385 }
386
387 #[test]
388 fn test_component_size() {
389 let mut uf = UnionFind::new(6);
390
391 uf.union(NodeId(0), NodeId(1));
392 uf.union(NodeId(1), NodeId(2));
393 uf.union(NodeId(3), NodeId(4));
396 assert_eq!(uf.component_size(NodeId(0)), 3);
401 assert_eq!(uf.component_size(NodeId(3)), 2);
402 assert_eq!(uf.component_size(NodeId(5)), 1);
403 }
404
405 #[test]
406 fn test_union_find_from_edges() {
407 let edges = [
408 (NodeId(0), NodeId(1)),
409 (NodeId(1), NodeId(2)),
410 (NodeId(3), NodeId(4)),
411 ];
412
413 let components = union_find_sequential(5, &edges).unwrap();
414
415 assert_eq!(components[0], components[1]);
417 assert_eq!(components[1], components[2]);
418
419 assert_eq!(components[3], components[4]);
421
422 assert_ne!(components[0], components[3]);
424 }
425
426 #[test]
427 fn test_empty_union_find() {
428 let uf = UnionFind::new(0);
429 assert!(uf.is_empty());
430 assert_eq!(uf.num_components(), 0);
431 }
432
433 #[test]
434 fn test_parallel_union_find_basic() {
435 let edges = [
436 (NodeId(0), NodeId(1)),
437 (NodeId(1), NodeId(2)),
438 (NodeId(3), NodeId(4)),
439 ];
440
441 let components = union_find_parallel(5, &edges).unwrap();
442
443 assert_eq!(components[0], components[1]);
445 assert_eq!(components[1], components[2]);
446
447 assert_eq!(components[3], components[4]);
449
450 assert_ne!(components[0], components[3]);
452 }
453
454 #[test]
455 fn test_parallel_union_find_single_component() {
456 let edges: Vec<_> = (0..9).map(|i| (NodeId(i), NodeId(i + 1))).collect();
458
459 let components = union_find_parallel(10, &edges).unwrap();
460
461 for i in 1..10 {
463 assert_eq!(components[0], components[i]);
464 }
465 }
466
467 #[test]
468 fn test_parallel_union_find_no_edges() {
469 let components = union_find_parallel(5, &[]).unwrap();
470
471 for i in 0..5 {
473 for j in (i + 1)..5 {
474 assert_ne!(components[i], components[j]);
475 }
476 }
477 }
478
479 #[test]
480 fn test_parallel_union_find_empty() {
481 let components = union_find_parallel(0, &[]).unwrap();
482 assert!(components.is_empty());
483 }
484
485 #[test]
486 fn test_parallel_vs_sequential_consistency() {
487 let edges = [
489 (NodeId(0), NodeId(5)),
490 (NodeId(1), NodeId(6)),
491 (NodeId(2), NodeId(7)),
492 (NodeId(5), NodeId(6)),
493 (NodeId(3), NodeId(8)),
494 (NodeId(4), NodeId(9)),
495 (NodeId(8), NodeId(9)),
496 ];
497
498 let seq_components = union_find_sequential(10, &edges).unwrap();
499 let par_components = union_find_parallel(10, &edges).unwrap();
500
501 for i in 0..10 {
503 for j in (i + 1)..10 {
504 let seq_same = seq_components[i] == seq_components[j];
505 let par_same = par_components[i] == par_components[j];
506 assert_eq!(seq_same, par_same, "Mismatch for nodes {} and {}", i, j);
507 }
508 }
509 }
510
511 #[test]
512 fn test_parallel_union_find_star_graph() {
513 let edges: Vec<_> = (1..10).map(|i| (NodeId(0), NodeId(i))).collect();
515
516 let components = union_find_parallel(10, &edges).unwrap();
517
518 for i in 1..10 {
520 assert_eq!(components[0], components[i]);
521 }
522 }
523}