use std::cell::Cell;
use std::fmt::{self, Debug};
use super::ElementType;
#[derive(Clone)]
pub struct UnionFind<E: ElementType = usize> {
elements: Vec<Cell<E>>,
ranks: Vec<u8>,
}
impl<E: Debug + ElementType> Debug for UnionFind<E> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "UnionFind({:?})", self.elements)
}
}
impl<E: ElementType> UnionFind<E> {
pub fn new(size: usize) -> Self {
UnionFind {
elements: (0..size).map(|i| {
let e = E::from_usize(i).expect("UnionFind::new: overflow");
Cell::new(e)
}).collect(),
ranks: vec![0; size],
}
}
pub fn len(&self) -> usize {
self.elements.len()
}
pub fn alloc(&mut self) -> E {
let result = E::from_usize(self.elements.len())
.expect("UnionFind::alloc: overflow");
self.elements.push(Cell::new(result));
self.ranks.push(0);
result
}
pub fn union(&mut self, a: E, b: E) {
let a = self.find(a);
let b = self.find(b);
if a == b { return }
let rank_a = self.rank(a);
let rank_b = self.rank(b);
if rank_a > rank_b {
self.set_parent(b, a);
} else if rank_b > rank_a {
self.set_parent(a, b);
} else {
self.set_parent(a, b);
self.increment_rank(b);
}
}
pub fn find(&self, mut element: E) -> E {
while element != self.parent(element) {
self.set_parent(element, self.grandparent(element));
element = self.parent(element);
}
element
}
pub fn equiv(&self, a: E, b: E) -> bool {
self.find(a) == self.find(b)
}
pub fn force(&self) {
for i in 0 .. self.len() {
self.find(E::from_usize(i).unwrap());
}
}
pub fn as_vec(&self) -> Vec<E> {
self.force();
self.elements.iter().map(Cell::get).collect()
}
fn rank(&self, element: E) -> u8 {
self.ranks[element.to_usize()]
}
fn increment_rank(&mut self, element: E) {
let i = element.to_usize();
let (rank, over) = self.ranks[i].overflowing_add(1);
assert!(!over, "UnionFind: rank overflow");
self.ranks[i] = rank;
}
fn parent(&self, element: E) -> E {
self.elements[element.to_usize()].get()
}
fn set_parent(&self, element: E, parent: E) {
self.elements[element.to_usize()].set(parent);
}
fn grandparent(&self, element: E) -> E {
self.parent(self.parent(element))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn len() {
assert_eq!(5, UnionFind::<u32>::new(5).len());
}
#[test]
fn union() {
let mut uf = UnionFind::<u32>::new(8);
assert!(!uf.equiv(0, 1));
uf.union(0, 1);
assert!(uf.equiv(0, 1));
}
#[test]
fn unions() {
let mut uf = UnionFind::<usize>::new(8);
uf.union(0, 1);
uf.union(1, 2);
uf.union(4, 3);
uf.union(3, 2);
assert!(uf.equiv(0, 1));
assert!(uf.equiv(0, 2));
assert!(uf.equiv(0, 3));
assert!(uf.equiv(0, 4));
assert!(!uf.equiv(0, 5));
uf.union(5, 3);
assert!(uf.equiv(0, 5));
uf.union(6, 7);
assert!(uf.equiv(6, 7));
assert!(!uf.equiv(5, 7));
uf.union(0, 7);
assert!(uf.equiv(5, 7));
}
}