Skip to main content

codemem_engine/consolidation/
union_find.rs

1use std::collections::HashMap;
2
3/// Union-Find (disjoint set) data structure for transitive clustering.
4pub struct UnionFind {
5    parent: Vec<usize>,
6    rank: Vec<usize>,
7}
8
9impl UnionFind {
10    pub fn new(n: usize) -> Self {
11        Self {
12            parent: (0..n).collect(),
13            rank: vec![0; n],
14        }
15    }
16
17    pub fn find(&mut self, x: usize) -> usize {
18        if self.parent[x] != x {
19            self.parent[x] = self.find(self.parent[x]);
20        }
21        self.parent[x]
22    }
23
24    pub fn union(&mut self, x: usize, y: usize) {
25        let rx = self.find(x);
26        let ry = self.find(y);
27        if rx == ry {
28            return;
29        }
30        match self.rank[rx].cmp(&self.rank[ry]) {
31            std::cmp::Ordering::Less => self.parent[rx] = ry,
32            std::cmp::Ordering::Greater => self.parent[ry] = rx,
33            std::cmp::Ordering::Equal => {
34                self.parent[ry] = rx;
35                self.rank[rx] += 1;
36            }
37        }
38    }
39
40    pub fn groups(&mut self, n: usize) -> Vec<Vec<usize>> {
41        let mut map: HashMap<usize, Vec<usize>> = HashMap::new();
42        for i in 0..n {
43            let root = self.find(i);
44            map.entry(root).or_default().push(i);
45        }
46        map.into_values().collect()
47    }
48}