use crate::models::{ComponentId, NodeId};
use crate::Result;
#[derive(Debug, Clone)]
pub struct UnionFind {
parent: Vec<u32>,
rank: Vec<u32>,
num_components: usize,
}
impl UnionFind {
pub fn new(n: usize) -> Self {
Self {
parent: (0..n as u32).collect(),
rank: vec![0; n],
num_components: n,
}
}
pub fn len(&self) -> usize {
self.parent.len()
}
pub fn is_empty(&self) -> bool {
self.parent.is_empty()
}
pub fn num_components(&self) -> usize {
self.num_components
}
pub fn find(&mut self, x: NodeId) -> NodeId {
let mut root = x.0;
while self.parent[root as usize] != root {
root = self.parent[root as usize];
}
let mut node = x.0;
while self.parent[node as usize] != root {
let next = self.parent[node as usize];
self.parent[node as usize] = root;
node = next;
}
NodeId(root)
}
pub fn union(&mut self, x: NodeId, y: NodeId) -> bool {
let root_x = self.find(x);
let root_y = self.find(y);
if root_x == root_y {
return false; }
let rx = self.rank[root_x.0 as usize];
let ry = self.rank[root_y.0 as usize];
if rx < ry {
self.parent[root_x.0 as usize] = root_y.0;
} else if rx > ry {
self.parent[root_y.0 as usize] = root_x.0;
} else {
self.parent[root_y.0 as usize] = root_x.0;
self.rank[root_x.0 as usize] += 1;
}
self.num_components -= 1;
true
}
pub fn connected(&mut self, x: NodeId, y: NodeId) -> bool {
self.find(x) == self.find(y)
}
pub fn component_ids(&mut self) -> Vec<ComponentId> {
let n = self.parent.len();
let mut comp_id = vec![ComponentId::UNASSIGNED; n];
let mut next_id = 0u32;
for i in 0..n {
let root = self.find(NodeId(i as u32));
if !comp_id[root.0 as usize].is_assigned() {
comp_id[root.0 as usize] = ComponentId::new(next_id);
next_id += 1;
}
comp_id[i] = comp_id[root.0 as usize];
}
comp_id
}
pub fn component_size(&mut self, x: NodeId) -> usize {
let root = self.find(x);
let mut count = 0;
for i in 0..self.parent.len() {
if self.find(NodeId(i as u32)) == root {
count += 1;
}
}
count
}
}
pub fn union_find_sequential(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
let mut uf = UnionFind::new(n);
for &(u, v) in edges {
uf.union(u, v);
}
Ok(uf.component_ids())
}
pub fn union_find_parallel(n: usize, edges: &[(NodeId, NodeId)]) -> Result<Vec<ComponentId>> {
use std::sync::atomic::{AtomicU32, Ordering};
if n == 0 {
return Ok(vec![]);
}
let parent: Vec<AtomicU32> = (0..n as u32).map(AtomicU32::new).collect();
let mut changed = true;
let mut iterations = 0;
const MAX_ITERATIONS: usize = 64;
while changed && iterations < MAX_ITERATIONS {
changed = false;
iterations += 1;
for &(u, v) in edges {
let mut pu = parent[u.0 as usize].load(Ordering::Relaxed);
let mut pv = parent[v.0 as usize].load(Ordering::Relaxed);
for _ in 0..n {
let gpu = parent[pu as usize].load(Ordering::Relaxed);
if gpu == pu {
break;
}
pu = gpu;
}
for _ in 0..n {
let gpv = parent[pv as usize].load(Ordering::Relaxed);
if gpv == pv {
break;
}
pv = gpv;
}
if pu != pv {
let (smaller, larger) = if pu < pv { (pu, pv) } else { (pv, pu) };
if parent[smaller as usize]
.compare_exchange(smaller, larger, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
changed = true;
}
}
}
for i in 0..n {
let pi = parent[i].load(Ordering::Relaxed);
if pi != i as u32 {
let gpi = parent[pi as usize].load(Ordering::Relaxed);
if gpi != pi {
let _ =
parent[i].compare_exchange(pi, gpi, Ordering::AcqRel, Ordering::Relaxed);
changed = true;
}
}
}
}
let mut final_parent: Vec<u32> = parent.iter().map(|p| p.load(Ordering::Relaxed)).collect();
for i in 0..n {
let mut root = i as u32;
while final_parent[root as usize] != root {
root = final_parent[root as usize];
}
let mut node = i as u32;
while final_parent[node as usize] != root {
let next = final_parent[node as usize];
final_parent[node as usize] = root;
node = next;
}
}
let mut comp_id = vec![ComponentId::UNASSIGNED; n];
let mut next_id = 0u32;
for i in 0..n {
let root = final_parent[i] as usize;
if !comp_id[root].is_assigned() {
comp_id[root] = ComponentId::new(next_id);
next_id += 1;
}
comp_id[i] = comp_id[root];
}
Ok(comp_id)
}
#[cfg(feature = "cuda")]
pub fn union_find_gpu_ready(n: usize, edges: &[(NodeId, NodeId)]) -> Result<(Vec<u32>, usize)> {
let components = union_find_parallel(n, edges)?;
let mut parent: Vec<u32> = (0..n as u32).collect();
let mut num_components = 0u32;
for i in 0..n {
if components[i].0 == num_components {
num_components += 1;
}
let comp = components[i].0;
for j in 0..=i {
if components[j].0 == comp {
parent[i] = j as u32;
break;
}
}
}
Ok((parent, num_components as usize))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_singleton_sets() {
let mut uf = UnionFind::new(5);
assert_eq!(uf.num_components(), 5);
for i in 0..5 {
assert_eq!(uf.find(NodeId(i)), NodeId(i));
}
}
#[test]
fn test_union_basic() {
let mut uf = UnionFind::new(5);
assert!(uf.union(NodeId(0), NodeId(1)));
assert_eq!(uf.num_components(), 4);
assert!(uf.connected(NodeId(0), NodeId(1)));
assert!(uf.union(NodeId(2), NodeId(3)));
assert_eq!(uf.num_components(), 3);
assert!(uf.union(NodeId(0), NodeId(2)));
assert_eq!(uf.num_components(), 2);
assert!(uf.connected(NodeId(0), NodeId(3)));
}
#[test]
fn test_union_same_component() {
let mut uf = UnionFind::new(3);
uf.union(NodeId(0), NodeId(1));
uf.union(NodeId(1), NodeId(2));
assert!(uf.connected(NodeId(0), NodeId(2)));
assert!(!uf.union(NodeId(0), NodeId(2)));
assert_eq!(uf.num_components(), 1);
}
#[test]
fn test_path_compression() {
let mut uf = UnionFind::new(10);
for i in 0..9 {
uf.union(NodeId(i), NodeId(i + 1));
}
let root = uf.find(NodeId(9));
for i in 0..10 {
assert_eq!(uf.find(NodeId(i)), root);
}
}
#[test]
fn test_component_ids() {
let mut uf = UnionFind::new(5);
uf.union(NodeId(0), NodeId(1));
uf.union(NodeId(2), NodeId(3));
let ids = uf.component_ids();
assert_eq!(ids[0], ids[1]);
assert_eq!(ids[2], ids[3]);
assert_ne!(ids[4], ids[0]);
assert_ne!(ids[4], ids[2]);
assert_eq!(uf.num_components(), 3);
}
#[test]
fn test_component_size() {
let mut uf = UnionFind::new(6);
uf.union(NodeId(0), NodeId(1));
uf.union(NodeId(1), NodeId(2));
uf.union(NodeId(3), NodeId(4));
assert_eq!(uf.component_size(NodeId(0)), 3);
assert_eq!(uf.component_size(NodeId(3)), 2);
assert_eq!(uf.component_size(NodeId(5)), 1);
}
#[test]
fn test_union_find_from_edges() {
let edges = [
(NodeId(0), NodeId(1)),
(NodeId(1), NodeId(2)),
(NodeId(3), NodeId(4)),
];
let components = union_find_sequential(5, &edges).unwrap();
assert_eq!(components[0], components[1]);
assert_eq!(components[1], components[2]);
assert_eq!(components[3], components[4]);
assert_ne!(components[0], components[3]);
}
#[test]
fn test_empty_union_find() {
let uf = UnionFind::new(0);
assert!(uf.is_empty());
assert_eq!(uf.num_components(), 0);
}
#[test]
fn test_parallel_union_find_basic() {
let edges = [
(NodeId(0), NodeId(1)),
(NodeId(1), NodeId(2)),
(NodeId(3), NodeId(4)),
];
let components = union_find_parallel(5, &edges).unwrap();
assert_eq!(components[0], components[1]);
assert_eq!(components[1], components[2]);
assert_eq!(components[3], components[4]);
assert_ne!(components[0], components[3]);
}
#[test]
fn test_parallel_union_find_single_component() {
let edges: Vec<_> = (0..9).map(|i| (NodeId(i), NodeId(i + 1))).collect();
let components = union_find_parallel(10, &edges).unwrap();
for i in 1..10 {
assert_eq!(components[0], components[i]);
}
}
#[test]
fn test_parallel_union_find_no_edges() {
let components = union_find_parallel(5, &[]).unwrap();
for i in 0..5 {
for j in (i + 1)..5 {
assert_ne!(components[i], components[j]);
}
}
}
#[test]
fn test_parallel_union_find_empty() {
let components = union_find_parallel(0, &[]).unwrap();
assert!(components.is_empty());
}
#[test]
fn test_parallel_vs_sequential_consistency() {
let edges = [
(NodeId(0), NodeId(5)),
(NodeId(1), NodeId(6)),
(NodeId(2), NodeId(7)),
(NodeId(5), NodeId(6)),
(NodeId(3), NodeId(8)),
(NodeId(4), NodeId(9)),
(NodeId(8), NodeId(9)),
];
let seq_components = union_find_sequential(10, &edges).unwrap();
let par_components = union_find_parallel(10, &edges).unwrap();
for i in 0..10 {
for j in (i + 1)..10 {
let seq_same = seq_components[i] == seq_components[j];
let par_same = par_components[i] == par_components[j];
assert_eq!(seq_same, par_same, "Mismatch for nodes {} and {}", i, j);
}
}
}
#[test]
fn test_parallel_union_find_star_graph() {
let edges: Vec<_> = (1..10).map(|i| (NodeId(0), NodeId(i))).collect();
let components = union_find_parallel(10, &edges).unwrap();
for i in 1..10 {
assert_eq!(components[0], components[i]);
}
}
}