disjoint_sets/
async.rs

1use std::fmt::{self, Debug};
2use std::marker::{Send, Sync};
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5#[cfg(feature = "serde")]
6use serde::{Serialize, Serializer, Deserialize, Deserializer};
7
8/// Lock-free, concurrent union-find representing a set of disjoint sets.
9///
10/// # Warning
11///
12/// I don’t yet have good reason to believe that this is correct.
13#[derive(Clone)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct AUnionFind(Box<[Entry]>);
16
17struct Entry {
18    id:   AtomicUsize,
19    rank: AtomicUsize,
20}
21
22unsafe impl Send for AUnionFind {}
23unsafe impl Sync for AUnionFind {}
24
25impl Clone for Entry {
26    fn clone(&self) -> Self {
27        Entry::new(self.id.load(Ordering::SeqCst),
28                   self.rank.load(Ordering::SeqCst))
29    }
30}
31
32impl Debug for AUnionFind {
33    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34        write!(formatter, "AUnionFind(")?;
35        formatter.debug_list()
36            .entries(self.0.iter().map(|entry| &entry.id)).finish()?;
37        write!(formatter, ")")
38    }
39}
40
41impl Default for AUnionFind {
42    fn default() -> Self {
43        AUnionFind::new(0)
44    }
45}
46
47impl Entry {
48    fn new(id: usize, rank: usize) -> Self {
49        Entry {
50            id:   AtomicUsize::new(id),
51            rank: AtomicUsize::new(rank),
52        }
53    }
54}
55
56impl AUnionFind {
57    /// Creates a new asynchronous union-find of `size` elements.
58    pub fn new(size: usize) -> Self {
59        AUnionFind((0..size)
60            .map(|i| Entry::new(i, 0))
61            .collect::<Vec<_>>()
62            .into_boxed_slice())
63    }
64
65    /// The number of elements in all the sets.
66    pub fn len(&self) -> usize {
67        self.0.len()
68    }
69
70    /// Is the union-find devoid of elements?
71    ///
72    /// It is possible to create an empty `AUnionFind`, but unlike with
73    /// [`UnionFind`](struct.UnionFind.html) it is not possible to add
74    /// elements.
75    pub fn is_empty(&self) -> bool {
76        self.0.is_empty()
77    }
78
79    /// Joins the sets of the two given elements.
80    ///
81    /// Returns whether anything changed. That is, if the sets were
82    /// different, it returns `true`, but if they were already the same
83    /// then it returns `false`.
84    pub fn union(&self, mut a: usize, mut b: usize) -> bool {
85        loop {
86            a = self.find(a);
87            b = self.find(b);
88
89            if a == b { return false; }
90
91            let rank_a = self.rank(a);
92            let rank_b = self.rank(b);
93
94            if rank_a > rank_b {
95                if self.change_parent(b, b, a) { return true; }
96            } else if rank_b > rank_a {
97                if self.change_parent(a, a, b) { return true; }
98            } else if self.change_parent(a, a, b) {
99                self.increment_rank(b);
100                return true;
101            }
102        }
103    }
104
105    /// Finds the representative element for the given element’s set.
106    pub fn find(&self, mut element: usize) -> usize {
107        let mut parent = self.parent(element);
108
109        while element != parent {
110            let grandparent = self.parent(parent);
111            self.change_parent(element, parent, grandparent);
112            element = parent;
113            parent = grandparent;
114        }
115
116        element
117    }
118
119    /// Determines whether two elements are in the same set.
120    pub fn equiv(&self, mut a: usize, mut b: usize) -> bool {
121        loop {
122            a = self.find(a);
123            b = self.find(b);
124
125            if a == b { return true; }
126            if self.parent(a) == a { return false; }
127        }
128    }
129
130    /// Forces all laziness, so that each element points directly to its
131    /// set’s representative.
132    pub fn force(&self) {
133        for i in 0 .. self.len() {
134            loop {
135                let parent = self.parent(i);
136                if i == parent {
137                    break
138                } else {
139                    let root = self.find(parent);
140                    if parent == root || self.change_parent(i, parent, root) {
141                        break;
142                    }
143                }
144            }
145        }
146    }
147
148    /// Returns a vector of set representatives.
149    pub fn to_vec(&self) -> Vec<usize> {
150        self.force();
151        self.0.iter().map(|entry| entry.id.load(Ordering::SeqCst)).collect()
152    }
153
154    // HELPERS
155
156    fn rank(&self, element: usize) -> usize {
157        self.0[element].rank.load(Ordering::SeqCst)
158    }
159
160    fn increment_rank(&self, element: usize) {
161        self.0[element].rank.fetch_add(1, Ordering::SeqCst);
162    }
163
164    fn parent(&self, element: usize) -> usize {
165        self.0[element].id.load(Ordering::SeqCst)
166    }
167
168    fn change_parent(&self,
169                     element: usize,
170                     old_parent: usize,
171                     new_parent: usize)
172                     -> bool {
173        self.0[element].id.compare_and_swap(old_parent,
174                                            new_parent,
175                                            Ordering::SeqCst)
176            == old_parent
177    }
178}
179
180#[cfg(feature = "serde")]
181impl Serialize for Entry {
182    fn serialize<S: Serializer>(&self, serializer: S)
183                                -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
184    {
185        use serde::ser::SerializeStruct;
186
187        let mut tuple = serializer.serialize_struct("Entry", 2)?;
188        tuple.serialize_field("id", &self.id.load(Ordering::Relaxed))?;
189        tuple.serialize_field("rank", &self.rank.load(Ordering::Relaxed))?;
190        tuple.end()
191    }
192}
193
194#[cfg(feature = "serde")]
195impl<'de> Deserialize<'de> for Entry {
196    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
197        use serde::de::{self, Visitor, SeqAccess, MapAccess};
198
199        #[derive(Deserialize)]
200        #[serde(field_identifier, rename_all = "lowercase")]
201        enum Field { Id, Rank, }
202
203        struct EntryVisitor;
204
205        impl<'de> Visitor<'de> for EntryVisitor {
206            type Value = Entry;
207
208            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209                formatter.write_str("struct Entry")
210            }
211
212            fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
213                let id = seq.next_element()?
214                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
215                let rank = seq.next_element()?
216                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
217                Ok(Entry::new(id, rank))
218            }
219
220            fn visit_map<V: MapAccess<'de>>(self, mut map: V) -> Result<Self::Value, V::Error> {
221                let mut id   = None;
222                let mut rank = None;
223
224                while let Some(key) = map.next_key()? {
225                    match key {
226                        Field::Id => {
227                            if id.is_some() {
228                                return Err(de::Error::duplicate_field("id"));
229                            }
230                            id = Some(map.next_value()?);
231                        }
232                        Field::Rank => {
233                            if rank.is_some() {
234                                return Err(de::Error::duplicate_field("rank"));
235                            }
236                            rank = Some(map.next_value()?);
237                        }
238                    }
239                }
240
241                let id   = id.ok_or_else(|| de::Error::missing_field("id"))?;
242                let rank = rank.ok_or_else(|| de::Error::missing_field("rank"))?;
243
244                Ok(Entry::new(id, rank))
245            }
246        }
247
248        const FIELDS: &'static [&'static str] = &["id", "rank"];
249        deserializer.deserialize_struct("Entry", FIELDS, EntryVisitor)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn len() {
259        assert_eq!(5, AUnionFind::new(5).len());
260    }
261
262    #[test]
263    fn union() {
264        let uf = AUnionFind::new(8);
265        assert!(!uf.equiv(0, 1));
266        uf.union(0, 1);
267        assert!(uf.equiv(0, 1));
268    }
269
270    #[test]
271    fn unions() {
272        let uf = AUnionFind::new(8);
273        assert!(uf.union(0, 1));
274        assert!(uf.union(1, 2));
275
276        assert!(uf.union(4, 3));
277        assert!(uf.union(3, 2));
278        assert!(! uf.union(0, 3));
279
280        assert!(uf.equiv(0, 1));
281        assert!(uf.equiv(0, 2));
282        assert!(uf.equiv(0, 3));
283        assert!(uf.equiv(0, 4));
284        assert!(!uf.equiv(0, 5));
285
286        assert!(uf.union(5, 3));
287        assert!(uf.equiv(0, 5));
288
289        assert!(uf.union(6, 7));
290        assert!(uf.equiv(6, 7));
291        assert!(!uf.equiv(5, 7));
292
293        assert!(uf.union(0, 7));
294        assert!(uf.equiv(5, 7));
295    }
296
297    #[test]
298    fn changed() {
299        let uf = AUnionFind::new(8);
300        assert!(uf.union(2, 3));
301        assert!(uf.union(0, 1));
302        assert!(uf.union(1, 3));
303        assert!(!uf.union(0, 2))
304    }
305
306    // This assumes that for equal-ranked roots, the first argument
307    // to union is pointed to the second.
308    #[test]
309    fn to_vec() {
310        let uf = AUnionFind::new(6);
311        assert_eq!(uf.to_vec(), vec![0, 1, 2, 3, 4, 5]);
312        uf.union(0, 1);
313        assert_eq!(uf.to_vec(), vec![1, 1, 2, 3, 4, 5]);
314        uf.union(2, 3);
315        assert_eq!(uf.to_vec(), vec![1, 1, 3, 3, 4, 5]);
316        uf.union(1, 3);
317        assert_eq!(uf.to_vec(), vec![3, 3, 3, 3, 4, 5]);
318    }
319
320    #[cfg(feature = "serde")]
321    #[test]
322    fn serde_round_trip() {
323        extern crate serde_json;
324
325        let uf0 = AUnionFind::new(8);
326        uf0.union(0, 1);
327        uf0.union(2, 3);
328        assert!( uf0.equiv(0, 1));
329        assert!(!uf0.equiv(1, 2));
330        assert!( uf0.equiv(2, 3));
331
332        let json = serde_json::to_string(&uf0).unwrap();
333        let uf1: AUnionFind = serde_json::from_str(&json).unwrap();
334        assert!( uf1.equiv(0, 1));
335        assert!(!uf1.equiv(1, 2));
336        assert!( uf1.equiv(2, 3));
337    }
338}