use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct UnionFind<T: Eq + Hash + Clone + Ord> {
parent: HashMap<T, T>,
rank: HashMap<T, usize>,
}
impl<T: Eq + Hash + Clone + Ord> Default for UnionFind<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Eq + Hash + Clone + Ord> UnionFind<T> {
pub fn new() -> Self {
Self {
parent: HashMap::new(),
rank: HashMap::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
parent: HashMap::with_capacity(capacity),
rank: HashMap::with_capacity(capacity),
}
}
pub fn make_set(&mut self, item: T) {
if !self.parent.contains_key(&item) {
self.parent.insert(item.clone(), item.clone());
self.rank.insert(item, 0);
}
}
pub fn find(&mut self, item: &T) -> T {
if !self.parent.contains_key(item) {
self.parent.insert(item.clone(), item.clone());
self.rank.insert(item.clone(), 0);
return item.clone();
}
let current = self.parent.get(item).cloned().unwrap();
if ¤t == item {
return item.clone();
}
let root = self.find(¤t);
self.parent.insert(item.clone(), root.clone());
root
}
pub fn union(&mut self, a: &T, b: &T) -> bool {
let root_a = self.find(a);
let root_b = self.find(b);
if root_a == root_b {
return false;
}
let rank_a = *self.rank.get(&root_a).unwrap_or(&0);
let rank_b = *self.rank.get(&root_b).unwrap_or(&0);
if rank_a < rank_b {
self.parent.insert(root_a, root_b);
} else if rank_a > rank_b {
self.parent.insert(root_b, root_a);
} else {
self.parent.insert(root_b, root_a.clone());
self.rank.insert(root_a, rank_a + 1);
}
true
}
pub fn connected(&mut self, a: &T, b: &T) -> bool {
self.find(a) == self.find(b)
}
pub fn groups(&mut self) -> HashMap<T, Vec<T>> {
let mut items: Vec<T> = self.parent.keys().cloned().collect();
items.sort(); let mut groups: HashMap<T, Vec<T>> = HashMap::new();
for item in items {
let root = self.find(&item);
groups.entry(root).or_default().push(item);
}
for members in groups.values_mut() {
members.sort();
}
groups
}
pub fn len(&self) -> usize {
self.parent.len()
}
pub fn is_empty(&self) -> bool {
self.parent.is_empty()
}
}
pub mod string_uf {
use super::*;
pub fn from_ids<'a, I>(ids: I) -> UnionFind<String>
where
I: IntoIterator<Item = &'a str>,
{
let mut uf = UnionFind::new();
for id in ids {
uf.make_set(id.to_string());
}
uf
}
pub fn find(uf: &mut UnionFind<String>, id: &str) -> String {
uf.find(&id.to_string())
}
pub fn union(uf: &mut UnionFind<String>, a: &str, b: &str) -> bool {
uf.union(&a.to_string(), &b.to_string())
}
}