1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
// Copyright (C) 2024 Ethan Uppal. All rights reserved.
use std::{collections::HashMap, fmt::Debug, hash::Hash, iter::Map};
pub trait NodeTrait: Eq + Hash + Clone {}
/// For a node `x`, when the node data is `(p, r)`, `x`'s parent is `p` and
/// `x`'s rank is `r`.
#[derive(Clone)]
pub struct NodeData<T> {
parent: T,
rank: usize
}
/// A collection of disjoint sets over cheaply-cloned objects.
pub struct DisjointSets<T: NodeTrait> {
nodes: HashMap<T, NodeData<T>>
}
impl<T: NodeTrait> DisjointSets<T> {
/// Constructs an empty disjoint set.
pub fn new() -> Self {
Self {
nodes: HashMap::new()
}
}
/// Adds a disjoint singleton `{v}` if `v` has not already been added.`
pub fn add(&mut self, v: T) {
if self.nodes.contains_key(&v) {
return;
}
let node_data = NodeData {
parent: v.clone(),
rank: 0
};
self.nodes.insert(v, node_data);
}
/// Finds the canonical representative of the set to which `v` belongs, if
/// `v` in fact has been added via a call to [`DisjointSets::add`].
pub fn find(&mut self, v: T) -> Option<T> {
let p = self.nodes.get(&v)?.parent.clone();
if v == p {
return Some(v);
}
let root = self.find(p)?;
self.nodes.get_mut(&v)?.parent = root.clone();
Some(root)
}
/// Merges the sets to which `a` and `b` belong to, returning their new
/// canonical representative. If `by_rank` is `true`, the union-by-rank
/// optimization is used, acheiving near-linear time complexity.
/// Otherwise, the canonical representative of `b` is chosen as the new
/// canonical representative, which leads to log-linear complexity.
pub fn union(&mut self, a: T, b: T, by_rank: bool) -> Option<T> {
let a = self.find(a)?;
let b = self.find(b)?;
if a != b {
if by_rank {
// Union-by-rank
let rank_a = self.nodes.get(&a)?.rank;
let rank_b = self.nodes.get(&b)?.rank;
if rank_a > rank_b {
self.nodes.get_mut(&b)?.parent = a.clone();
} else {
self.nodes.get_mut(&a)?.parent = b.clone();
if rank_a == rank_b {
self.nodes.get_mut(&b)?.rank += 1;
}
}
} else {
// Use `b` as new parent
self.nodes.get_mut(&a)?.parent = b.clone();
}
}
Some(a)
}
/// Optimizes `find` and `union` access for all nodes.
pub fn collapse(&mut self) {
let keys = self.nodes.keys().cloned().collect::<Vec<_>>();
for key in keys {
self.find(key.clone());
}
}
}
impl<T: NodeTrait> Debug for DisjointSets<T>
where
T: Debug
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut i = 0;
for (node, data) in &self.nodes {
if i > 0 {
writeln!(f)?;
}
write!(f, "{:?} -> {:?}", node, data.parent)?;
i += 1;
}
Ok(())
}
}
impl<'a, T: NodeTrait> IntoIterator for &'a DisjointSets<T> {
type Item = (&'a T, &'a T);
type IntoIter = Map<
std::collections::hash_map::Iter<'a, T, NodeData<T>>,
fn((&'a T, &'a NodeData<T>)) -> (&'a T, &'a T)
>;
fn into_iter(self) -> Self::IntoIter {
self.nodes.iter().map(|(node, data)| (node, &data.parent))
}
}