Skip to main content

ac_lib/graph/
unionfind.rs

1pub struct UnionFind {
2    parent: Vec<usize>,
3    rank: Vec<usize>,
4    size: Vec<usize>,
5}
6
7impl UnionFind {
8    pub fn new(size: usize) -> Self {
9        UnionFind {
10            parent: (0..size).collect(),
11            rank: vec![0; size],
12            size: vec![1; size],
13        }
14    }
15
16    pub fn find(&mut self, x: usize) -> usize {
17        if self.parent[x] != x {
18            self.parent[x] = self.find(self.parent[x]);
19        }
20        self.parent[x]
21    }
22
23    pub fn union(&mut self, x: usize, y: usize) -> bool {
24        let root_x = self.find(x);
25        let root_y = self.find(y);
26
27        if root_x == root_y {
28            return false;
29        }
30
31        if self.rank[root_x] > self.rank[root_y] {
32            self.parent[root_y] = root_x;
33            self.size[root_x] += self.size[root_y];
34        } else if self.rank[root_x] < self.rank[root_y] {
35            self.parent[root_x] = root_y;
36            self.size[root_y] += self.size[root_x];
37        } else {
38            self.parent[root_y] = root_x;
39            self.rank[root_x] += 1;
40            self.size[root_x] += self.size[root_y];
41        }
42
43        true
44    }
45
46    pub fn connected(&mut self, x: usize, y: usize) -> bool {
47        self.find(x) == self.find(y)
48    }
49
50    pub fn size(&mut self, x: usize) -> usize {
51        let root = self.find(x);
52        self.size[root]
53    }
54
55    pub fn count_groups(&mut self) -> usize {
56        let n = self.parent.len();
57        (0..n).filter(|&i| self.find(i) == i).count()
58    }
59}