Skip to main content

oxihuman_core/
disjoint_set.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan) / SPDX-License-Identifier: Apache-2.0
2#![allow(dead_code)]
3
4//! Union-Find (disjoint set) with path compression and union by rank.
5
6/// A disjoint set forest with path compression and union by rank.
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9pub struct DisjointSet {
10    parent: Vec<usize>,
11    rank: Vec<u32>,
12    count: usize, // number of distinct sets
13}
14
15#[allow(dead_code)]
16impl DisjointSet {
17    pub fn new(n: usize) -> Self {
18        Self {
19            parent: (0..n).collect(),
20            rank: vec![0; n],
21            count: n,
22        }
23    }
24
25    pub fn find(&mut self, mut x: usize) -> usize {
26        while self.parent[x] != x {
27            self.parent[x] = self.parent[self.parent[x]]; // path halving
28            x = self.parent[x];
29        }
30        x
31    }
32
33    pub fn union(&mut self, a: usize, b: usize) -> bool {
34        let ra = self.find(a);
35        let rb = self.find(b);
36        if ra == rb {
37            return false;
38        }
39        match self.rank[ra].cmp(&self.rank[rb]) {
40            std::cmp::Ordering::Less => self.parent[ra] = rb,
41            std::cmp::Ordering::Greater => self.parent[rb] = ra,
42            std::cmp::Ordering::Equal => {
43                self.parent[rb] = ra;
44                self.rank[ra] += 1;
45            }
46        }
47        self.count -= 1;
48        true
49    }
50
51    pub fn connected(&mut self, a: usize, b: usize) -> bool {
52        self.find(a) == self.find(b)
53    }
54
55    pub fn set_count(&self) -> usize {
56        self.count
57    }
58
59    pub fn element_count(&self) -> usize {
60        self.parent.len()
61    }
62
63    /// Size of the set containing x.
64    pub fn set_size(&mut self, x: usize) -> usize {
65        let root = self.find(x);
66        let n = self.parent.len();
67        let mut size = 0;
68        for i in 0..n {
69            if self.find(i) == root {
70                size += 1;
71            }
72        }
73        size
74    }
75
76    /// Returns all roots (representatives).
77    pub fn roots(&mut self) -> Vec<usize> {
78        let n = self.parent.len();
79        let mut roots = Vec::new();
80        for i in 0..n {
81            if self.find(i) == i {
82                roots.push(i);
83            }
84        }
85        roots
86    }
87}
88
89#[allow(dead_code)]
90pub fn new_disjoint_set(n: usize) -> DisjointSet {
91    DisjointSet::new(n)
92}
93
94#[allow(dead_code)]
95pub fn ds_find(ds: &mut DisjointSet, x: usize) -> usize {
96    ds.find(x)
97}
98
99#[allow(dead_code)]
100pub fn ds_union(ds: &mut DisjointSet, x: usize, y: usize) -> bool {
101    ds.union(x, y)
102}
103
104#[allow(dead_code)]
105pub fn ds_connected(ds: &mut DisjointSet, x: usize, y: usize) -> bool {
106    ds.connected(x, y)
107}
108
109#[allow(dead_code)]
110pub fn ds_same(ds: &mut DisjointSet, x: usize, y: usize) -> bool {
111    ds.connected(x, y)
112}
113
114#[allow(dead_code)]
115pub fn ds_component_count(ds: &DisjointSet) -> usize {
116    ds.set_count()
117}
118
119#[allow(dead_code)]
120pub fn ds_size(ds: &DisjointSet) -> usize {
121    ds.element_count()
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_initial_separate() {
130        let mut ds = DisjointSet::new(5);
131        assert!(!ds.connected(0, 1));
132        assert_eq!(ds.set_count(), 5);
133    }
134
135    #[test]
136    fn test_union() {
137        let mut ds = DisjointSet::new(5);
138        assert!(ds.union(0, 1));
139        assert!(ds.connected(0, 1));
140        assert_eq!(ds.set_count(), 4);
141    }
142
143    #[test]
144    fn test_transitive() {
145        let mut ds = DisjointSet::new(5);
146        ds.union(0, 1);
147        ds.union(1, 2);
148        assert!(ds.connected(0, 2));
149    }
150
151    #[test]
152    fn test_no_dup_union() {
153        let mut ds = DisjointSet::new(3);
154        assert!(ds.union(0, 1));
155        assert!(!ds.union(0, 1)); // already same set
156    }
157
158    #[test]
159    fn test_set_size() {
160        let mut ds = DisjointSet::new(5);
161        ds.union(0, 1);
162        ds.union(0, 2);
163        assert_eq!(ds.set_size(0), 3);
164        assert_eq!(ds.set_size(3), 1);
165    }
166
167    #[test]
168    fn test_roots() {
169        let mut ds = DisjointSet::new(4);
170        ds.union(0, 1);
171        ds.union(2, 3);
172        let roots = ds.roots();
173        assert_eq!(roots.len(), 2);
174    }
175
176    #[test]
177    fn test_all_union() {
178        let mut ds = DisjointSet::new(4);
179        ds.union(0, 1);
180        ds.union(2, 3);
181        ds.union(0, 2);
182        assert_eq!(ds.set_count(), 1);
183    }
184
185    #[test]
186    fn test_element_count() {
187        let ds = DisjointSet::new(10);
188        assert_eq!(ds.element_count(), 10);
189    }
190
191    #[test]
192    fn test_find_self() {
193        let mut ds = DisjointSet::new(3);
194        assert_eq!(ds.find(2), 2);
195    }
196
197    #[test]
198    fn test_large_chain() {
199        let mut ds = DisjointSet::new(100);
200        for i in 0..99 {
201            ds.union(i, i + 1);
202        }
203        assert_eq!(ds.set_count(), 1);
204        assert!(ds.connected(0, 99));
205    }
206}