Skip to main content

oxicuda_graphalg/mst/
union_find.rs

1//! Disjoint Set Union (Union-Find) with path compression and union by rank.
2
3#[derive(Debug, Clone)]
4pub struct UnionFind {
5    parent: Vec<usize>,
6    rank: Vec<u32>,
7    size: Vec<usize>,
8    num_sets: usize,
9}
10
11impl UnionFind {
12    pub fn new(n: usize) -> Self {
13        Self {
14            parent: (0..n).collect(),
15            rank: vec![0; n],
16            size: vec![1; n],
17            num_sets: n,
18        }
19    }
20
21    pub fn len(&self) -> usize {
22        self.parent.len()
23    }
24
25    pub fn is_empty(&self) -> bool {
26        self.parent.is_empty()
27    }
28
29    pub fn find(&mut self, x: usize) -> usize {
30        let mut r = x;
31        while self.parent[r] != r {
32            r = self.parent[r];
33        }
34        // Path compression
35        let mut cur = x;
36        while self.parent[cur] != r {
37            let next = self.parent[cur];
38            self.parent[cur] = r;
39            cur = next;
40        }
41        r
42    }
43
44    /// Union two sets. Returns true if they were merged.
45    pub fn union(&mut self, a: usize, b: usize) -> bool {
46        let ra = self.find(a);
47        let rb = self.find(b);
48        if ra == rb {
49            return false;
50        }
51        match self.rank[ra].cmp(&self.rank[rb]) {
52            std::cmp::Ordering::Less => {
53                self.parent[ra] = rb;
54                self.size[rb] += self.size[ra];
55            }
56            std::cmp::Ordering::Greater => {
57                self.parent[rb] = ra;
58                self.size[ra] += self.size[rb];
59            }
60            std::cmp::Ordering::Equal => {
61                self.parent[rb] = ra;
62                self.rank[ra] += 1;
63                self.size[ra] += self.size[rb];
64            }
65        }
66        self.num_sets -= 1;
67        true
68    }
69
70    pub fn set_size(&mut self, x: usize) -> usize {
71        let r = self.find(x);
72        self.size[r]
73    }
74
75    pub fn num_sets(&self) -> usize {
76        self.num_sets
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn uf_basic() {
86        let mut uf = UnionFind::new(5);
87        assert_eq!(uf.num_sets(), 5);
88        assert!(uf.union(0, 1));
89        assert_eq!(uf.num_sets(), 4);
90        assert!(!uf.union(0, 1));
91        assert_eq!(uf.find(0), uf.find(1));
92    }
93
94    #[test]
95    fn uf_set_size() {
96        let mut uf = UnionFind::new(4);
97        uf.union(0, 1);
98        uf.union(2, 3);
99        uf.union(1, 2);
100        assert_eq!(uf.set_size(0), 4);
101    }
102}