use std::{mem::ManuallyDrop, sync::atomic::Ordering};
use rayon::prelude::*;
use crate::prelude::*;
pub struct DisjointSetStruct<NI: Idx>(Box<[Atomic<NI>]>);
unsafe impl<NI: Idx> Sync for DisjointSetStruct<NI> {}
unsafe impl<NI: Idx> Send for DisjointSetStruct<NI> {}
impl<NI: Idx> UnionFind<NI> for DisjointSetStruct<NI> {
fn union(&self, mut id1: NI, mut id2: NI) {
loop {
id1 = self.find(id1);
id2 = self.find(id2);
if id1 == id2 {
return;
}
if id1 < id2 {
std::mem::swap(&mut id1, &mut id2);
}
let old_entry = id1;
let new_entry = id2;
if self.update_parent(id1, old_entry, new_entry).is_ok() {
break;
}
}
}
fn find(&self, mut id: NI) -> NI {
let mut parent = self.parent(id);
while id != parent {
let grand_parent = self.parent(parent);
let _ = self.update_parent(id, parent, grand_parent);
id = parent;
parent = grand_parent;
}
id
}
fn len(&self) -> usize {
self.0.len()
}
fn compress(&self) {
(0..self.len()).into_par_iter().map(NI::new).for_each(|id| {
self.find(id);
});
}
}
impl<NI: Idx> DisjointSetStruct<NI> {
pub fn new(size: usize) -> Self {
let mut v = Vec::with_capacity(size);
(0..size)
.into_par_iter()
.map(|i| Atomic::new(NI::new(i)))
.collect_into_vec(&mut v);
Self(v.into_boxed_slice())
}
fn parent(&self, i: NI) -> NI {
self.0[i.index()].load(Ordering::SeqCst)
}
fn update_parent(&self, id: NI, current: NI, new: NI) -> Result<NI, NI> {
self.0[id.index()].compare_exchange_weak(current, new, Ordering::SeqCst, Ordering::Relaxed)
}
}
impl<NI: Idx> Components<NI> for DisjointSetStruct<NI> {
fn component(&self, node: NI) -> NI {
self.find(node)
}
fn to_vec(self) -> Vec<NI> {
let mut components = ManuallyDrop::new(self.0.into_vec());
let (ptr, len, cap) = (
components.as_mut_ptr(),
components.len(),
components.capacity(),
);
unsafe {
let ptr = ptr as *mut Vec<NI>;
let ptr = ptr as *mut _;
Vec::from_raw_parts(ptr, len, cap)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::Barrier;
use super::*;
#[test]
fn test_union() {
let dss = DisjointSetStruct::new(10);
assert_eq!(dss.find(9), 9);
dss.union(9, 7);
assert_eq!(dss.find(9), 7);
dss.union(7, 4);
assert_eq!(dss.find(9), 4);
dss.union(4, 2);
assert_eq!(dss.find(9), 2);
dss.union(2, 0);
assert_eq!(dss.find(9), 0);
}
#[test]
fn test_union_with_path_halving() {
let dss = DisjointSetStruct::new(10);
dss.union(4, 3);
dss.union(3, 2);
dss.union(2, 1);
dss.union(1, 0);
dss.union(9, 8);
dss.union(8, 7);
dss.union(7, 6);
dss.union(6, 5);
assert_eq!(dss.find(4), 0);
assert_eq!(dss.find(9), 5);
dss.union(5, 4);
for i in 0..dss.len() {
assert_eq!(dss.find(i), 0);
}
}
#[test]
fn test_union_parallel() {
let barrier = Arc::new(Barrier::new(2));
let dss = Arc::new(DisjointSetStruct::new(1000));
fn workload(barrier: &Barrier, dss: &DisjointSetStruct<u64>) {
barrier.wait();
for i in 0..500 {
dss.union(i, i + 1);
}
barrier.wait();
for i in 501..999 {
dss.union(i, i + 1);
}
}
let t1 = std::thread::spawn({
let barrier = Arc::clone(&barrier);
let dss = Arc::clone(&dss);
move || workload(&barrier, &dss)
});
let t2 = std::thread::spawn({
let barrier = Arc::clone(&barrier);
let dss = Arc::clone(&dss);
move || workload(&barrier, &dss)
});
t1.join().unwrap();
t2.join().unwrap();
for i in 0..500 {
assert_eq!(dss.find(i), dss.find(i + 1));
}
assert_ne!(dss.find(500), dss.find(501));
for i in 501..999 {
assert_eq!(dss.find(i), dss.find(i + 1));
}
}
}