evoc_rs/utils/disjoint_set.rs
1//! Disjoint set helpers
2
3///////////////////////////
4// DisjointSet structure //
5///////////////////////////
6
7/// Union-Find with union by rank and path halving.
8///
9/// Standard disjoint set forest supporting near-constant-time `find` and
10/// `union` operations (amortised inverse Ackermann). Used internally for
11/// connected component tracking during clustering.
12pub struct DisjointSet {
13 /// Parent pointers; `parent[x] == x` indicates a root.
14 parent: Vec<usize>,
15 /// Upper bound on subtree depth, used to keep merges balanced.
16 rank: Vec<usize>,
17}
18
19impl DisjointSet {
20 /// Create a new disjoint set with `n` elements, each in its own singleton
21 /// set.
22 ///
23 /// ### Params
24 ///
25 /// * `n` - Number of elements (indexed `0..n`)
26 pub fn new(n: usize) -> Self {
27 Self {
28 parent: (0..n).collect(),
29 rank: vec![0; n],
30 }
31 }
32
33 /// Find the representative (root) of the set containing `x`.
34 ///
35 /// Applies path halving: each traversed node is repointed to its
36 /// grandparent, flattening the tree over successive calls.
37 ///
38 /// ### Params
39 ///
40 /// * `x` - Element to look up
41 ///
42 /// ### Returns
43 ///
44 /// Root index of the component containing `x`
45 pub fn find(&mut self, mut x: usize) -> usize {
46 while self.parent[x] != x {
47 self.parent[x] = self.parent[self.parent[x]];
48 x = self.parent[x];
49 }
50 x
51 }
52
53 /// Merge the sets containing `x` and `y` by rank.
54 ///
55 /// The shorter tree is attached under the taller one to keep depth
56 /// bounded.
57 ///
58 /// ### Params
59 ///
60 /// * `x` - First element
61 /// * `y` - Second element
62 ///
63 /// ### Returns
64 ///
65 /// `true` if `x` and `y` were in different sets (i.e. a merge actually
66 /// happened), `false` if they were already connected.
67 pub fn union(&mut self, x: usize, y: usize) -> bool {
68 let rx = self.find(x);
69 let ry = self.find(y);
70 if rx == ry {
71 return false;
72 }
73
74 if self.rank[rx] < self.rank[ry] {
75 self.parent[rx] = ry;
76 } else if self.rank[rx] > self.rank[ry] {
77 self.parent[ry] = rx;
78 } else {
79 self.parent[ry] = rx;
80 self.rank[rx] += 1;
81 }
82 true
83 }
84
85 /// Check whether `x` and `y` belong to the same set.
86 ///
87 /// ### Params
88 ///
89 /// * `x` - First element
90 /// * `y` - Second element
91 ///
92 /// ### Returns
93 ///
94 /// `true` if both elements share the same root
95 pub fn connected(&mut self, x: usize, y: usize) -> bool {
96 self.find(x) == self.find(y)
97 }
98}
99
100//////////////////////
101// SizedDisjointSet //
102//////////////////////
103
104/// Union-Find that tracks component sizes.
105///
106/// Same structure as [`DisjointSet`] but uses union by size instead of rank,
107/// and exposes component sizes. Used during linkage tree construction where
108/// merge sizes feed into distance/weight calculations.
109///
110/// ### Fields
111///
112/// * `parent` - Parent pointers; `parent[x] == x` indicates a root
113/// * `size` - Number of elements in the subtree rooted at each node; only
114/// meaningful at root nodes
115pub struct SizedDisjointSet {
116 /// Parent pointers; `parent[x] == x` indicates a root.
117 parent: Vec<usize>,
118 /// Component size; only valid at root nodes.
119 size: Vec<usize>,
120}
121
122impl SizedDisjointSet {
123 /// Create with `n` singleton elements, each of size 1.
124 ///
125 /// ### Params
126 ///
127 /// * `n` - Number of elements (indexed `0..n`)
128 pub fn new(n: usize) -> Self {
129 Self {
130 parent: (0..n).collect(),
131 size: vec![1; n],
132 }
133 }
134
135 /// Find the representative of the set containing `x`, with path halving.
136 ///
137 /// ### Params
138 ///
139 /// * `x` - Element to look up
140 ///
141 /// ### Returns
142 ///
143 /// Root index of the component containing `x`
144 pub fn find(&mut self, mut x: usize) -> usize {
145 while self.parent[x] != x {
146 self.parent[x] = self.parent[self.parent[x]];
147 x = self.parent[x];
148 }
149 x
150 }
151
152 /// Merge the sets containing `x` and `y` by size.
153 ///
154 /// The smaller component is attached under the larger one.
155 ///
156 /// ### Params
157 ///
158 /// * `x` - First element
159 /// * `y` - Second element
160 ///
161 /// ### Returns
162 ///
163 /// `true` if a merge occurred, `false` if already in the same set.
164 pub fn union(&mut self, x: usize, y: usize) -> bool {
165 let rx = self.find(x);
166 let ry = self.find(y);
167 if rx == ry {
168 return false;
169 }
170
171 if self.size[rx] < self.size[ry] {
172 self.parent[rx] = ry;
173 self.size[ry] += self.size[rx];
174 } else {
175 self.parent[ry] = rx;
176 self.size[rx] += self.size[ry];
177 }
178 true
179 }
180
181 /// Return the size of the component containing `x`.
182 ///
183 /// ### Params
184 ///
185 /// * `x` - Element to look up
186 ///
187 /// ### Returns
188 ///
189 /// Number of elements in the component
190 pub fn component_size(&mut self, x: usize) -> usize {
191 let r = self.find(x);
192 self.size[r]
193 }
194}
195
196///////////
197// Tests //
198///////////
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_disjoint_set_basic() {
206 let mut ds = DisjointSet::new(5);
207
208 assert!(!ds.connected(0, 1));
209 assert!(ds.union(0, 1));
210 assert!(ds.connected(0, 1));
211 assert!(!ds.union(0, 1)); // already connected
212 }
213
214 #[test]
215 fn test_disjoint_set_chain() {
216 let mut ds = DisjointSet::new(5);
217 ds.union(0, 1);
218 ds.union(1, 2);
219 ds.union(2, 3);
220 ds.union(3, 4);
221
222 for i in 0..5 {
223 for j in 0..5 {
224 assert!(ds.connected(i, j));
225 }
226 }
227 }
228
229 #[test]
230 fn test_disjoint_set_two_components() {
231 let mut ds = DisjointSet::new(6);
232 ds.union(0, 1);
233 ds.union(1, 2);
234 ds.union(3, 4);
235 ds.union(4, 5);
236
237 assert!(ds.connected(0, 2));
238 assert!(ds.connected(3, 5));
239 assert!(!ds.connected(0, 3));
240 }
241
242 #[test]
243 fn test_sized_disjoint_set() {
244 let mut ds = SizedDisjointSet::new(5);
245
246 assert_eq!(ds.component_size(0), 1);
247 ds.union(0, 1);
248 assert_eq!(ds.component_size(0), 2);
249 assert_eq!(ds.component_size(1), 2);
250 ds.union(0, 2);
251 assert_eq!(ds.component_size(2), 3);
252 }
253
254 #[test]
255 fn test_sized_disjoint_set_full_merge() {
256 let mut ds = SizedDisjointSet::new(4);
257 ds.union(0, 1);
258 ds.union(2, 3);
259 ds.union(0, 3);
260 assert_eq!(ds.component_size(0), 4);
261 assert_eq!(ds.component_size(3), 4);
262 }
263}