use num_traits::FromPrimitive;
use std::cell::Cell;
#[derive(Clone, Debug)]
pub struct UF<I: Copy> {
leaders: Vec<Cell<I>>,
min_uncanonical: Cell<I>,
}
impl<I> PartialEq for UF<I>
where
I: Into<usize> + Copy + FromPrimitive + PartialEq,
{
fn eq(&self, other: &Self) -> bool {
assert!(self.leaders.len() == other.leaders.len(),
"Tried to compare equality on two UF with different sizes");
for i_idx in 0..self.leaders.len() {
let i = I::from_usize(i_idx).unwrap();
if self.find(i).into() != other.find(i).into() {
return false;
}
}
true
}
}
impl<I> Eq for UF<I>
where
I: Into<usize> + Copy + FromPrimitive + Eq,
{
}
impl<I> UF<I>
where
I: Into<usize> + Copy + FromPrimitive,
{
pub fn new_reflexive(size: I) -> Self {
let mut leaders: Vec<Cell<I>> = Vec::with_capacity(size.into());
for i in 0..size.into() {
leaders.push(Cell::new(I::from_usize(i).unwrap()))
}
UF { leaders, min_uncanonical: Cell::new(size) }
}
pub fn max(&self) -> I {
I::from_usize(self.leaders.len()).unwrap()
}
pub fn len(&self) -> usize {
self.leaders.len()
}
#[allow(dead_code)]
fn const_find(&self, mut i: I) -> I {
loop {
let l = self.leaders[i.into()].get();
if l.into() == i.into() {
return l
}
i = l;
}
}
pub fn find(&self, i: I) -> I {
if i.into() < self.min_uncanonical.get().into() {
return self.leaders[i.into()].get();
}
let cell = &self.leaders[i.into()];
let l = cell.get();
if i.into() == l.into() || self.leaders[l.into()].get().into() == l.into() {
l
} else {
let mut prev = i;
let mut this = l;
loop {
let next = self.leaders[this.into()].get();
if this.into() == next.into() {
break;
}
self.leaders[this.into()].set(prev);
prev = this;
this = next;
}
let res = this;
this = prev;
while this.into() != i.into() {
let next = self.leaders[this.into()].replace(res);
this = next;
}
self.leaders[i.into()].set(res);
res
}
}
pub fn union(&mut self, i: I, j: I) {
let l_i = self.find(i);
let l_j = self.find(j);
if l_i.into() < l_j.into() {
self.leaders[l_j.into()].set(l_i);
self.bump_min_uncanonical(l_j)
} else if l_j.into() < l_i.into() {
self.leaders[l_i.into()].set(l_j);
self.bump_min_uncanonical(l_j)
}
}
fn bump_min_uncanonical(&mut self, i: I) {
if i.into() < self.min_uncanonical.get().into() {
self.min_uncanonical.set(I::from_usize(i.into()).unwrap());
}
}
pub fn same_set(&self, i: I, j: I) -> bool {
self.find(i).into() == self.find(j).into()
}
pub fn equivalence_union(a: &Self, b: &Self) -> Self {
assert!(a.leaders.len() == b.leaders.len(), "Called equivalence_union on two UF of different sizes");
let mut res = a.clone();
for idx in 0..b.leaders.len() {
let i = I::from_usize(idx).unwrap();
res.union(i, b.find(i));
}
b.mark_canonical();
res
}
fn mark_canonical(&self) {
self.min_uncanonical.set(self.max());
}
#[allow(dead_code)]
fn slow_equivalence_intersection(a: &Self, b: &Self) -> Self {
assert!(a.leaders.len() == b.leaders.len(), "Called equivalence_union on two UF of different sizes");
let len = a.leaders.len();
let max_i = I::from_usize(len).unwrap();
let mut res = Self::new_reflexive(max_i);
for i_idx in 0..len {
let i = I::from_usize(i_idx).unwrap();
for j_idx in i_idx+1..len {
let j = I::from_usize(j_idx).unwrap();
if a.same_set(i,j) && b.same_set(i,j) {
res.union(i,j);
}
}
}
res
}
pub fn equivalence_intersection(a: &Self, b: &Self) -> Self {
assert!(a.leaders.len() == b.leaders.len(), "Called equivalence_union on two UF of different sizes");
let ap = a.as_permutation();
let bp = b.as_permutation();
let mut c = UF::new_reflexive(a.max());
for i in (0..ap.len()).rev() {
let mut ai = i;
let mut bi = i;
let mut anext = ap[ai].into();
let mut bnext = bp[bi].into();
loop {
while anext < ai && anext > bnext {
ai = anext;
anext = ap[ai].into();
}
if anext >= ai || anext == bnext {
break;
}
while bnext < bi && bnext > anext {
bi = bnext;
bnext = bp[bi].into();
}
if bnext >= bi || anext == bnext {
break;
}
}
if anext == bnext {
c.union(I::from_usize(i).unwrap(), I::from_usize(anext).unwrap());
}
}
c
}
#[allow(dead_code)]
fn canonicalize(&self) {
for idx in 0..self.leaders.len() {
let i = I::from_usize(idx).unwrap();
self.find(i);
}
self.mark_canonical();
}
pub fn as_permutation(&self) -> Vec<I> {
let mut res = Vec::with_capacity(self.leaders.len().into());
for idx in 0..self.len() {
let i = I::from_usize(idx).unwrap();
let j = self.find(i);
if i.into() == j.into() {
res.push(j);
} else {
res.push(res[j.into()]);
res[j.into()] = i;
}
}
self.mark_canonical();
res
}
#[allow(dead_code)]
fn assert_invariants(&self)
where
I: std::fmt::Display + std::fmt::Debug,
{
for idx in 0..self.leaders.len() {
let v = self.leaders[idx].get().into();
assert!(v <= idx, "leaders[{}] == {}, expected it to be <= {}", idx, v, idx);
}
for idx in 0..self.min_uncanonical.get().into() {
assert_eq!(self.leaders[idx].get().into(),
idx,
"index {} is less than {}, but isn't canonical",
idx, self.min_uncanonical.get());
}
}
#[allow(dead_code)]
unsafe fn from_slice(slice: &[I]) -> Self {
let leaders = slice
.iter()
.cloned()
.map(|v| Cell::new(v))
.collect();
UF { leaders, min_uncanonical: Cell::new(I::from_usize(0).unwrap()) }
}
#[allow(dead_code)]
unsafe fn struct_eq(&self, other: &Self) -> bool {
if self.leaders.len() != other.leaders.len() {
return false;
}
for i in 0..self.leaders.len() {
if self.leaders[i].get().into() != other.leaders[i].get().into() {
return false;
}
}
self.min_uncanonical.get().into() == other.min_uncanonical.get().into()
}
pub fn leaders<'a>(&'a self) -> LeadersIter<'a, I> {
LeadersIter { uf: self, next_i: I::from_usize(0).unwrap() }
}
}
pub struct LeadersIter<'a, I: Copy> {
uf: &'a UF<I>,
next_i: I,
}
impl<'a, I: Copy> Iterator for LeadersIter<'a,I>
where
I: Into<usize> + Copy + FromPrimitive,
{
type Item = I;
fn next(&mut self) -> Option<Self::Item> {
loop {
let i = self.next_i;
if i.into() >= self.uf.len() {
return None;
}
let l = self.uf.find(i);
self.next_i = I::from_usize(i.into() + 1).unwrap();
if i.into() == l.into() {
return Some(i);
}
}
}
}
#[cfg(test)]
mod tests {
type T = u16;
use super::UF;
use num_traits::FromPrimitive;
use rand::prelude::*;
fn test_rng() -> StdRng {
StdRng::seed_from_u64(0x0102030405060708_u64)
}
fn residue_class(len: T, m: T) -> UF<T> {
let mut res = UF::new_reflexive(len);
for i_idx in m.into()..len.into() {
let i = T::from_usize(i_idx).unwrap();
let j = i - m;
res.union(i,j);
}
res
}
fn assert_is_residue_class(m: T, a: &UF<T>) {
println!("checking if UF is residue_class {}", m);
let mut v = Vec::with_capacity(a.len());
for i in 0..a.max() {
v.push(i % m);
}
let b = unsafe { UF::from_slice(&v) };
a.canonicalize();
b.canonicalize();
assert!(unsafe { a.struct_eq(&b) });
}
fn random_uf(size: usize, max_unions: usize, rng: &mut StdRng) -> UF<T> {
let tsize = T::from_usize(size).unwrap();
let mut res = UF::new_reflexive(tsize);
let num_unions = rng.gen_range(0, max_unions);
for _ in 0..num_unions {
let i = rng.gen_range(0, tsize);
let j = rng.gen_range(0, tsize);
res.union(i, j);
}
res
}
fn test_intersections(a: &UF<T>, b: &UF<T>) {
let c1 = UF::slow_equivalence_intersection(a, b);
c1.canonicalize();
{
let c2 = UF::equivalence_intersection(a, b);
c2.canonicalize();
let res = unsafe { c1.struct_eq(&c2) };
if !res {
println!("a={:?}", a);
println!("b={:?}", b);
println!("c1={:?}", c1);
println!("c2={:?}", c2);
assert!(res);
}
}
{
let t = a;
let a = b;
let b = t;
let c2 = UF::slow_equivalence_intersection(a, b);
c2.canonicalize();
let res = unsafe { c1.struct_eq(&c2) };
if !res {
println!("a={:?}", a);
println!("b={:?}", b);
println!("c1={:?}", c1);
println!("c2={:?}", c2);
assert!(res);
}
}
}
fn do_iterator_test(a: &UF<T>) {
use std::collections::BTreeSet;
let mut s = BTreeSet::new();
for i in 0..a.max() {
if a.find(i) == i {
s.insert(i);
}
}
let comp: Vec<T> = s.into_iter().collect();
let leaders: Vec<T> = a.leaders().collect();
assert_eq!(comp, leaders, "iterator test failed for {:?}", a);
}
#[test]
fn do_random_intersection_tests() {
let mut rng = test_rng();
let ntests = 10000;
for _ in 0..ntests {
let size = rng.gen_range(10, 30);
let a = random_uf(size, 15, &mut rng);
do_iterator_test(&a);
let b = random_uf(size, 15, &mut rng);
do_iterator_test(&b);
test_intersections(&a, &b);
}
}
fn modular_residue_test(a: T, b: T, size: T) {
use num_integer::Integer;
println!("making xa");
let xa = residue_class(size, a);
println!("checking xa");
assert_is_residue_class(a, &xa);
println!("making xb");
let xb = residue_class(size, b);
println!("testing xb");
assert_is_residue_class(b, &xb);
println!("slowly making yca");
let c = a * b;
assert_eq!(a.gcd(&b), 1, "a and b are not relatively prime");
{
println!("quickly making ycb");
let ycb = UF::equivalence_intersection(&xa, &xb);
println!("testing ycb");
assert_is_residue_class(c, &ycb);
}
{
println!("quickly making ycc");
let ycb = UF::equivalence_intersection(&xb, &xa);
println!("testing ycb");
assert_is_residue_class(c, &ycb);
}
}
#[test]
fn lots_of_residue_tests() {
const PRIMES: [T; 10] =
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29];
for i in 0..PRIMES.len()-1 {
for j in i+1..PRIMES.len() {
modular_residue_test(PRIMES[i], PRIMES[j], PRIMES[i]*PRIMES[j]*4);
}
}
}
#[test]
fn synthetic_find_test() {
unsafe {
const T_VALS: [u16; 10] = [ 0, 0, 2, 1, 3, 4, 5, 6, 7, 8 ];
let mut t = UF::from_slice(&T_VALS[..]);
t.assert_invariants();
t.union(8, 9);
const U_VALS: [u16; 10] = [ 0, 0, 2, 0, 0, 0, 0, 0, 0, 0 ];
let u = UF::from_slice(&U_VALS[..]);
t.assert_invariants();
assert!(t.struct_eq(&u));
}
}
#[test]
fn permutation_test() {
unsafe {
const T_VALS: [u16; 10] = [ 0, 0, 2, 3, 2, 5, 6, 7, 7, 7 ];
let t = UF::from_slice(&T_VALS[..]);
t.assert_invariants();
const U_VALS: [u16; 10] = [ 1, 0, 4, 3, 2, 5, 6, 9, 7, 8 ];
let u = t.as_permutation();
assert_eq!(&U_VALS[..], &u[..]);
}
}
}