disjoint_hash_set/
lib.rs

1//! A disjoint set / union-find data structure suitable for incremental
2//! tracking of connected component identified by their hash.
3//!
4//! The total number of components does not need to be known in advance.
5//! Connections between components and the components themselves can be added
6//! as they are discovered.
7//!
8//! Employs rank-based set joins and path compression resulting in the
9//! asymptotically optimal time complexity associated with union-find
10//! algorithms.
11//!
12//! ## Examples
13//! ```
14//! use disjoint_hash_set::DisjointHashSet;
15//!
16//! let mut djhs = DisjointHashSet::new();
17//! djhs.link("hello", "hi");
18//! djhs.link("hello", "👋");
19//! assert!(djhs.is_linked("hi", "👋"));
20//!
21//! // `DisjointHashSet` can be built from an iterator of edges
22//! let djhs = vec![("a", "b"), ("a", "c"), ("d", "e"), ("f", "f")]
23//!     .into_iter()
24//!     .collect::<DisjointHashSet<_>>();
25//!
26//! // Consume djhs to iterate over each disjoint set
27//! let sets = djhs.sets(); // looks like [{"a", "b", "c"}, {"d", "e"}, {"f"}]
28//! assert_eq!(sets.count(), 3);
29//! ```
30
31use std::{
32    borrow::Borrow,
33    collections::{HashMap, HashSet},
34    hash::Hash,
35};
36
37#[derive(Debug, Clone)]
38pub struct DisjointHashSet<K> {
39    ids: HashMap<K, PointerId>,
40    data: Vec<ParentPointer>,
41}
42
43impl<K: Eq + Hash> DisjointHashSet<K> {
44    /// Creates an empty `DisjointHashSet`.
45    ///
46    /// # Example
47    /// ```
48    /// use disjoint_hash_set ::DisjointHashSet;
49    /// let mut djhs: DisjointHashSet<&str> = DisjointHashSet::new();
50    /// ```
51    pub fn new() -> Self {
52        Self { ids: HashMap::new(), data: Vec::new() }
53    }
54
55    /// Check if the value has already been inserted.
56    ///
57    /// # Example
58    /// ```
59    /// use disjoint_hash_set::DisjointHashSet;
60    ///
61    /// let mut djhs = DisjointHashSet::new();
62    /// assert!(!djhs.contains(&"a"));
63    /// djhs.insert(&"a");
64    /// assert!(djhs.contains(&"a"));
65    /// ```
66    pub fn contains<T: Borrow<K>>(&self, val: T) -> bool {
67        self.id(val.borrow()).is_some()
68    }
69
70    /// Insert the value as a new disjoint set with a single member. Returns
71    /// true if the value was not already present.
72    ///
73    /// ```
74    /// use disjoint_hash_set::DisjointHashSet;
75    ///
76    /// let mut djhs = DisjointHashSet::new();
77    /// assert!(djhs.insert(&"a"));
78    /// assert!(!djhs.insert(&"a"));
79    /// ```
80    pub fn insert(&mut self, val: K) -> bool {
81        (!self.contains(&val)).then(|| self.insert_unchecked(val)).is_some()
82    }
83
84    /// Checks if the two keys are members of the same set.
85    /// This will not implicitly add values that were not already present.
86    /// ```
87    /// use disjoint_hash_set::DisjointHashSet;
88    ///
89    /// let mut djhs = DisjointHashSet::new();
90    /// djhs.link("a", "b");
91    /// djhs.link("a", "c");
92    /// assert!(djhs.is_linked("b", "c"));
93    /// assert!(!djhs.is_linked("a", "d"));
94    /// ```
95    pub fn is_linked<T: Borrow<K>>(&mut self, val1: T, val2: T) -> bool {
96        let (id1, id2) = (
97            self.id(val1.borrow()).map(|id| self.find(id)),
98            self.id(val2.borrow()).map(|id| self.find(id)),
99        );
100
101        id1.is_some() && id2.is_some() && id1 == id2
102    }
103
104    /// Link the respective sets of the two provided values. This will insert
105    /// non-existent values in the process.
106    /// ```
107    /// use disjoint_hash_set::DisjointHashSet;
108    ///
109    /// let mut djhs = DisjointHashSet::new();
110    ///
111    /// djhs.link("a", "b");
112    /// assert!(djhs.contains("a"));
113    /// assert!(djhs.contains("b"));
114    /// assert!(djhs.is_linked("a", "b"));
115    /// ```
116    pub fn link(&mut self, val1: K, val2: K) {
117        let ids = (self.id_or_insert(val1), self.id_or_insert(val2));
118        let roots = (self.find(ids.0), self.find(ids.1));
119
120        if roots.0 != roots.1 {
121            let ranks = (self.get(roots.0).rank, self.get(roots.1).rank);
122
123            if ranks.0 < ranks.1 {
124                self.get_mut(roots.0).parent = roots.1;
125            } else {
126                self.get_mut(roots.1).parent = roots.0;
127
128                if ranks.0 == ranks.1 {
129                    self.get_mut(roots.0).rank += 1;
130                };
131            }
132        }
133    }
134
135    /// Consumes the DisjointHashSet and returns an iterator of HashSets for
136    /// each disjoint set.
137    ///
138    /// ```
139    /// use disjoint_hash_set::DisjointHashSet;
140    /// use std::{collections::HashSet, iter::FromIterator};
141    ///
142    /// let edges = vec![("a", "a"), ("b", "c"), ("d", "e"), ("e", "f")];
143    /// let mut sets = DisjointHashSet::from_iter(edges).sets().collect::<Vec<_>>();
144    /// sets.sort_by(|set_a, set_b| set_a.len().cmp(&set_b.len()));
145    ///
146    /// let expected_sets: Vec<HashSet<&str>> = vec![
147    ///     HashSet::from_iter(vec!["a"]),
148    ///     HashSet::from_iter(vec!["b", "c"]),
149    ///     HashSet::from_iter(vec!["d", "e", "f"]),
150    /// ];
151    ///
152    /// assert_eq!(sets, expected_sets);
153    /// ```
154    pub fn sets(mut self) -> impl Iterator<Item = HashSet<K>> {
155        let mut sets = HashMap::new();
156
157        let roots: Vec<PointerId> =
158            (0..self.data.len()).map(|id| self.find(PointerId(id))).collect();
159
160        self.ids.into_iter().for_each(|(val, id)| {
161            sets.entry(roots[id.0]).or_insert_with(HashSet::new).insert(val);
162        });
163
164        sets.into_iter().map(|(_, set)| set)
165    }
166
167    fn find(&mut self, id: PointerId) -> PointerId {
168        let parent_id = self.get(id).parent;
169        let grandparent_id = self.get(parent_id).parent;
170        if parent_id == grandparent_id {
171            parent_id
172        } else {
173            let root_id = self.find(parent_id);
174            self.get_mut(id).parent = root_id;
175            root_id
176        }
177    }
178
179    fn id(&self, value: &K) -> Option<PointerId> {
180        self.ids.get(value).copied()
181    }
182
183    fn id_or_insert(&mut self, value: K) -> PointerId {
184        self.id(&value).unwrap_or_else(|| self.insert_unchecked(value))
185    }
186
187    fn insert_unchecked(&mut self, value: K) -> PointerId {
188        let id = PointerId(self.data.len());
189        self.ids.insert(value, id);
190        self.data.push(ParentPointer { parent: id, rank: 0 });
191        id
192    }
193
194    fn get(&self, id: PointerId) -> &ParentPointer {
195        &self.data[id.0]
196    }
197
198    fn get_mut(&mut self, id: PointerId) -> &mut ParentPointer {
199        &mut self.data[id.0]
200    }
201}
202
203#[derive(Debug, Clone)]
204struct ParentPointer {
205    parent: PointerId,
206    rank: u8,
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
210struct PointerId(usize);
211
212impl<V: Eq + Hash> FromIterator<(V, V)> for DisjointHashSet<V> {
213    fn from_iter<I: IntoIterator<Item = (V, V)>>(links: I) -> Self {
214        let mut djhs = DisjointHashSet::new();
215        links.into_iter().for_each(|(a, b)| djhs.link(a, b));
216        djhs
217    }
218}