1use std::cell::Cell;
2use std::fmt::{self, Debug};
3
4use super::ElementType;
5
6#[derive(Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub struct UnionFind<Element: ElementType = usize> {
10 elements: Vec<Cell<Element>>,
11 ranks: Vec<u8>,
12}
13impl<Element: Debug + ElementType> Debug for UnionFind<Element> {
16 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
17 write!(formatter, "UnionFind({:?})", self.elements)
18 }
19}
20
21impl<Element: ElementType> Default for UnionFind<Element> {
22 fn default() -> Self {
23 UnionFind::new(0)
24 }
25}
26
27impl<Element: ElementType> UnionFind<Element> {
28 pub fn new(size: usize) -> Self {
34 UnionFind {
35 elements: (0..size).map(|i| {
36 let e = Element::from_usize(i).expect("UnionFind::new: overflow");
37 Cell::new(e)
38 }).collect(),
39 ranks: vec![0; size],
40 }
41 }
42
43 pub fn len(&self) -> usize {
45 self.elements.len()
46 }
47
48 pub fn is_empty(&self) -> bool {
53 self.elements.is_empty()
54 }
55
56 pub fn alloc(&mut self) -> Element {
63 let result = Element::from_usize(self.elements.len())
64 .expect("UnionFind::alloc: overflow");
65 self.elements.push(Cell::new(result));
66 self.ranks.push(0);
67 result
68 }
69
70 pub fn union(&mut self, a: Element, b: Element) -> bool {
76 let a = self.find(a);
77 let b = self.find(b);
78
79 if a == b { return false; }
80
81 let rank_a = self.rank(a);
82 let rank_b = self.rank(b);
83
84 if rank_a > rank_b {
85 self.set_parent(b, a);
86 } else if rank_b > rank_a {
87 self.set_parent(a, b);
88 } else {
89 self.set_parent(a, b);
90 self.increment_rank(b);
91 }
92
93 true
94 }
95
96 pub fn find(&self, mut element: Element) -> Element {
98 let mut parent = self.parent(element);
99
100 while element != parent {
101 let grandparent = self.parent(parent);
102 self.set_parent(element, grandparent);
103 element = parent;
104 parent = grandparent;
105 }
106
107 element
108 }
109
110 pub fn equiv(&self, a: Element, b: Element) -> bool {
112 self.find(a) == self.find(b)
113 }
114
115 pub fn force(&self) {
118 for i in 0 .. self.len() {
119 let element = Element::from_usize(i).unwrap();
120 let root = self.find(element);
121 self.set_parent(element, root);
122 }
123 }
124
125 pub fn to_vec(&self) -> Vec<Element> {
127 self.force();
128 self.elements.iter().map(Cell::get).collect()
129 }
130
131 fn rank(&self, element: Element) -> u8 {
134 self.ranks[element.to_usize()]
135 }
136
137 fn increment_rank(&mut self, element: Element) {
138 let i = element.to_usize();
139 self.ranks[i] = self.ranks[i].saturating_add(1);
140 }
141
142 fn parent(&self, element: Element) -> Element {
143 self.elements[element.to_usize()].get()
144 }
145
146 fn set_parent(&self, element: Element, parent: Element) {
147 self.elements[element.to_usize()].set(parent);
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn len() {
157 assert_eq!(5, UnionFind::<u32>::new(5).len());
158 }
159
160 #[test]
161 fn union() {
162 let mut uf = UnionFind::<u32>::new(8);
163 assert!(!uf.equiv(0, 1));
164 uf.union(0, 1);
165 assert!(uf.equiv(0, 1));
166 }
167
168 #[test]
169 fn unions() {
170 let mut uf = UnionFind::<usize>::new(8);
171 assert!(uf.union(0, 1));
172 assert!(uf.union(1, 2));
173 assert!(uf.union(4, 3));
174 assert!(uf.union(3, 2));
175 assert!(! uf.union(0, 3));
176
177 assert!(uf.equiv(0, 1));
178 assert!(uf.equiv(0, 2));
179 assert!(uf.equiv(0, 3));
180 assert!(uf.equiv(0, 4));
181 assert!(!uf.equiv(0, 5));
182
183 uf.union(5, 3);
184 assert!(uf.equiv(0, 5));
185
186 uf.union(6, 7);
187 assert!(uf.equiv(6, 7));
188 assert!(!uf.equiv(5, 7));
189
190 uf.union(0, 7);
191 assert!(uf.equiv(5, 7));
192 }
193
194 #[cfg(feature = "serde")]
195 #[test]
196 fn serde_round_trip() {
197 extern crate serde_json;
198
199 let mut uf0: UnionFind<usize> = UnionFind::new(8);
200 uf0.union(0, 1);
201 uf0.union(2, 3);
202 assert!( uf0.equiv(0, 1));
203 assert!(!uf0.equiv(1, 2));
204 assert!( uf0.equiv(2, 3));
205
206 let json = serde_json::to_string(&uf0).unwrap();
207 let uf1: UnionFind<usize> = serde_json::from_str(&json).unwrap();
208 assert!( uf1.equiv(0, 1));
209 assert!(!uf1.equiv(1, 2));
210 assert!( uf1.equiv(2, 3));
211 }
212}