algorithms_edu/data_structures/
union_find.rs1use std::cmp::Ordering::*;
12
13#[derive(Clone)]
15pub struct UnionFind {
16 parents: Vec<usize>,
17 ranks: Vec<usize>,
18}
19
20impl UnionFind {
21 pub fn with_size(size: usize) -> Self {
22 UnionFind {
23 parents: (0..size).collect(),
25 ranks: vec![0; size],
26 }
27 }
28
29 pub fn with_ranks(ranks: Vec<usize>) -> Self {
30 let size = ranks.len();
31 UnionFind {
32 parents: (0..size).collect(),
34 ranks,
35 }
36 }
37
38 pub fn len(&self) -> usize {
39 self.parents.len()
40 }
41
42 pub fn is_empty(&self) -> bool {
43 self.parents.is_empty()
44 }
45
46 pub fn extend(&mut self, size: usize) {
47 let n = self.len();
48 for i in n..n + size {
49 self.parents.push(i);
50 self.ranks.push(0);
51 }
52 }
53
54 pub fn union(&mut self, a: usize, b: usize) -> bool {
56 let rep_a = self.find(a);
57 let rep_b = self.find(b);
58
59 if rep_a == rep_b {
60 return false;
61 }
62
63 let rank_a = self.ranks[rep_a];
64 let rank_b = self.ranks[rep_b];
65
66 match rank_a.cmp(&rank_b) {
67 Greater => self.set_parent(rep_b, rep_a),
68 Less => self.set_parent(rep_a, rep_b),
69 Equal => {
70 self.set_parent(rep_a, rep_b);
71 self.increment_rank(rep_b);
72 }
73 }
74
75 true
76 }
77
78 pub fn find(&mut self, mut element: usize) -> usize {
80 let mut parent = self.parent(element);
81 while element != parent {
82 let next_parent = self.parent(parent);
83 self.set_parent(element, next_parent);
84 element = parent;
85 parent = next_parent;
86 }
87
88 element
89 }
90
91 pub fn in_same_set(&mut self, a: usize, b: usize) -> bool {
92 self.find(a) == self.find(b)
93 }
94
95 fn increment_rank(&mut self, element: usize) {
96 self.ranks[element] = self.ranks[element].saturating_add(1);
97 }
98
99 pub fn parent(&self, element: usize) -> usize {
100 self.parents[element]
101 }
102
103 pub fn set_parent(&mut self, element: usize, parent: usize) {
104 self.parents[element] = parent;
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_union_find() {
114 let mut uf = UnionFind::with_size(7);
115 uf.extend(1);
116 assert_eq!(uf.len(), 8);
117 assert!(!uf.is_empty());
118 assert!(uf.union(0, 1));
119 assert!(uf.union(1, 2));
120 assert!(uf.union(4, 3));
121 assert!(uf.union(3, 2));
122 assert!(!uf.union(0, 3));
123
124 assert!(uf.in_same_set(0, 1));
125 assert!(uf.in_same_set(0, 2));
126 assert!(uf.in_same_set(0, 3));
127 assert!(uf.in_same_set(0, 4));
128 assert!(!uf.in_same_set(0, 5));
129
130 uf.union(5, 3);
131 assert!(uf.in_same_set(0, 5));
132
133 uf.union(6, 7);
134 assert!(uf.in_same_set(6, 7));
135 assert!(!uf.in_same_set(5, 7));
136
137 uf.union(0, 7);
138 assert!(uf.in_same_set(5, 7));
139 }
140}