oxicuda_graphalg/mst/
union_find.rs1#[derive(Debug, Clone)]
4pub struct UnionFind {
5 parent: Vec<usize>,
6 rank: Vec<u32>,
7 size: Vec<usize>,
8 num_sets: usize,
9}
10
11impl UnionFind {
12 pub fn new(n: usize) -> Self {
13 Self {
14 parent: (0..n).collect(),
15 rank: vec![0; n],
16 size: vec![1; n],
17 num_sets: n,
18 }
19 }
20
21 pub fn len(&self) -> usize {
22 self.parent.len()
23 }
24
25 pub fn is_empty(&self) -> bool {
26 self.parent.is_empty()
27 }
28
29 pub fn find(&mut self, x: usize) -> usize {
30 let mut r = x;
31 while self.parent[r] != r {
32 r = self.parent[r];
33 }
34 let mut cur = x;
36 while self.parent[cur] != r {
37 let next = self.parent[cur];
38 self.parent[cur] = r;
39 cur = next;
40 }
41 r
42 }
43
44 pub fn union(&mut self, a: usize, b: usize) -> bool {
46 let ra = self.find(a);
47 let rb = self.find(b);
48 if ra == rb {
49 return false;
50 }
51 match self.rank[ra].cmp(&self.rank[rb]) {
52 std::cmp::Ordering::Less => {
53 self.parent[ra] = rb;
54 self.size[rb] += self.size[ra];
55 }
56 std::cmp::Ordering::Greater => {
57 self.parent[rb] = ra;
58 self.size[ra] += self.size[rb];
59 }
60 std::cmp::Ordering::Equal => {
61 self.parent[rb] = ra;
62 self.rank[ra] += 1;
63 self.size[ra] += self.size[rb];
64 }
65 }
66 self.num_sets -= 1;
67 true
68 }
69
70 pub fn set_size(&mut self, x: usize) -> usize {
71 let r = self.find(x);
72 self.size[r]
73 }
74
75 pub fn num_sets(&self) -> usize {
76 self.num_sets
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn uf_basic() {
86 let mut uf = UnionFind::new(5);
87 assert_eq!(uf.num_sets(), 5);
88 assert!(uf.union(0, 1));
89 assert_eq!(uf.num_sets(), 4);
90 assert!(!uf.union(0, 1));
91 assert_eq!(uf.find(0), uf.find(1));
92 }
93
94 #[test]
95 fn uf_set_size() {
96 let mut uf = UnionFind::new(4);
97 uf.union(0, 1);
98 uf.union(2, 3);
99 uf.union(1, 2);
100 assert_eq!(uf.set_size(0), 4);
101 }
102}