1#[macro_use]
17extern crate serde_derive;
18
19use multimap::MMap;
20use std::collections::btree_map::Entry;
21use std::collections::BTreeMap as Map;
22
23#[derive(Clone, Debug, Deserialize, Serialize)]
24pub struct Partition<T: Copy + Ord> {
25 ranks: Map<T, usize>,
26 parent_map: Map<T, T>,
27 child_map: MMap<T, T>,
28}
29
30impl<T: Copy + Ord> Default for Partition<T> {
31 fn default() -> Partition<T> {
32 Partition::new()
33 }
34}
35
36impl<T: Copy + Ord> Partition<T> {
37 pub fn new() -> Partition<T> {
38 Partition {
39 ranks: Map::new(),
40 parent_map: Map::new(),
41 child_map: MMap::new(),
42 }
43 }
44
45 pub fn insert(&mut self, elt: T) {
47 match self.ranks.entry(elt) {
48 Entry::Occupied(_) => panic!("tried to insert an element twice"),
49 Entry::Vacant(e) => e.insert(0),
50 };
51 }
52
53 pub fn is_rep(&self, elt: &T) -> bool {
55 !self.parent_map.contains_key(elt)
56 }
57
58 pub fn merge(&mut self, elt1: T, elt2: T) -> bool {
61 let rep1 = self.representative_mut(elt1);
62 let rep2 = self.representative_mut(elt2);
63 if rep1 != rep2 {
64 self.merge_reps(rep1, rep2);
65 true
66 } else {
67 false
68 }
69 }
70
71 fn merge_reps(&mut self, rep1: T, rep2: T) {
73 assert!(self.is_rep(&rep1) && self.is_rep(&rep2));
74 let rank1 = self.ranks[&rep1];
75 let rank2 = self.ranks[&rep2];
76 if rank1 <= rank2 {
77 self.parent_map.insert(rep1, rep2);
78 self.child_map.insert(rep2, rep1);
79 if rank1 == rank2 {
80 self.ranks.insert(rep2, rank2 + 1);
81 }
82 } else {
83 self.parent_map.insert(rep2, rep1);
84 self.child_map.insert(rep1, rep2);
85 }
86 }
87
88 pub fn representative_mut(&mut self, elt: T) -> T {
89 let rep = self.representative(elt);
90 if let Some(orig_parent_ref) = self.parent_map.get_mut(&elt) {
92 if *orig_parent_ref != rep {
93 self.child_map.remove(&*orig_parent_ref, &elt);
94 self.child_map.insert(rep, elt);
95 *orig_parent_ref = rep;
96 }
97 }
98 rep
99 }
100
101 pub fn representative(&self, elt: T) -> T {
102 debug_assert!(self.contains(elt));
103 let mut ret = elt;
104 while let Some(parent) = self.parent_map.get(&ret) {
105 ret = *parent;
106 }
107 ret
108 }
109
110 pub fn same_part_mut(&mut self, elt1: T, elt2: T) -> bool {
111 self.representative_mut(elt1) == self.representative_mut(elt2)
112 }
113
114 pub fn same_part(&self, elt1: T, elt2: T) -> bool {
115 self.representative(elt1) == self.representative(elt2)
116 }
117
118 pub fn contains(&self, elt: T) -> bool {
119 self.ranks.contains_key(&elt)
120 }
121
122 pub fn remove_part(&mut self, elt: T) {
123 let elts = self.iter_part(elt).collect::<Vec<_>>();
124 for e in elts {
125 self.parent_map.remove(&e);
126 self.ranks.remove(&e);
127 self.child_map.remove_all(&e);
128 }
129 }
130
131 pub fn iter_part<'a>(&'a self, elt: T) -> impl Iterator<Item = T> + 'a {
132 PartIter::new(self, self.representative(elt))
133 }
134
135 pub fn iter_parts<'a>(&'a self) -> impl Iterator<Item = impl Iterator<Item = T> + 'a> + 'a {
136 self.ranks
137 .keys()
138 .filter(move |elt| self.is_rep(elt))
140 .map(move |r| self.iter_part(*r))
142 }
143}
144
145impl<T: Copy + Ord, PI: IntoIterator<Item = T>> std::iter::FromIterator<PI> for Partition<T> {
146 fn from_iter<I>(iter: I) -> Self
147 where
148 I: IntoIterator<Item = PI>,
149 {
150 let mut ret = Partition::new();
151
152 for part in iter.into_iter() {
153 let mut part_iter = part.into_iter();
154 if let Some(rep) = part_iter.next() {
155 ret.ranks.insert(rep, 1);
158 for child in part_iter {
159 ret.ranks.insert(child, 0);
160 ret.parent_map.insert(child, rep);
161 ret.child_map.insert(rep, child);
162 }
163 }
164 }
165 ret
166 }
167}
168
169pub struct PartIter<'a, T: Copy + Ord> {
170 partition: &'a Partition<T>,
171 stack: Vec<Box<dyn Iterator<Item = T> + 'a>>,
176}
177
178impl<'a, T: Copy + Ord> PartIter<'a, T> {
179 fn new(partition: &'a Partition<T>, root: T) -> PartIter<'a, T> {
180 PartIter {
181 partition,
182 stack: vec![Box::new(Some(root).into_iter())],
183 }
184 }
185}
186
187impl<'a, T: Copy + Ord> Iterator for PartIter<'a, T> {
188 type Item = T;
189
190 fn next(&mut self) -> Option<Self::Item> {
191 while let Some(iter) = self.stack.last_mut() {
192 if let Some(item) = iter.next() {
193 self.stack
194 .push(Box::new(self.partition.child_map.get(&item).cloned()));
195 return Some(item);
196 } else {
197 self.stack.pop();
198 }
199 }
200 None
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
210 fn partition() {
211 fn assert_vec_eq(mut a: Vec<u32>, mut b: Vec<u32>) {
212 a.sort();
213 b.sort();
214 assert_eq!(a, b);
215 }
216
217 let mut partition = Partition::new();
218 partition.insert(0);
219 partition.insert(1);
220 partition.insert(2);
221 partition.insert(3);
222 partition.insert(4);
223
224 assert_eq!(partition.iter_parts().count(), 5);
225
226 partition.merge(0, 4);
227 assert_eq!(partition.iter_parts().count(), 4);
228 partition.merge(0, 4);
229 assert_eq!(partition.iter_parts().count(), 4);
230 assert!(partition.same_part(0, 4));
231 assert_vec_eq(partition.iter_part(0).collect(), vec![0, 4]);
232 assert_vec_eq(partition.iter_part(4).collect(), vec![0, 4]);
233
234 partition.merge(1, 2);
235 assert_eq!(partition.iter_parts().count(), 3);
236 assert!(partition.same_part(1, 2));
237 assert_vec_eq(partition.iter_part(1).collect(), vec![1, 2]);
238 assert_vec_eq(partition.iter_part(2).collect(), vec![1, 2]);
239
240 partition.merge(2, 4);
241 assert_eq!(partition.iter_parts().count(), 2);
242 assert_vec_eq(partition.iter_part(0).collect(), vec![0, 1, 2, 4]);
243 assert_vec_eq(partition.iter_part(1).collect(), vec![0, 1, 2, 4]);
244 assert_vec_eq(partition.iter_part(2).collect(), vec![0, 1, 2, 4]);
245 assert_vec_eq(partition.iter_part(4).collect(), vec![0, 1, 2, 4]);
246
247 partition.remove_part(1);
248 assert_eq!(partition.iter_parts().count(), 1);
249 assert_vec_eq(partition.iter_part(3).collect(), vec![3]);
250 }
251}