Skip to main content

evoc_rs/utils/
disjoint_set.rs

1//! Disjoint set helpers
2
3///////////////////////////
4// DisjointSet structure //
5///////////////////////////
6
7/// Union-Find with union by rank and path halving.
8///
9/// Standard disjoint set forest supporting near-constant-time `find` and
10/// `union` operations (amortised inverse Ackermann). Used internally for
11/// connected component tracking during clustering.
12pub struct DisjointSet {
13    /// Parent pointers; `parent[x] == x` indicates a root.
14    parent: Vec<usize>,
15    /// Upper bound on subtree depth, used to keep merges balanced.
16    rank: Vec<usize>,
17}
18
19impl DisjointSet {
20    /// Create a new disjoint set with `n` elements, each in its own singleton
21    /// set.
22    ///
23    /// ### Params
24    ///
25    /// * `n` - Number of elements (indexed `0..n`)
26    pub fn new(n: usize) -> Self {
27        Self {
28            parent: (0..n).collect(),
29            rank: vec![0; n],
30        }
31    }
32
33    /// Find the representative (root) of the set containing `x`.
34    ///
35    /// Applies path halving: each traversed node is repointed to its
36    /// grandparent, flattening the tree over successive calls.
37    ///
38    /// ### Params
39    ///
40    /// * `x` - Element to look up
41    ///
42    /// ### Returns
43    ///
44    /// Root index of the component containing `x`
45    pub fn find(&mut self, mut x: usize) -> usize {
46        while self.parent[x] != x {
47            self.parent[x] = self.parent[self.parent[x]];
48            x = self.parent[x];
49        }
50        x
51    }
52
53    /// Merge the sets containing `x` and `y` by rank.
54    ///
55    /// The shorter tree is attached under the taller one to keep depth
56    /// bounded.
57    ///
58    /// ### Params
59    ///
60    /// * `x` - First element
61    /// * `y` - Second element
62    ///
63    /// ### Returns
64    ///
65    /// `true` if `x` and `y` were in different sets (i.e. a merge actually
66    /// happened), `false` if they were already connected.
67    pub fn union(&mut self, x: usize, y: usize) -> bool {
68        let rx = self.find(x);
69        let ry = self.find(y);
70        if rx == ry {
71            return false;
72        }
73
74        if self.rank[rx] < self.rank[ry] {
75            self.parent[rx] = ry;
76        } else if self.rank[rx] > self.rank[ry] {
77            self.parent[ry] = rx;
78        } else {
79            self.parent[ry] = rx;
80            self.rank[rx] += 1;
81        }
82        true
83    }
84
85    /// Check whether `x` and `y` belong to the same set.
86    ///
87    /// ### Params
88    ///
89    /// * `x` - First element
90    /// * `y` - Second element
91    ///
92    /// ### Returns
93    ///
94    /// `true` if both elements share the same root
95    pub fn connected(&mut self, x: usize, y: usize) -> bool {
96        self.find(x) == self.find(y)
97    }
98}
99
100//////////////////////
101// SizedDisjointSet //
102//////////////////////
103
104/// Union-Find that tracks component sizes.
105///
106/// Same structure as [`DisjointSet`] but uses union by size instead of rank,
107/// and exposes component sizes. Used during linkage tree construction where
108/// merge sizes feed into distance/weight calculations.
109///
110/// ### Fields
111///
112/// * `parent` - Parent pointers; `parent[x] == x` indicates a root
113/// * `size` - Number of elements in the subtree rooted at each node; only
114///   meaningful at root nodes
115pub struct SizedDisjointSet {
116    /// Parent pointers; `parent[x] == x` indicates a root.
117    parent: Vec<usize>,
118    /// Component size; only valid at root nodes.
119    size: Vec<usize>,
120}
121
122impl SizedDisjointSet {
123    /// Create with `n` singleton elements, each of size 1.
124    ///
125    /// ### Params
126    ///
127    /// * `n` - Number of elements (indexed `0..n`)
128    pub fn new(n: usize) -> Self {
129        Self {
130            parent: (0..n).collect(),
131            size: vec![1; n],
132        }
133    }
134
135    /// Find the representative of the set containing `x`, with path halving.
136    ///
137    /// ### Params
138    ///
139    /// * `x` - Element to look up
140    ///
141    /// ### Returns
142    ///
143    /// Root index of the component containing `x`
144    pub fn find(&mut self, mut x: usize) -> usize {
145        while self.parent[x] != x {
146            self.parent[x] = self.parent[self.parent[x]];
147            x = self.parent[x];
148        }
149        x
150    }
151
152    /// Merge the sets containing `x` and `y` by size.
153    ///
154    /// The smaller component is attached under the larger one.
155    ///
156    /// ### Params
157    ///
158    /// * `x` - First element
159    /// * `y` - Second element
160    ///
161    /// ### Returns
162    ///
163    /// `true` if a merge occurred, `false` if already in the same set.
164    pub fn union(&mut self, x: usize, y: usize) -> bool {
165        let rx = self.find(x);
166        let ry = self.find(y);
167        if rx == ry {
168            return false;
169        }
170
171        if self.size[rx] < self.size[ry] {
172            self.parent[rx] = ry;
173            self.size[ry] += self.size[rx];
174        } else {
175            self.parent[ry] = rx;
176            self.size[rx] += self.size[ry];
177        }
178        true
179    }
180
181    /// Return the size of the component containing `x`.
182    ///
183    /// ### Params
184    ///
185    /// * `x` - Element to look up
186    ///
187    /// ### Returns
188    ///
189    /// Number of elements in the component
190    pub fn component_size(&mut self, x: usize) -> usize {
191        let r = self.find(x);
192        self.size[r]
193    }
194}
195
196///////////
197// Tests //
198///////////
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_disjoint_set_basic() {
206        let mut ds = DisjointSet::new(5);
207
208        assert!(!ds.connected(0, 1));
209        assert!(ds.union(0, 1));
210        assert!(ds.connected(0, 1));
211        assert!(!ds.union(0, 1)); // already connected
212    }
213
214    #[test]
215    fn test_disjoint_set_chain() {
216        let mut ds = DisjointSet::new(5);
217        ds.union(0, 1);
218        ds.union(1, 2);
219        ds.union(2, 3);
220        ds.union(3, 4);
221
222        for i in 0..5 {
223            for j in 0..5 {
224                assert!(ds.connected(i, j));
225            }
226        }
227    }
228
229    #[test]
230    fn test_disjoint_set_two_components() {
231        let mut ds = DisjointSet::new(6);
232        ds.union(0, 1);
233        ds.union(1, 2);
234        ds.union(3, 4);
235        ds.union(4, 5);
236
237        assert!(ds.connected(0, 2));
238        assert!(ds.connected(3, 5));
239        assert!(!ds.connected(0, 3));
240    }
241
242    #[test]
243    fn test_sized_disjoint_set() {
244        let mut ds = SizedDisjointSet::new(5);
245
246        assert_eq!(ds.component_size(0), 1);
247        ds.union(0, 1);
248        assert_eq!(ds.component_size(0), 2);
249        assert_eq!(ds.component_size(1), 2);
250        ds.union(0, 2);
251        assert_eq!(ds.component_size(2), 3);
252    }
253
254    #[test]
255    fn test_sized_disjoint_set_full_merge() {
256        let mut ds = SizedDisjointSet::new(4);
257        ds.union(0, 1);
258        ds.union(2, 3);
259        ds.union(0, 3);
260        assert_eq!(ds.component_size(0), 4);
261        assert_eq!(ds.component_size(3), 4);
262    }
263}