jp_partition/
lib.rs

1// Copyright 2018-2019 Joe Neeman.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// See the LICENSE-APACHE or LICENSE-MIT files at the top-level directory
10// of this distribution.
11
12//! This crate provides an implementation of the disjoint-sets algorithm that is built on top of
13//! a pair of multimaps. (The reason for this weird implementation is that once multimaps is fully
14//! persistent, this will be also.)
15
16#[macro_use]
17extern crate serde_derive;
18
19use multimap::MMap;
20use std::collections::btree_map::Entry;
21use std::collections::BTreeMap as Map;
22
23#[derive(Clone, Debug, Deserialize, Serialize)]
24pub struct Partition<T: Copy + Ord> {
25    ranks: Map<T, usize>,
26    parent_map: Map<T, T>,
27    child_map: MMap<T, T>,
28}
29
30impl<T: Copy + Ord> Default for Partition<T> {
31    fn default() -> Partition<T> {
32        Partition::new()
33    }
34}
35
36impl<T: Copy + Ord> Partition<T> {
37    pub fn new() -> Partition<T> {
38        Partition {
39            ranks: Map::new(),
40            parent_map: Map::new(),
41            child_map: MMap::new(),
42        }
43    }
44
45    /// Panics if the new element already exists.
46    pub fn insert(&mut self, elt: T) {
47        match self.ranks.entry(elt) {
48            Entry::Occupied(_) => panic!("tried to insert an element twice"),
49            Entry::Vacant(e) => e.insert(0),
50        };
51    }
52
53    /// Is the given element the representative of its component?
54    pub fn is_rep(&self, elt: &T) -> bool {
55        !self.parent_map.contains_key(elt)
56    }
57
58    /// Returns true if there was a merge to be done (i.e. they didn't already belong to the same
59    /// part).
60    pub fn merge(&mut self, elt1: T, elt2: T) -> bool {
61        let rep1 = self.representative_mut(elt1);
62        let rep2 = self.representative_mut(elt2);
63        if rep1 != rep2 {
64            self.merge_reps(rep1, rep2);
65            true
66        } else {
67            false
68        }
69    }
70
71    // Panics unless the two given elements are representatives of their components.
72    fn merge_reps(&mut self, rep1: T, rep2: T) {
73        assert!(self.is_rep(&rep1) && self.is_rep(&rep2));
74        let rank1 = self.ranks[&rep1];
75        let rank2 = self.ranks[&rep2];
76        if rank1 <= rank2 {
77            self.parent_map.insert(rep1, rep2);
78            self.child_map.insert(rep2, rep1);
79            if rank1 == rank2 {
80                self.ranks.insert(rep2, rank2 + 1);
81            }
82        } else {
83            self.parent_map.insert(rep2, rep1);
84            self.child_map.insert(rep1, rep2);
85        }
86    }
87
88    pub fn representative_mut(&mut self, elt: T) -> T {
89        let rep = self.representative(elt);
90        // Reparent the element to the representative.
91        if let Some(orig_parent_ref) = self.parent_map.get_mut(&elt) {
92            if *orig_parent_ref != rep {
93                self.child_map.remove(&*orig_parent_ref, &elt);
94                self.child_map.insert(rep, elt);
95                *orig_parent_ref = rep;
96            }
97        }
98        rep
99    }
100
101    pub fn representative(&self, elt: T) -> T {
102        debug_assert!(self.contains(elt));
103        let mut ret = elt;
104        while let Some(parent) = self.parent_map.get(&ret) {
105            ret = *parent;
106        }
107        ret
108    }
109
110    pub fn same_part_mut(&mut self, elt1: T, elt2: T) -> bool {
111        self.representative_mut(elt1) == self.representative_mut(elt2)
112    }
113
114    pub fn same_part(&self, elt1: T, elt2: T) -> bool {
115        self.representative(elt1) == self.representative(elt2)
116    }
117
118    pub fn contains(&self, elt: T) -> bool {
119        self.ranks.contains_key(&elt)
120    }
121
122    pub fn remove_part(&mut self, elt: T) {
123        let elts = self.iter_part(elt).collect::<Vec<_>>();
124        for e in elts {
125            self.parent_map.remove(&e);
126            self.ranks.remove(&e);
127            self.child_map.remove_all(&e);
128        }
129    }
130
131    pub fn iter_part<'a>(&'a self, elt: T) -> impl Iterator<Item = T> + 'a {
132        PartIter::new(self, self.representative(elt))
133    }
134
135    pub fn iter_parts<'a>(&'a self) -> impl Iterator<Item = impl Iterator<Item = T> + 'a> + 'a {
136        self.ranks
137            .keys()
138            // For each representative of a part...
139            .filter(move |elt| self.is_rep(elt))
140            // ...return an iterator over that part.
141            .map(move |r| self.iter_part(*r))
142    }
143}
144
145impl<T: Copy + Ord, PI: IntoIterator<Item = T>> std::iter::FromIterator<PI> for Partition<T> {
146    fn from_iter<I>(iter: I) -> Self
147    where
148        I: IntoIterator<Item = PI>,
149    {
150        let mut ret = Partition::new();
151
152        for part in iter.into_iter() {
153            let mut part_iter = part.into_iter();
154            if let Some(rep) = part_iter.next() {
155                // Declare the first element in the part as its representative; all other elements
156                // will have the representative as their direct parent.
157                ret.ranks.insert(rep, 1);
158                for child in part_iter {
159                    ret.ranks.insert(child, 0);
160                    ret.parent_map.insert(child, rep);
161                    ret.child_map.insert(rep, child);
162                }
163            }
164        }
165        ret
166    }
167}
168
169pub struct PartIter<'a, T: Copy + Ord> {
170    partition: &'a Partition<T>,
171    // We can traverse a component as though it were a tree, by following the child links. In order
172    // to keep track of the iteration we store a stack, each element of which contains an iterator
173    // over nodes at a certain level of the tree. Note that each of these iterators is of the type
174    // returned by MMap::get; we currently have no way to name this type, hence the Box.
175    stack: Vec<Box<dyn Iterator<Item = T> + 'a>>,
176}
177
178impl<'a, T: Copy + Ord> PartIter<'a, T> {
179    fn new(partition: &'a Partition<T>, root: T) -> PartIter<'a, T> {
180        PartIter {
181            partition,
182            stack: vec![Box::new(Some(root).into_iter())],
183        }
184    }
185}
186
187impl<'a, T: Copy + Ord> Iterator for PartIter<'a, T> {
188    type Item = T;
189
190    fn next(&mut self) -> Option<Self::Item> {
191        while let Some(iter) = self.stack.last_mut() {
192            if let Some(item) = iter.next() {
193                self.stack
194                    .push(Box::new(self.partition.child_map.get(&item).cloned()));
195                return Some(item);
196            } else {
197                self.stack.pop();
198            }
199        }
200        None
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    // TODO: think about how to use proptest for testing this
209    #[test]
210    fn partition() {
211        fn assert_vec_eq(mut a: Vec<u32>, mut b: Vec<u32>) {
212            a.sort();
213            b.sort();
214            assert_eq!(a, b);
215        }
216
217        let mut partition = Partition::new();
218        partition.insert(0);
219        partition.insert(1);
220        partition.insert(2);
221        partition.insert(3);
222        partition.insert(4);
223
224        assert_eq!(partition.iter_parts().count(), 5);
225
226        partition.merge(0, 4);
227        assert_eq!(partition.iter_parts().count(), 4);
228        partition.merge(0, 4);
229        assert_eq!(partition.iter_parts().count(), 4);
230        assert!(partition.same_part(0, 4));
231        assert_vec_eq(partition.iter_part(0).collect(), vec![0, 4]);
232        assert_vec_eq(partition.iter_part(4).collect(), vec![0, 4]);
233
234        partition.merge(1, 2);
235        assert_eq!(partition.iter_parts().count(), 3);
236        assert!(partition.same_part(1, 2));
237        assert_vec_eq(partition.iter_part(1).collect(), vec![1, 2]);
238        assert_vec_eq(partition.iter_part(2).collect(), vec![1, 2]);
239
240        partition.merge(2, 4);
241        assert_eq!(partition.iter_parts().count(), 2);
242        assert_vec_eq(partition.iter_part(0).collect(), vec![0, 1, 2, 4]);
243        assert_vec_eq(partition.iter_part(1).collect(), vec![0, 1, 2, 4]);
244        assert_vec_eq(partition.iter_part(2).collect(), vec![0, 1, 2, 4]);
245        assert_vec_eq(partition.iter_part(4).collect(), vec![0, 1, 2, 4]);
246
247        partition.remove_part(1);
248        assert_eq!(partition.iter_parts().count(), 1);
249        assert_vec_eq(partition.iter_part(3).collect(), vec![3]);
250    }
251}