competitive_hpp/
union_find.rs

1use num::traits::{NumAssignOps, PrimInt, Unsigned};
2
3/// # UnionFind
4///
5/// Example:
6/// ```
7/// use competitive_hpp::prelude::*;
8/// // 0 ━━━━━ 1 ━━━━━ 4
9/// //
10/// // 2 ━━━━━ 3
11/// let mut uf = UnionFind::new(5usize);
12/// uf.union(0, 1);
13/// uf.union(2, 3);
14/// uf.union(1, 4);
15///
16/// assert_eq!(uf.find(0), uf.find(1));
17/// assert_eq!(2, uf.num_of_groups());
18/// assert!(uf.is_same(0, 1));
19/// assert_eq!(uf.rank(2), 0);
20/// assert_eq!(uf.group_size(0), 3);
21/// ```
22#[derive(Clone, Debug)]
23pub struct UnionFind<T> {
24    par: Vec<T>,
25    rank: Vec<T>,
26    group: Vec<T>,
27    num_of_groups: T,
28}
29
30impl<T> UnionFind<T>
31where
32    T: PrimInt + NumAssignOps + Unsigned,
33{
34    pub fn new(n: T) -> Self {
35        let mut par: Vec<T> = vec![];
36        let un = n.to_u64().unwrap();
37        for i in 0..un {
38            par.push(T::from(i).unwrap())
39        }
40
41        UnionFind {
42            par,
43            rank: vec![T::zero(); n.to_usize().unwrap()],
44            group: vec![T::one(); n.to_usize().unwrap()],
45            num_of_groups: n,
46        }
47    }
48
49    pub fn find(&mut self, x: T) -> T {
50        let ux = x.to_usize().unwrap();
51        if self.par[ux] == x {
52            x
53        } else {
54            let px = self.par[ux];
55            let root = self.find(px);
56            // reattach edges
57            self.par[ux] = root;
58            root
59        }
60    }
61
62    pub fn union(&mut self, x: T, y: T) {
63        let x = self.find(x);
64        let y = self.find(y);
65        if x == y {
66            return;
67        }
68        let ux = x.to_usize().unwrap();
69        let uy = y.to_usize().unwrap();
70
71        if self.rank[ux] < self.rank[uy] {
72            let tmp = self.group[ux];
73            self.group[uy] += tmp;
74            self.par[ux] = y;
75        } else {
76            let tmp = self.group[uy];
77            self.group[ux] += tmp;
78            self.par[uy] = x;
79        }
80        if self.rank[ux] == self.rank[uy] {
81            self.rank[uy] += T::one();
82        }
83        self.num_of_groups -= T::one();
84    }
85
86    pub fn is_same(&mut self, x: T, y: T) -> bool {
87        self.find(x) == self.find(y)
88    }
89
90    pub fn group_size(&mut self, x: T) -> T {
91        let p = self.find(x);
92        self.group[p.to_usize().unwrap()]
93    }
94
95    pub fn rank(&self, x: T) -> T {
96        self.rank[x.to_usize().unwrap()]
97    }
98
99    pub fn num_of_groups(&self) -> T {
100        self.num_of_groups
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    macro_rules! impl_union_find_tests {
109        ($ty:ty) => {
110            let mut uf = UnionFind::new(5 as $ty);
111
112            // 0 ━━━━━ 1 ━━━━━ 4
113            //
114            // 2 ━━━━━ 3
115            assert_eq!(5, uf.num_of_groups());
116
117            uf.union(0, 1);
118            assert_eq!(4, uf.num_of_groups());
119
120            uf.union(2, 3);
121            assert_eq!(3, uf.num_of_groups());
122
123            uf.union(1, 4);
124            assert_eq!(2, uf.num_of_groups());
125
126            assert_eq!(uf.find(0), uf.find(1));
127            assert_ne!(uf.find(0), uf.find(2));
128            assert_ne!(uf.find(0), uf.find(3));
129            assert_eq!(uf.find(0), uf.find(4));
130            assert_ne!(uf.find(1), uf.find(2));
131            assert_ne!(uf.find(1), uf.find(3));
132            assert_eq!(uf.find(1), uf.find(4));
133            assert_eq!(uf.find(2), uf.find(3));
134            assert_ne!(uf.find(2), uf.find(4));
135            assert_ne!(uf.find(3), uf.find(4));
136
137            assert!(uf.is_same(0, 1));
138            assert!(!uf.is_same(0, 2));
139            assert!(!uf.is_same(0, 3));
140            assert!(uf.is_same(0, 4));
141
142            assert_eq!(uf.rank(0), 0);
143            assert_eq!(uf.rank(1), 1);
144            assert_eq!(uf.rank(2), 0);
145            assert_eq!(uf.rank(3), 1);
146            assert_eq!(uf.rank(4), 1);
147
148            assert_eq!(uf.group_size(0), 3);
149            assert_eq!(uf.group_size(1), 3);
150            assert_eq!(uf.group_size(2), 2);
151            assert_eq!(uf.group_size(3), 2);
152            assert_eq!(uf.group_size(4), 3);
153
154            uf.union(0, 2);
155            assert_eq!(1, uf.num_of_groups());
156        };
157    }
158
159    #[test]
160    fn u64_union_find_test() {
161        impl_union_find_tests!(u64);
162    }
163    #[test]
164    fn u32_union_find_test() {
165        impl_union_find_tests!(u32);
166    }
167    #[test]
168    fn usize_union_find_test() {
169        impl_union_find_tests!(usize);
170    }
171}