use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DisjointSet<T: Clone + std::hash::Hash + Eq> {
parent: HashMap<T, T>,
rank: HashMap<T, usize>,
num_sets: usize,
}
impl<T: Clone + std::hash::Hash + Eq> DisjointSet<T> {
pub fn new() -> Self {
Self {
parent: HashMap::new(),
rank: HashMap::new(),
num_sets: 0,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
parent: HashMap::with_capacity(capacity),
rank: HashMap::with_capacity(capacity),
num_sets: 0,
}
}
pub fn make_set(&mut self, x: T) {
if !self.parent.contains_key(&x) {
self.parent.insert(x.clone(), x.clone());
self.rank.insert(x, 0);
self.num_sets += 1;
}
}
pub fn find(&mut self, x: &T) -> Option<T> {
if !self.parent.contains_key(x) {
return None;
}
let mut current = x.clone();
let mut path = Vec::new();
while self.parent[¤t] != current {
path.push(current.clone());
current = self.parent[¤t].clone();
}
for node in path {
self.parent.insert(node, current.clone());
}
Some(current)
}
pub fn union(&mut self, x: T, y: T) -> bool {
let root_x = match self.find(&x) {
Some(root) => root,
None => return false,
};
let root_y = match self.find(&y) {
Some(root) => root,
None => return false,
};
if root_x == root_y {
return false; }
let rank_x = self.rank[&root_x];
let rank_y = self.rank[&root_y];
match rank_x.cmp(&rank_y) {
std::cmp::Ordering::Less => {
self.parent.insert(root_x, root_y);
}
std::cmp::Ordering::Greater => {
self.parent.insert(root_y, root_x);
}
std::cmp::Ordering::Equal => {
self.parent.insert(root_y, root_x.clone());
self.rank.insert(root_x, rank_x + 1);
}
}
self.num_sets -= 1;
true
}
pub fn connected(&mut self, x: &T, y: &T) -> bool {
match (self.find(x), self.find(y)) {
(Some(root_x), Some(root_y)) => root_x == root_y,
_ => false,
}
}
pub fn contains(&self, x: &T) -> bool {
self.parent.contains_key(x)
}
pub fn num_sets(&self) -> usize {
self.num_sets
}
pub fn size(&self) -> usize {
self.parent.len()
}
pub fn is_empty(&self) -> bool {
self.parent.is_empty()
}
pub fn get_set_members(&mut self, x: &T) -> Option<Vec<T>> {
let target_root = self.find(x)?;
let mut members = Vec::new();
let elements_to_check: Vec<T> = self.parent.keys().cloned().collect();
for element in elements_to_check {
if let Some(root) = self.find(&element) {
if root == target_root {
members.push(element);
}
}
}
Some(members)
}
pub fn get_all_sets(&mut self) -> Vec<Vec<T>> {
let mut sets_map: HashMap<T, Vec<T>> = HashMap::new();
for element in self.parent.keys().cloned().collect::<Vec<_>>() {
if let Some(root) = self.find(&element) {
sets_map.entry(root).or_default().push(element);
}
}
sets_map.into_values().collect()
}
pub fn clear(&mut self) {
self.parent.clear();
self.rank.clear();
self.num_sets = 0;
}
}
impl<T: Clone + std::hash::Hash + Eq> Default for DisjointSet<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let mut ds = DisjointSet::new();
assert_eq!(ds.size(), 0);
assert_eq!(ds.num_sets(), 0);
assert!(ds.is_empty());
ds.make_set(1);
ds.make_set(2);
ds.make_set(3);
assert_eq!(ds.size(), 3);
assert_eq!(ds.num_sets(), 3);
assert!(!ds.is_empty());
assert!(ds.contains(&1));
assert!(ds.contains(&2));
assert!(ds.contains(&3));
assert!(!ds.contains(&4));
}
#[test]
fn test_union_find() {
let mut ds = DisjointSet::new();
ds.make_set(1);
ds.make_set(2);
ds.make_set(3);
ds.make_set(4);
assert!(!ds.connected(&1, &2));
assert!(!ds.connected(&3, &4));
assert!(ds.union(1, 2));
assert_eq!(ds.num_sets(), 3);
assert!(ds.connected(&1, &2));
assert!(!ds.connected(&1, &3));
assert!(ds.union(3, 4));
assert_eq!(ds.num_sets(), 2);
assert!(ds.connected(&3, &4));
assert!(!ds.connected(&1, &3));
assert!(ds.union(1, 3));
assert_eq!(ds.num_sets(), 1);
assert!(ds.connected(&1, &3));
assert!(ds.connected(&2, &4));
assert!(!ds.union(1, 2));
assert_eq!(ds.num_sets(), 1);
}
#[test]
fn test_path_compression() {
let mut ds = DisjointSet::new();
ds.make_set(1);
ds.make_set(2);
ds.make_set(3);
ds.make_set(4);
ds.union(1, 2);
ds.union(2, 3);
ds.union(3, 4);
let root1 = ds.find(&1).expect("Operation failed");
let root2 = ds.find(&2).expect("Operation failed");
let root3 = ds.find(&3).expect("Operation failed");
let root4 = ds.find(&4).expect("Operation failed");
assert_eq!(root1, root2);
assert_eq!(root2, root3);
assert_eq!(root3, root4);
}
#[test]
fn test_get_set_members() {
let mut ds = DisjointSet::new();
ds.make_set(1);
ds.make_set(2);
ds.make_set(3);
ds.make_set(4);
ds.union(1, 2);
ds.union(3, 4);
let members1 = ds.get_set_members(&1).expect("Operation failed");
assert_eq!(members1.len(), 2);
assert!(members1.contains(&1));
assert!(members1.contains(&2));
let members3 = ds.get_set_members(&3).expect("Operation failed");
assert_eq!(members3.len(), 2);
assert!(members3.contains(&3));
assert!(members3.contains(&4));
assert!(ds.get_set_members(&5).is_none());
}
#[test]
fn test_get_all_sets() {
let mut ds = DisjointSet::new();
ds.make_set(1);
ds.make_set(2);
ds.make_set(3);
ds.make_set(4);
ds.make_set(5);
ds.union(1, 2);
ds.union(3, 4);
let all_sets = ds.get_all_sets();
assert_eq!(all_sets.len(), 3);
let mut set_sizes: Vec<usize> = all_sets.iter().map(|s| s.len()).collect();
set_sizes.sort();
assert_eq!(set_sizes, vec![1, 2, 2]);
}
#[test]
fn test_edge_cases() {
let mut ds = DisjointSet::new();
assert!(!ds.union(1, 2));
assert!(ds.find(&1).is_none());
assert!(!ds.connected(&1, &2));
ds.make_set(1);
ds.make_set(1); assert_eq!(ds.size(), 1);
assert_eq!(ds.num_sets(), 1);
}
#[test]
fn test_clear() {
let mut ds = DisjointSet::new();
ds.make_set(1);
ds.make_set(2);
ds.union(1, 2);
assert_eq!(ds.size(), 2);
assert_eq!(ds.num_sets(), 1);
ds.clear();
assert_eq!(ds.size(), 0);
assert_eq!(ds.num_sets(), 0);
assert!(ds.is_empty());
}
#[test]
fn test_with_strings() {
let mut ds = DisjointSet::new();
ds.make_set("alice".to_string());
ds.make_set("bob".to_string());
ds.make_set("charlie".to_string());
ds.union("alice".to_string(), "bob".to_string());
assert!(ds.connected(&"alice".to_string(), &"bob".to_string()));
assert!(!ds.connected(&"alice".to_string(), &"charlie".to_string()));
}
#[test]
fn test_large_dataset() {
let mut ds = DisjointSet::with_capacity(1000);
for i in 0..1000 {
ds.make_set(i);
}
assert_eq!(ds.size(), 1000);
assert_eq!(ds.num_sets(), 1000);
for i in (0..1000).step_by(2) {
ds.union(i, i + 1);
}
assert_eq!(ds.num_sets(), 500);
assert!(ds.connected(&0, &1));
assert!(ds.connected(&998, &999));
assert!(!ds.connected(&0, &2));
}
}