competitive_hpp/
union_find.rs1use num::traits::{NumAssignOps, PrimInt, Unsigned};
2
3#[derive(Clone, Debug)]
23pub struct UnionFind<T> {
24 par: Vec<T>,
25 rank: Vec<T>,
26 group: Vec<T>,
27 num_of_groups: T,
28}
29
30impl<T> UnionFind<T>
31where
32 T: PrimInt + NumAssignOps + Unsigned,
33{
34 pub fn new(n: T) -> Self {
35 let mut par: Vec<T> = vec![];
36 let un = n.to_u64().unwrap();
37 for i in 0..un {
38 par.push(T::from(i).unwrap())
39 }
40
41 UnionFind {
42 par,
43 rank: vec![T::zero(); n.to_usize().unwrap()],
44 group: vec![T::one(); n.to_usize().unwrap()],
45 num_of_groups: n,
46 }
47 }
48
49 pub fn find(&mut self, x: T) -> T {
50 let ux = x.to_usize().unwrap();
51 if self.par[ux] == x {
52 x
53 } else {
54 let px = self.par[ux];
55 let root = self.find(px);
56 self.par[ux] = root;
58 root
59 }
60 }
61
62 pub fn union(&mut self, x: T, y: T) {
63 let x = self.find(x);
64 let y = self.find(y);
65 if x == y {
66 return;
67 }
68 let ux = x.to_usize().unwrap();
69 let uy = y.to_usize().unwrap();
70
71 if self.rank[ux] < self.rank[uy] {
72 let tmp = self.group[ux];
73 self.group[uy] += tmp;
74 self.par[ux] = y;
75 } else {
76 let tmp = self.group[uy];
77 self.group[ux] += tmp;
78 self.par[uy] = x;
79 }
80 if self.rank[ux] == self.rank[uy] {
81 self.rank[uy] += T::one();
82 }
83 self.num_of_groups -= T::one();
84 }
85
86 pub fn is_same(&mut self, x: T, y: T) -> bool {
87 self.find(x) == self.find(y)
88 }
89
90 pub fn group_size(&mut self, x: T) -> T {
91 let p = self.find(x);
92 self.group[p.to_usize().unwrap()]
93 }
94
95 pub fn rank(&self, x: T) -> T {
96 self.rank[x.to_usize().unwrap()]
97 }
98
99 pub fn num_of_groups(&self) -> T {
100 self.num_of_groups
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 macro_rules! impl_union_find_tests {
109 ($ty:ty) => {
110 let mut uf = UnionFind::new(5 as $ty);
111
112 assert_eq!(5, uf.num_of_groups());
116
117 uf.union(0, 1);
118 assert_eq!(4, uf.num_of_groups());
119
120 uf.union(2, 3);
121 assert_eq!(3, uf.num_of_groups());
122
123 uf.union(1, 4);
124 assert_eq!(2, uf.num_of_groups());
125
126 assert_eq!(uf.find(0), uf.find(1));
127 assert_ne!(uf.find(0), uf.find(2));
128 assert_ne!(uf.find(0), uf.find(3));
129 assert_eq!(uf.find(0), uf.find(4));
130 assert_ne!(uf.find(1), uf.find(2));
131 assert_ne!(uf.find(1), uf.find(3));
132 assert_eq!(uf.find(1), uf.find(4));
133 assert_eq!(uf.find(2), uf.find(3));
134 assert_ne!(uf.find(2), uf.find(4));
135 assert_ne!(uf.find(3), uf.find(4));
136
137 assert!(uf.is_same(0, 1));
138 assert!(!uf.is_same(0, 2));
139 assert!(!uf.is_same(0, 3));
140 assert!(uf.is_same(0, 4));
141
142 assert_eq!(uf.rank(0), 0);
143 assert_eq!(uf.rank(1), 1);
144 assert_eq!(uf.rank(2), 0);
145 assert_eq!(uf.rank(3), 1);
146 assert_eq!(uf.rank(4), 1);
147
148 assert_eq!(uf.group_size(0), 3);
149 assert_eq!(uf.group_size(1), 3);
150 assert_eq!(uf.group_size(2), 2);
151 assert_eq!(uf.group_size(3), 2);
152 assert_eq!(uf.group_size(4), 3);
153
154 uf.union(0, 2);
155 assert_eq!(1, uf.num_of_groups());
156 };
157 }
158
159 #[test]
160 fn u64_union_find_test() {
161 impl_union_find_tests!(u64);
162 }
163 #[test]
164 fn u32_union_find_test() {
165 impl_union_find_tests!(u32);
166 }
167 #[test]
168 fn usize_union_find_test() {
169 impl_union_find_tests!(usize);
170 }
171}