pub struct DisjointSet {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl DisjointSet {
pub fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
pub fn find(&mut self, mut x: usize) -> usize {
while self.parent[x] != x {
self.parent[x] = self.parent[self.parent[x]];
x = self.parent[x];
}
x
}
pub fn union(&mut self, x: usize, y: usize) -> bool {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return false;
}
if self.rank[rx] < self.rank[ry] {
self.parent[rx] = ry;
} else if self.rank[rx] > self.rank[ry] {
self.parent[ry] = rx;
} else {
self.parent[ry] = rx;
self.rank[rx] += 1;
}
true
}
pub fn connected(&mut self, x: usize, y: usize) -> bool {
self.find(x) == self.find(y)
}
}
pub struct SizedDisjointSet {
parent: Vec<usize>,
size: Vec<usize>,
}
impl SizedDisjointSet {
pub fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
size: vec![1; n],
}
}
pub fn find(&mut self, mut x: usize) -> usize {
while self.parent[x] != x {
self.parent[x] = self.parent[self.parent[x]];
x = self.parent[x];
}
x
}
pub fn union(&mut self, x: usize, y: usize) -> bool {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return false;
}
if self.size[rx] < self.size[ry] {
self.parent[rx] = ry;
self.size[ry] += self.size[rx];
} else {
self.parent[ry] = rx;
self.size[rx] += self.size[ry];
}
true
}
pub fn component_size(&mut self, x: usize) -> usize {
let r = self.find(x);
self.size[r]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_disjoint_set_basic() {
let mut ds = DisjointSet::new(5);
assert!(!ds.connected(0, 1));
assert!(ds.union(0, 1));
assert!(ds.connected(0, 1));
assert!(!ds.union(0, 1)); }
#[test]
fn test_disjoint_set_chain() {
let mut ds = DisjointSet::new(5);
ds.union(0, 1);
ds.union(1, 2);
ds.union(2, 3);
ds.union(3, 4);
for i in 0..5 {
for j in 0..5 {
assert!(ds.connected(i, j));
}
}
}
#[test]
fn test_disjoint_set_two_components() {
let mut ds = DisjointSet::new(6);
ds.union(0, 1);
ds.union(1, 2);
ds.union(3, 4);
ds.union(4, 5);
assert!(ds.connected(0, 2));
assert!(ds.connected(3, 5));
assert!(!ds.connected(0, 3));
}
#[test]
fn test_sized_disjoint_set() {
let mut ds = SizedDisjointSet::new(5);
assert_eq!(ds.component_size(0), 1);
ds.union(0, 1);
assert_eq!(ds.component_size(0), 2);
assert_eq!(ds.component_size(1), 2);
ds.union(0, 2);
assert_eq!(ds.component_size(2), 3);
}
#[test]
fn test_sized_disjoint_set_full_merge() {
let mut ds = SizedDisjointSet::new(4);
ds.union(0, 1);
ds.union(2, 3);
ds.union(0, 3);
assert_eq!(ds.component_size(0), 4);
assert_eq!(ds.component_size(3), 4);
}
}