use crate::Id;
use std::cell::Cell;
use std::fmt::Debug;
#[derive(Debug, Clone, Default)]
pub struct UnionFind {
parents: Vec<Cell<Id>>,
}
impl UnionFind {
pub fn make_set(&mut self) -> Id {
let id = Id::from(self.parents.len());
self.parents.push(Cell::new(id));
id
}
#[inline(always)]
fn parent(&self, query: Id) -> Id {
self.parents[usize::from(query)].get()
}
#[inline(always)]
fn set_parent(&self, query: Id, new_parent: Id) {
self.parents[usize::from(query)].set(new_parent)
}
pub fn find(&self, mut current: Id) -> Id {
loop {
let parent = self.parent(current);
if current == parent {
return parent;
}
let grandparent = self.parent(parent);
self.set_parent(current, grandparent);
current = grandparent;
}
}
pub fn union(&mut self, set1: Id, set2: Id) -> (Id, Id) {
let mut root1 = self.find(set1);
let mut root2 = self.find(set2);
if root1 == root2 {
(root1, root2)
} else {
if root1 > root2 {
std::mem::swap(&mut root1, &mut root2);
}
self.set_parent(root2, root1);
(root1, root2)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use indexmap::{indexmap, indexset, IndexMap, IndexSet};
impl UnionFind {
pub fn build_sets(&self) -> IndexMap<Id, IndexSet<Id>> {
let mut map: IndexMap<Id, IndexSet<Id>> = Default::default();
for i in 0..self.parents.len() {
let i = Id::from(i);
let leader = self.find(i);
map.entry(leader).or_default().insert(i);
}
map
}
}
fn make_union_find(n: usize) -> UnionFind {
let mut uf = UnionFind::default();
for _ in 0..n {
uf.make_set();
}
uf
}
#[test]
fn union_find() {
let n = 10;
fn id(u: usize) -> Id {
u.into()
}
let mut uf = make_union_find(n);
for i in 0..n {
let i = Id::from(i);
assert_eq!(uf.find(i), i);
assert_eq!(uf.find(i), i);
}
let expected_sets = (0..n)
.map(|i| (id(i), indexset!(id(i))))
.collect::<IndexMap<_, _>>();
assert_eq!(uf.build_sets(), expected_sets);
assert_eq!(uf.union(id(0), id(1)), (id(0), id(1)));
assert_eq!(uf.union(id(1), id(2)), (id(0), id(2)));
assert_eq!(uf.union(id(3), id(2)), (id(0), id(3)));
assert_eq!(uf.union(id(6), id(7)), (id(6), id(7)));
assert_eq!(uf.union(id(8), id(9)), (id(8), id(9)));
assert_eq!(uf.union(id(7), id(9)), (id(6), id(8)));
assert_eq!(uf.union(id(1), id(3)), (id(0), id(0)));
assert_eq!(uf.union(id(7), id(8)), (id(6), id(6)));
let expected_sets = indexmap!(
id(0) => indexset!(id(0), id(1), id(2), id(3)),
id(4) => indexset!(id(4)),
id(5) => indexset!(id(5)),
id(6) => indexset!(id(6), id(7), id(8), id(9)),
);
assert_eq!(uf.build_sets(), expected_sets);
for i in 0..n {
assert_eq!(uf.parent(id(i)), uf.find(id(i)));
}
}
}