Skip to main content

ringkernel_graph/algorithms/
union_find.rs

1//! Union-Find (Disjoint Set) data structure.
2//!
3//! Union-Find efficiently tracks connected components in undirected graphs.
4//! Supports:
5//! - `find(x)`: Find representative of x's component
6//! - `union(x, y)`: Merge components containing x and y
7//!
8//! Uses path compression and union by rank for near O(1) amortized operations.
9
10use crate::models::{ComponentId, NodeId};
11use crate::Result;
12
13/// Union-Find data structure with path compression and union by rank.
14#[derive(Debug, Clone)]
15pub struct UnionFind {
16    /// Parent pointers (parent[i] = parent of node i, or i if root).
17    parent: Vec<u32>,
18    /// Rank (tree height upper bound) for union by rank.
19    rank: Vec<u32>,
20    /// Number of components.
21    num_components: usize,
22}
23
24impl UnionFind {
25    /// Create new Union-Find with n singleton sets.
26    pub fn new(n: usize) -> Self {
27        Self {
28            parent: (0..n as u32).collect(),
29            rank: vec![0; n],
30            num_components: n,
31        }
32    }
33
34    /// Number of nodes.
35    pub fn len(&self) -> usize {
36        self.parent.len()
37    }
38
39    /// Check if empty.
40    pub fn is_empty(&self) -> bool {
41        self.parent.is_empty()
42    }
43
44    /// Number of disjoint components.
45    pub fn num_components(&self) -> usize {
46        self.num_components
47    }
48
49    /// Find representative of node's component with path compression.
50    pub fn find(&mut self, x: NodeId) -> NodeId {
51        let mut root = x.0;
52
53        // Find root
54        while self.parent[root as usize] != root {
55            root = self.parent[root as usize];
56        }
57
58        // Path compression: point all nodes on path directly to root
59        let mut node = x.0;
60        while self.parent[node as usize] != root {
61            let next = self.parent[node as usize];
62            self.parent[node as usize] = root;
63            node = next;
64        }
65
66        NodeId(root)
67    }
68
69    /// Union two components by rank.
70    ///
71    /// Returns true if a merge occurred (x and y were in different components).
72    pub fn union(&mut self, x: NodeId, y: NodeId) -> bool {
73        let root_x = self.find(x);
74        let root_y = self.find(y);
75
76        if root_x == root_y {
77            return false; // Already in same component
78        }
79
80        // Union by rank: attach smaller tree under larger tree
81        let rx = self.rank[root_x.0 as usize];
82        let ry = self.rank[root_y.0 as usize];
83
84        if rx < ry {
85            self.parent[root_x.0 as usize] = root_y.0;
86        } else if rx > ry {
87            self.parent[root_y.0 as usize] = root_x.0;
88        } else {
89            // Same rank: arbitrarily choose root_x as new root, increment rank
90            self.parent[root_y.0 as usize] = root_x.0;
91            self.rank[root_x.0 as usize] += 1;
92        }
93
94        self.num_components -= 1;
95        true
96    }
97
98    /// Check if two nodes are in the same component.
99    pub fn connected(&mut self, x: NodeId, y: NodeId) -> bool {
100        self.find(x) == self.find(y)
101    }
102
103    /// Get component ID for each node.
104    ///
105    /// Returns a vector where `result[i]` is the component ID of node i.
106    /// Component IDs are assigned 0, 1, 2, ... based on root nodes.
107    pub fn component_ids(&mut self) -> Vec<ComponentId> {
108        let n = self.parent.len();
109        let mut comp_id = vec![ComponentId::UNASSIGNED; n];
110        let mut next_id = 0u32;
111
112        for i in 0..n {
113            let root = self.find(NodeId(i as u32));
114
115            // Assign ID to root if not already assigned
116            if !comp_id[root.0 as usize].is_assigned() {
117                comp_id[root.0 as usize] = ComponentId::new(next_id);
118                next_id += 1;
119            }
120
121            // Copy root's ID to this node
122            comp_id[i] = comp_id[root.0 as usize];
123        }
124
125        comp_id
126    }
127
128    /// Get the size of the component containing node x.
129    pub fn component_size(&mut self, x: NodeId) -> usize {
130        let root = self.find(x);
131        let mut count = 0;
132        for i in 0..self.parent.len() {
133            if self.find(NodeId(i as u32)) == root {
134                count += 1;
135            }
136        }
137        count
138    }
139}
140
141/// Sequential union-find on edge list.
142///
143/// Computes connected components from undirected edges.
144pub fn union_find_sequential(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
145    let mut uf = UnionFind::new(n);
146
147    for &(u, v) in edges {
148        uf.union(u, v);
149    }
150
151    Ok(uf.component_ids())
152}
153
154/// Parallel union-find using Shiloach-Vishkin style algorithm.
155///
156/// This implementation uses atomic operations for thread-safe parallel execution.
157/// The algorithm works in rounds of:
158/// 1. Hook: For each edge, attempt to hook smaller component under larger
159/// 2. Jump: Pointer jumping to flatten trees
160///
161/// Continues until no changes occur (convergence).
162pub fn union_find_parallel(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
163    use std::sync::atomic::{AtomicU32, Ordering};
164
165    if n == 0 {
166        return Ok(vec![]);
167    }
168
169    // Initialize parent array with atomic operations for thread safety
170    let parent: Vec<AtomicU32> = (0..n as u32).map(AtomicU32::new).collect();
171
172    // Shiloach-Vishkin style parallel connected components
173    let mut changed = true;
174    let mut iterations = 0;
175    const MAX_ITERATIONS: usize = 64; // Prevent infinite loops
176
177    while changed && iterations < MAX_ITERATIONS {
178        changed = false;
179        iterations += 1;
180
181        // Phase 1: Hook - process all edges
182        // For each edge (u, v), try to hook smaller root under larger root
183        for &(u, v) in edges {
184            let mut pu = parent[u.0 as usize].load(Ordering::Relaxed);
185            let mut pv = parent[v.0 as usize].load(Ordering::Relaxed);
186
187            // Find roots (with limited iterations to avoid infinite loops)
188            for _ in 0..n {
189                let gpu = parent[pu as usize].load(Ordering::Relaxed);
190                if gpu == pu {
191                    break;
192                }
193                pu = gpu;
194            }
195            for _ in 0..n {
196                let gpv = parent[pv as usize].load(Ordering::Relaxed);
197                if gpv == pv {
198                    break;
199                }
200                pv = gpv;
201            }
202
203            // Hook smaller root under larger root
204            if pu != pv {
205                let (smaller, larger) = if pu < pv { (pu, pv) } else { (pv, pu) };
206                // Atomic compare-and-swap to hook smaller under larger
207                if parent[smaller as usize]
208                    .compare_exchange(smaller, larger, Ordering::AcqRel, Ordering::Relaxed)
209                    .is_ok()
210                {
211                    changed = true;
212                }
213            }
214        }
215
216        // Phase 2: Jump - pointer jumping to flatten trees
217        // Each node points to its grandparent: parent[i] = parent[parent[i]]
218        for i in 0..n {
219            let pi = parent[i].load(Ordering::Relaxed);
220            if pi != i as u32 {
221                let gpi = parent[pi as usize].load(Ordering::Relaxed);
222                if gpi != pi {
223                    // Point to grandparent (path compression)
224                    let _ =
225                        parent[i].compare_exchange(pi, gpi, Ordering::AcqRel, Ordering::Relaxed);
226                    changed = true;
227                }
228            }
229        }
230    }
231
232    // Final pass: ensure all nodes point to their root and assign component IDs
233    let mut final_parent: Vec<u32> = parent.iter().map(|p| p.load(Ordering::Relaxed)).collect();
234
235    // Complete path compression
236    for i in 0..n {
237        let mut root = i as u32;
238        while final_parent[root as usize] != root {
239            root = final_parent[root as usize];
240        }
241        // Compress path
242        let mut node = i as u32;
243        while final_parent[node as usize] != root {
244            let next = final_parent[node as usize];
245            final_parent[node as usize] = root;
246            node = next;
247        }
248    }
249
250    // Assign component IDs
251    let mut comp_id = vec![ComponentId::UNASSIGNED; n];
252    let mut next_id = 0u32;
253
254    for i in 0..n {
255        let root = final_parent[i] as usize;
256
257        if !comp_id[root].is_assigned() {
258            comp_id[root] = ComponentId::new(next_id);
259            next_id += 1;
260        }
261
262        comp_id[i] = comp_id[root];
263    }
264
265    Ok(comp_id)
266}
267
268/// Parallel union-find optimized for GPU execution.
269///
270/// This variant structures the computation for future GPU acceleration:
271/// - Uses contiguous memory layouts suitable for GPU buffers
272/// - Minimizes synchronization points
273/// - Structures work for SIMT execution
274#[cfg(feature = "cuda")]
275pub fn union_find_gpu_ready(n: usize, edges: &[(NodeId, NodeId)]) -> Result<(Vec<u32>, usize)> {
276    // For now, use the parallel CPU implementation
277    // Returns (parent array, number of components) for GPU buffer compatibility
278    let components = union_find_parallel(n, edges)?;
279
280    // Convert to parent array format
281    let mut parent: Vec<u32> = (0..n as u32).collect();
282    let mut num_components = 0u32;
283
284    for i in 0..n {
285        if components[i].0 == num_components {
286            num_components += 1;
287        }
288        // Find representative
289        let comp = components[i].0;
290        // First node with this component ID becomes the root
291        for j in 0..=i {
292            if components[j].0 == comp {
293                parent[i] = j as u32;
294                break;
295            }
296        }
297    }
298
299    Ok((parent, num_components as usize))
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_singleton_sets() {
308        let mut uf = UnionFind::new(5);
309        assert_eq!(uf.num_components(), 5);
310
311        // Each node is its own representative
312        for i in 0..5 {
313            assert_eq!(uf.find(NodeId(i)), NodeId(i));
314        }
315    }
316
317    #[test]
318    fn test_union_basic() {
319        let mut uf = UnionFind::new(5);
320
321        assert!(uf.union(NodeId(0), NodeId(1)));
322        assert_eq!(uf.num_components(), 4);
323        assert!(uf.connected(NodeId(0), NodeId(1)));
324
325        assert!(uf.union(NodeId(2), NodeId(3)));
326        assert_eq!(uf.num_components(), 3);
327
328        assert!(uf.union(NodeId(0), NodeId(2)));
329        assert_eq!(uf.num_components(), 2);
330        assert!(uf.connected(NodeId(0), NodeId(3)));
331    }
332
333    #[test]
334    fn test_union_same_component() {
335        let mut uf = UnionFind::new(3);
336
337        uf.union(NodeId(0), NodeId(1));
338        uf.union(NodeId(1), NodeId(2));
339
340        // All three are connected
341        assert!(uf.connected(NodeId(0), NodeId(2)));
342
343        // Union within same component returns false
344        assert!(!uf.union(NodeId(0), NodeId(2)));
345        assert_eq!(uf.num_components(), 1);
346    }
347
348    #[test]
349    fn test_path_compression() {
350        let mut uf = UnionFind::new(10);
351
352        // Create a chain: 0 -> 1 -> 2 -> ... -> 9
353        for i in 0..9 {
354            uf.union(NodeId(i), NodeId(i + 1));
355        }
356
357        // Find on last node should compress path
358        let root = uf.find(NodeId(9));
359
360        // After path compression, parent should point directly to root
361        // (this tests that path compression worked)
362        for i in 0..10 {
363            assert_eq!(uf.find(NodeId(i)), root);
364        }
365    }
366
367    #[test]
368    fn test_component_ids() {
369        let mut uf = UnionFind::new(5);
370
371        uf.union(NodeId(0), NodeId(1));
372        uf.union(NodeId(2), NodeId(3));
373
374        let ids = uf.component_ids();
375
376        // 0 and 1 should have same ID
377        assert_eq!(ids[0], ids[1]);
378        // 2 and 3 should have same ID
379        assert_eq!(ids[2], ids[3]);
380        // 4 is alone
381        assert_ne!(ids[4], ids[0]);
382        assert_ne!(ids[4], ids[2]);
383        // Three distinct components
384        assert_eq!(uf.num_components(), 3);
385    }
386
387    #[test]
388    fn test_component_size() {
389        let mut uf = UnionFind::new(6);
390
391        uf.union(NodeId(0), NodeId(1));
392        uf.union(NodeId(1), NodeId(2));
393        // Component {0, 1, 2} has size 3
394
395        uf.union(NodeId(3), NodeId(4));
396        // Component {3, 4} has size 2
397
398        // Node 5 is alone
399
400        assert_eq!(uf.component_size(NodeId(0)), 3);
401        assert_eq!(uf.component_size(NodeId(3)), 2);
402        assert_eq!(uf.component_size(NodeId(5)), 1);
403    }
404
405    #[test]
406    fn test_union_find_from_edges() {
407        let edges = [
408            (NodeId(0), NodeId(1)),
409            (NodeId(1), NodeId(2)),
410            (NodeId(3), NodeId(4)),
411        ];
412
413        let components = union_find_sequential(5, &edges).unwrap();
414
415        // {0, 1, 2} are connected
416        assert_eq!(components[0], components[1]);
417        assert_eq!(components[1], components[2]);
418
419        // {3, 4} are connected
420        assert_eq!(components[3], components[4]);
421
422        // Different groups
423        assert_ne!(components[0], components[3]);
424    }
425
426    #[test]
427    fn test_empty_union_find() {
428        let uf = UnionFind::new(0);
429        assert!(uf.is_empty());
430        assert_eq!(uf.num_components(), 0);
431    }
432
433    #[test]
434    fn test_parallel_union_find_basic() {
435        let edges = [
436            (NodeId(0), NodeId(1)),
437            (NodeId(1), NodeId(2)),
438            (NodeId(3), NodeId(4)),
439        ];
440
441        let components = union_find_parallel(5, &edges).unwrap();
442
443        // {0, 1, 2} are connected
444        assert_eq!(components[0], components[1]);
445        assert_eq!(components[1], components[2]);
446
447        // {3, 4} are connected
448        assert_eq!(components[3], components[4]);
449
450        // Different groups
451        assert_ne!(components[0], components[3]);
452    }
453
454    #[test]
455    fn test_parallel_union_find_single_component() {
456        // Linear chain connecting all nodes
457        let edges: Vec<_> = (0..9).map(|i| (NodeId(i), NodeId(i + 1))).collect();
458
459        let components = union_find_parallel(10, &edges).unwrap();
460
461        // All nodes should be in the same component
462        for i in 1..10 {
463            assert_eq!(components[0], components[i]);
464        }
465    }
466
467    #[test]
468    fn test_parallel_union_find_no_edges() {
469        let components = union_find_parallel(5, &[]).unwrap();
470
471        // Each node is its own component
472        for i in 0..5 {
473            for j in (i + 1)..5 {
474                assert_ne!(components[i], components[j]);
475            }
476        }
477    }
478
479    #[test]
480    fn test_parallel_union_find_empty() {
481        let components = union_find_parallel(0, &[]).unwrap();
482        assert!(components.is_empty());
483    }
484
485    #[test]
486    fn test_parallel_vs_sequential_consistency() {
487        // Test that parallel and sequential give same results
488        let edges = [
489            (NodeId(0), NodeId(5)),
490            (NodeId(1), NodeId(6)),
491            (NodeId(2), NodeId(7)),
492            (NodeId(5), NodeId(6)),
493            (NodeId(3), NodeId(8)),
494            (NodeId(4), NodeId(9)),
495            (NodeId(8), NodeId(9)),
496        ];
497
498        let seq_components = union_find_sequential(10, &edges).unwrap();
499        let par_components = union_find_parallel(10, &edges).unwrap();
500
501        // Check connectivity is identical
502        for i in 0..10 {
503            for j in (i + 1)..10 {
504                let seq_same = seq_components[i] == seq_components[j];
505                let par_same = par_components[i] == par_components[j];
506                assert_eq!(seq_same, par_same, "Mismatch for nodes {} and {}", i, j);
507            }
508        }
509    }
510
511    #[test]
512    fn test_parallel_union_find_star_graph() {
513        // Star graph: node 0 connected to all others
514        let edges: Vec<_> = (1..10).map(|i| (NodeId(0), NodeId(i))).collect();
515
516        let components = union_find_parallel(10, &edges).unwrap();
517
518        // All nodes should be in the same component
519        for i in 1..10 {
520            assert_eq!(components[0], components[i]);
521        }
522    }
523}