use portable_atomic::{AtomicU128, Ordering};
pub struct DisjointSets {
data: Vec<AtomicU128>,
}
const PARENT_MASK: u128 = u64::MAX as u128;
const RANK_MASK: u128 = (u64::MAX as u128) << 64;
impl DisjointSets {
pub fn new(size: usize) -> Self {
assert!(
AtomicU128::is_lock_free(),
"AtomicU128 must be lock-free for dset64 performance!"
);
let data: Vec<AtomicU128> = (0..size).map(|i| AtomicU128::new(i as u128)).collect();
DisjointSets { data }
}
#[inline(always)]
pub fn find(&self, mut id: usize) -> usize {
while id != self.parent(id) {
let value = self.data[id].load(Ordering::Relaxed);
let new_parent = self.parent((value & PARENT_MASK) as usize);
let new_value = (value & RANK_MASK) | (new_parent as u128);
if value != new_value {
let _ = self.data[id].compare_exchange_weak(
value,
new_value,
Ordering::SeqCst,
Ordering::SeqCst,
);
}
id = new_parent;
}
id
}
#[inline]
pub fn same(&self, mut id1: usize, mut id2: usize) -> bool {
loop {
id1 = self.find(id1);
id2 = self.find(id2);
if id1 == id2 {
return true;
}
if self.parent(id1) == id1 {
return false;
}
}
}
#[inline(always)]
pub fn unite(&self, mut id1: usize, mut id2: usize) -> usize {
loop {
id1 = self.find(id1);
id2 = self.find(id2);
if id1 == id2 {
return id1;
}
let mut r1 = self.rank(id1);
let mut r2 = self.rank(id2);
if r1 > r2 || (r1 == r2 && id1 < id2) {
std::mem::swap(&mut r1, &mut r2);
std::mem::swap(&mut id1, &mut id2);
}
let old_entry = ((r1 as u128) << 64) | (id1 as u128);
let new_entry = ((r1 as u128) << 64) | (id2 as u128);
if self.data[id1]
.compare_exchange(old_entry, new_entry, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
continue; }
if r1 == r2 {
let old_entry = ((r2 as u128) << 64) | (id2 as u128);
let new_entry = (((r2 + 1) as u128) << 64) | (id2 as u128);
let _ = self.data[id2].compare_exchange_weak(
old_entry,
new_entry,
Ordering::SeqCst,
Ordering::SeqCst,
);
}
return id2;
}
}
#[inline]
pub fn size(&self) -> usize {
self.data.len()
}
#[inline(always)]
fn rank(&self, id: usize) -> u64 {
((self.data[id].load(Ordering::Relaxed) >> 64) & PARENT_MASK) as u64
}
#[inline(always)]
fn parent(&self, id: usize) -> usize {
(self.data[id].load(Ordering::Relaxed) & PARENT_MASK) as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let dsets = DisjointSets::new(10);
assert_eq!(dsets.find(0), 0);
assert_eq!(dsets.find(5), 5);
assert!(!dsets.same(0, 5));
dsets.unite(0, 5);
assert!(dsets.same(0, 5));
dsets.unite(5, 7);
assert!(dsets.same(0, 7));
assert!(dsets.same(5, 7));
assert!(!dsets.same(0, 3));
}
#[test]
fn test_path_compression() {
let dsets = DisjointSets::new(5);
dsets.unite(0, 1);
dsets.unite(1, 2);
dsets.unite(2, 3);
dsets.unite(3, 4);
let root = dsets.find(0);
assert_eq!(dsets.find(1), root);
assert_eq!(dsets.find(2), root);
assert_eq!(dsets.find(3), root);
assert_eq!(dsets.find(4), root);
}
}