1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/// # UnionFind
///
/// Example:
/// ```
/// use competitive_hpp::union_find::UnionFind;
///
/// let mut uf = UnionFind::new(5);
/// uf.union(0, 1);
/// uf.union(2, 3);
/// uf.union(1, 4);
///
/// uf.find(1); // 0
/// uf.is_same(0, 1); // true
/// uf.group_size(0); // 3
/// ```
#[derive(Clone, Debug)]
pub struct UnionFind {
    par: Vec<usize>,
    rank: Vec<usize>,
    group: Vec<usize>,
}

impl UnionFind {
    pub fn new(n: usize) -> Self {
        UnionFind {
            par: (0..n).collect(),
            rank: vec![0; n],
            group: vec![1; n],
        }
    }

    pub fn find(&mut self, x: usize) -> usize {
        if self.par[x] == x {
            x
        } else {
            let px = self.par[x];
            let root = self.find(px);
            // reattach edges
            self.par[x] = root;
            root
        }
    }

    pub fn union(&mut self, x: usize, y: usize) {
        let x = self.find(x);
        let y = self.find(y);
        if x == y {
            return;
        }
        if self.rank[x] < self.rank[y] {
            self.group[y] += self.group[x];
            self.par[x] = y;
        } else {
            self.group[x] += self.group[y];
            self.par[y] = x;
        }
        if self.rank[x] == self.rank[y] {
            self.rank[y] += 1;
        }
    }

    pub fn is_same(&mut self, x: usize, y: usize) -> bool {
        self.find(x) == self.find(y)
    }

    pub fn group_size(&mut self, x: usize) -> usize {
        let p = self.find(x);
        self.group[p]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn union_find_test() {
        let mut uf = UnionFind::new(5);

        // 0 ━━━━━ 1 ━━━━━ 4
        //
        // 2 ━━━━━ 3

        uf.union(0, 1);
        uf.union(2, 3);
        uf.union(1, 4);

        assert_eq!(uf.find(0), uf.find(1));
        assert_ne!(uf.find(0), uf.find(2));
        assert_ne!(uf.find(0), uf.find(3));
        assert_eq!(uf.find(0), uf.find(4));
        assert_ne!(uf.find(1), uf.find(2));
        assert_ne!(uf.find(1), uf.find(3));
        assert_eq!(uf.find(1), uf.find(4));
        assert_eq!(uf.find(2), uf.find(3));
        assert_ne!(uf.find(2), uf.find(4));
        assert_ne!(uf.find(3), uf.find(4));

        assert!(uf.is_same(0, 1));
        assert!(!uf.is_same(0, 2));
        assert!(!uf.is_same(0, 3));
        assert!(uf.is_same(0, 4));

        assert_eq!(uf.rank[0], 0);
        assert_eq!(uf.rank[1], 1);
        assert_eq!(uf.rank[2], 0);
        assert_eq!(uf.rank[3], 1);
        assert_eq!(uf.rank[4], 1);

        assert_eq!(uf.group_size(0), 3);
        assert_eq!(uf.group_size(1), 3);
        assert_eq!(uf.group_size(2), 2);
        assert_eq!(uf.group_size(3), 2);
        assert_eq!(uf.group_size(4), 3);

        // dbg!(uf);
    }
}