pulsar_utils/
disjoint_set.rs

1// Copyright (C) 2024 Ethan Uppal. All rights reserved.
2use std::{collections::HashMap, fmt::Debug, hash::Hash, iter::Map};
3
4pub trait NodeTrait: Eq + Hash + Clone {}
5
6/// For a node `x`, when the node data is `(p, r)`, `x`'s parent is `p` and
7/// `x`'s rank is `r`.
8#[derive(Clone)]
9pub struct NodeData<T> {
10    parent: T,
11    rank: usize
12}
13
14/// A collection of disjoint sets over cheaply-cloned objects.
15pub struct DisjointSets<T: NodeTrait> {
16    nodes: HashMap<T, NodeData<T>>
17}
18
19impl<T: NodeTrait> DisjointSets<T> {
20    /// Constructs an empty disjoint set.
21    pub fn new() -> Self {
22        Self {
23            nodes: HashMap::new()
24        }
25    }
26
27    /// Adds a disjoint singleton `{v}` if `v` has not already been added.`
28    pub fn add(&mut self, v: T) {
29        if self.nodes.contains_key(&v) {
30            return;
31        }
32        let node_data = NodeData {
33            parent: v.clone(),
34            rank: 0
35        };
36        self.nodes.insert(v, node_data);
37    }
38
39    /// Finds the canonical representative of the set to which `v` belongs, if
40    /// `v` in fact has been added via a call to [`DisjointSets::add`].
41    pub fn find(&mut self, v: T) -> Option<T> {
42        let p = self.nodes.get(&v)?.parent.clone();
43        if v == p {
44            return Some(v);
45        }
46        let root = self.find(p)?;
47        self.nodes.get_mut(&v)?.parent = root.clone();
48        Some(root)
49    }
50
51    /// Merges the sets to which `a` and `b` belong to, returning their new
52    /// canonical representative. If `by_rank` is `true`, the union-by-rank
53    /// optimization is used, acheiving near-linear time complexity.
54    /// Otherwise, the canonical representative of `b` is chosen as the new
55    /// canonical representative, which leads to log-linear complexity.
56    pub fn union(&mut self, a: T, b: T, by_rank: bool) -> Option<T> {
57        let a = self.find(a)?;
58        let b = self.find(b)?;
59        if a != b {
60            if by_rank {
61                // Union-by-rank
62                let rank_a = self.nodes.get(&a)?.rank;
63                let rank_b = self.nodes.get(&b)?.rank;
64                if rank_a > rank_b {
65                    self.nodes.get_mut(&b)?.parent = a.clone();
66                } else {
67                    self.nodes.get_mut(&a)?.parent = b.clone();
68                    if rank_a == rank_b {
69                        self.nodes.get_mut(&b)?.rank += 1;
70                    }
71                }
72            } else {
73                // Use `b` as new parent
74                self.nodes.get_mut(&a)?.parent = b.clone();
75            }
76        }
77        Some(a)
78    }
79
80    /// Optimizes `find` and `union` access for all nodes.
81    pub fn collapse(&mut self) {
82        let keys = self.nodes.keys().cloned().collect::<Vec<_>>();
83        for key in keys {
84            self.find(key.clone());
85        }
86    }
87}
88
89impl<T: NodeTrait> Debug for DisjointSets<T>
90where
91    T: Debug
92{
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        for (i, (node, data)) in self.nodes.iter().enumerate() {
95            if i > 0 {
96                writeln!(f)?;
97            }
98            write!(f, "{:?} -> {:?}", node, data.parent)?;
99        }
100        Ok(())
101    }
102}
103
104impl<'a, T: NodeTrait> IntoIterator for &'a DisjointSets<T> {
105    type Item = (&'a T, &'a T);
106    type IntoIter = Map<
107        std::collections::hash_map::Iter<'a, T, NodeData<T>>,
108        fn((&'a T, &'a NodeData<T>)) -> (&'a T, &'a T)
109    >;
110
111    fn into_iter(self) -> Self::IntoIter {
112        self.nodes.iter().map(|(node, data)| (node, &data.parent))
113    }
114}