1use std::fmt::{self, Debug};
2use std::marker::{Send, Sync};
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5#[cfg(feature = "serde")]
6use serde::{Serialize, Serializer, Deserialize, Deserializer};
7
8#[derive(Clone)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15pub struct AUnionFind(Box<[Entry]>);
16
17struct Entry {
18 id: AtomicUsize,
19 rank: AtomicUsize,
20}
21
22unsafe impl Send for AUnionFind {}
23unsafe impl Sync for AUnionFind {}
24
25impl Clone for Entry {
26 fn clone(&self) -> Self {
27 Entry::new(self.id.load(Ordering::SeqCst),
28 self.rank.load(Ordering::SeqCst))
29 }
30}
31
32impl Debug for AUnionFind {
33 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34 write!(formatter, "AUnionFind(")?;
35 formatter.debug_list()
36 .entries(self.0.iter().map(|entry| &entry.id)).finish()?;
37 write!(formatter, ")")
38 }
39}
40
41impl Default for AUnionFind {
42 fn default() -> Self {
43 AUnionFind::new(0)
44 }
45}
46
47impl Entry {
48 fn new(id: usize, rank: usize) -> Self {
49 Entry {
50 id: AtomicUsize::new(id),
51 rank: AtomicUsize::new(rank),
52 }
53 }
54}
55
56impl AUnionFind {
57 pub fn new(size: usize) -> Self {
59 AUnionFind((0..size)
60 .map(|i| Entry::new(i, 0))
61 .collect::<Vec<_>>()
62 .into_boxed_slice())
63 }
64
65 pub fn len(&self) -> usize {
67 self.0.len()
68 }
69
70 pub fn is_empty(&self) -> bool {
76 self.0.is_empty()
77 }
78
79 pub fn union(&self, mut a: usize, mut b: usize) -> bool {
85 loop {
86 a = self.find(a);
87 b = self.find(b);
88
89 if a == b { return false; }
90
91 let rank_a = self.rank(a);
92 let rank_b = self.rank(b);
93
94 if rank_a > rank_b {
95 if self.change_parent(b, b, a) { return true; }
96 } else if rank_b > rank_a {
97 if self.change_parent(a, a, b) { return true; }
98 } else if self.change_parent(a, a, b) {
99 self.increment_rank(b);
100 return true;
101 }
102 }
103 }
104
105 pub fn find(&self, mut element: usize) -> usize {
107 let mut parent = self.parent(element);
108
109 while element != parent {
110 let grandparent = self.parent(parent);
111 self.change_parent(element, parent, grandparent);
112 element = parent;
113 parent = grandparent;
114 }
115
116 element
117 }
118
119 pub fn equiv(&self, mut a: usize, mut b: usize) -> bool {
121 loop {
122 a = self.find(a);
123 b = self.find(b);
124
125 if a == b { return true; }
126 if self.parent(a) == a { return false; }
127 }
128 }
129
130 pub fn force(&self) {
133 for i in 0 .. self.len() {
134 loop {
135 let parent = self.parent(i);
136 if i == parent {
137 break
138 } else {
139 let root = self.find(parent);
140 if parent == root || self.change_parent(i, parent, root) {
141 break;
142 }
143 }
144 }
145 }
146 }
147
148 pub fn to_vec(&self) -> Vec<usize> {
150 self.force();
151 self.0.iter().map(|entry| entry.id.load(Ordering::SeqCst)).collect()
152 }
153
154 fn rank(&self, element: usize) -> usize {
157 self.0[element].rank.load(Ordering::SeqCst)
158 }
159
160 fn increment_rank(&self, element: usize) {
161 self.0[element].rank.fetch_add(1, Ordering::SeqCst);
162 }
163
164 fn parent(&self, element: usize) -> usize {
165 self.0[element].id.load(Ordering::SeqCst)
166 }
167
168 fn change_parent(&self,
169 element: usize,
170 old_parent: usize,
171 new_parent: usize)
172 -> bool {
173 self.0[element].id.compare_and_swap(old_parent,
174 new_parent,
175 Ordering::SeqCst)
176 == old_parent
177 }
178}
179
180#[cfg(feature = "serde")]
181impl Serialize for Entry {
182 fn serialize<S: Serializer>(&self, serializer: S)
183 -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
184 {
185 use serde::ser::SerializeStruct;
186
187 let mut tuple = serializer.serialize_struct("Entry", 2)?;
188 tuple.serialize_field("id", &self.id.load(Ordering::Relaxed))?;
189 tuple.serialize_field("rank", &self.rank.load(Ordering::Relaxed))?;
190 tuple.end()
191 }
192}
193
194#[cfg(feature = "serde")]
195impl<'de> Deserialize<'de> for Entry {
196 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
197 use serde::de::{self, Visitor, SeqAccess, MapAccess};
198
199 #[derive(Deserialize)]
200 #[serde(field_identifier, rename_all = "lowercase")]
201 enum Field { Id, Rank, }
202
203 struct EntryVisitor;
204
205 impl<'de> Visitor<'de> for EntryVisitor {
206 type Value = Entry;
207
208 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209 formatter.write_str("struct Entry")
210 }
211
212 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
213 let id = seq.next_element()?
214 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
215 let rank = seq.next_element()?
216 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
217 Ok(Entry::new(id, rank))
218 }
219
220 fn visit_map<V: MapAccess<'de>>(self, mut map: V) -> Result<Self::Value, V::Error> {
221 let mut id = None;
222 let mut rank = None;
223
224 while let Some(key) = map.next_key()? {
225 match key {
226 Field::Id => {
227 if id.is_some() {
228 return Err(de::Error::duplicate_field("id"));
229 }
230 id = Some(map.next_value()?);
231 }
232 Field::Rank => {
233 if rank.is_some() {
234 return Err(de::Error::duplicate_field("rank"));
235 }
236 rank = Some(map.next_value()?);
237 }
238 }
239 }
240
241 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
242 let rank = rank.ok_or_else(|| de::Error::missing_field("rank"))?;
243
244 Ok(Entry::new(id, rank))
245 }
246 }
247
248 const FIELDS: &'static [&'static str] = &["id", "rank"];
249 deserializer.deserialize_struct("Entry", FIELDS, EntryVisitor)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn len() {
259 assert_eq!(5, AUnionFind::new(5).len());
260 }
261
262 #[test]
263 fn union() {
264 let uf = AUnionFind::new(8);
265 assert!(!uf.equiv(0, 1));
266 uf.union(0, 1);
267 assert!(uf.equiv(0, 1));
268 }
269
270 #[test]
271 fn unions() {
272 let uf = AUnionFind::new(8);
273 assert!(uf.union(0, 1));
274 assert!(uf.union(1, 2));
275
276 assert!(uf.union(4, 3));
277 assert!(uf.union(3, 2));
278 assert!(! uf.union(0, 3));
279
280 assert!(uf.equiv(0, 1));
281 assert!(uf.equiv(0, 2));
282 assert!(uf.equiv(0, 3));
283 assert!(uf.equiv(0, 4));
284 assert!(!uf.equiv(0, 5));
285
286 assert!(uf.union(5, 3));
287 assert!(uf.equiv(0, 5));
288
289 assert!(uf.union(6, 7));
290 assert!(uf.equiv(6, 7));
291 assert!(!uf.equiv(5, 7));
292
293 assert!(uf.union(0, 7));
294 assert!(uf.equiv(5, 7));
295 }
296
297 #[test]
298 fn changed() {
299 let uf = AUnionFind::new(8);
300 assert!(uf.union(2, 3));
301 assert!(uf.union(0, 1));
302 assert!(uf.union(1, 3));
303 assert!(!uf.union(0, 2))
304 }
305
306 #[test]
309 fn to_vec() {
310 let uf = AUnionFind::new(6);
311 assert_eq!(uf.to_vec(), vec![0, 1, 2, 3, 4, 5]);
312 uf.union(0, 1);
313 assert_eq!(uf.to_vec(), vec![1, 1, 2, 3, 4, 5]);
314 uf.union(2, 3);
315 assert_eq!(uf.to_vec(), vec![1, 1, 3, 3, 4, 5]);
316 uf.union(1, 3);
317 assert_eq!(uf.to_vec(), vec![3, 3, 3, 3, 4, 5]);
318 }
319
320 #[cfg(feature = "serde")]
321 #[test]
322 fn serde_round_trip() {
323 extern crate serde_json;
324
325 let uf0 = AUnionFind::new(8);
326 uf0.union(0, 1);
327 uf0.union(2, 3);
328 assert!( uf0.equiv(0, 1));
329 assert!(!uf0.equiv(1, 2));
330 assert!( uf0.equiv(2, 3));
331
332 let json = serde_json::to_string(&uf0).unwrap();
333 let uf1: AUnionFind = serde_json::from_str(&json).unwrap();
334 assert!( uf1.equiv(0, 1));
335 assert!(!uf1.equiv(1, 2));
336 assert!( uf1.equiv(2, 3));
337 }
338}