disjoint_sets/
array.rs

1use std::cell::Cell;
2use std::fmt::{self, Debug};
3
4use super::ElementType;
5
6/// Vector-based union-find representing a set of disjoint sets.
7#[derive(Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub struct UnionFind<Element: ElementType = usize> {
10    elements: Vec<Cell<Element>>,
11    ranks: Vec<u8>,
12}
13// Invariant: self.elements.len() == self.ranks.len()
14
15impl<Element: Debug + ElementType> Debug for UnionFind<Element> {
16    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
17        write!(formatter, "UnionFind({:?})", self.elements)
18    }
19}
20
21impl<Element: ElementType> Default for UnionFind<Element> {
22    fn default() -> Self {
23        UnionFind::new(0)
24    }
25}
26
27impl<Element: ElementType> UnionFind<Element> {
28    /// Creates a new union-find of `size` elements.
29    ///
30    /// # Panics
31    ///
32    /// If `size` elements would overflow the element type `Element`.
33    pub fn new(size: usize) -> Self {
34        UnionFind {
35            elements: (0..size).map(|i| {
36                let e = Element::from_usize(i).expect("UnionFind::new: overflow");
37                Cell::new(e)
38            }).collect(),
39            ranks: vec![0; size],
40        }
41    }
42
43    /// The number of elements in all the sets.
44    pub fn len(&self) -> usize {
45        self.elements.len()
46    }
47
48    /// Is the union-find devoid of elements?
49    ///
50    /// It is possible to create an empty `UnionFind` and then add
51    /// elements with [`alloc`](#method.alloc).
52    pub fn is_empty(&self) -> bool {
53        self.elements.is_empty()
54    }
55
56    /// Creates a new element in a singleton set.
57    ///
58    /// # Panics
59    ///
60    /// If allocating another element would overflow the element type
61    /// `Element`.
62    pub fn alloc(&mut self) -> Element {
63        let result = Element::from_usize(self.elements.len())
64                       .expect("UnionFind::alloc: overflow");
65        self.elements.push(Cell::new(result));
66        self.ranks.push(0);
67        result
68    }
69
70    /// Joins the sets of the two given elements.
71    ///
72    /// Returns whether anything changed. That is, if the sets were
73    /// different, it returns `true`, but if they were already the same
74    /// then it returns `false`.
75    pub fn union(&mut self, a: Element, b: Element) -> bool {
76        let a = self.find(a);
77        let b = self.find(b);
78
79        if a == b { return false; }
80
81        let rank_a = self.rank(a);
82        let rank_b = self.rank(b);
83
84        if rank_a > rank_b {
85            self.set_parent(b, a);
86        } else if rank_b > rank_a {
87            self.set_parent(a, b);
88        } else {
89            self.set_parent(a, b);
90            self.increment_rank(b);
91        }
92
93        true
94    }
95
96    /// Finds the representative element for the given element’s set.
97    pub fn find(&self, mut element: Element) -> Element {
98        let mut parent = self.parent(element);
99
100        while element != parent {
101            let grandparent = self.parent(parent);
102            self.set_parent(element, grandparent);
103            element = parent;
104            parent = grandparent;
105        }
106
107        element
108    }
109
110    /// Determines whether two elements are in the same set.
111    pub fn equiv(&self, a: Element, b: Element) -> bool {
112        self.find(a) == self.find(b)
113    }
114
115    /// Forces all laziness, so that each element points directly to its
116    /// set’s representative.
117    pub fn force(&self) {
118        for i in 0 .. self.len() {
119            let element = Element::from_usize(i).unwrap();
120            let root = self.find(element);
121            self.set_parent(element, root);
122        }
123    }
124
125    /// Returns a vector of set representatives.
126    pub fn to_vec(&self) -> Vec<Element> {
127        self.force();
128        self.elements.iter().map(Cell::get).collect()
129    }
130
131    // HELPERS
132
133    fn rank(&self, element: Element) -> u8 {
134        self.ranks[element.to_usize()]
135    }
136
137    fn increment_rank(&mut self, element: Element) {
138        let i = element.to_usize();
139        self.ranks[i] = self.ranks[i].saturating_add(1);
140    }
141
142    fn parent(&self, element: Element) -> Element {
143        self.elements[element.to_usize()].get()
144    }
145
146    fn set_parent(&self, element: Element, parent: Element) {
147        self.elements[element.to_usize()].set(parent);
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn len() {
157        assert_eq!(5, UnionFind::<u32>::new(5).len());
158    }
159
160    #[test]
161    fn union() {
162        let mut uf = UnionFind::<u32>::new(8);
163        assert!(!uf.equiv(0, 1));
164        uf.union(0, 1);
165        assert!(uf.equiv(0, 1));
166    }
167
168    #[test]
169    fn unions() {
170        let mut uf = UnionFind::<usize>::new(8);
171        assert!(uf.union(0, 1));
172        assert!(uf.union(1, 2));
173        assert!(uf.union(4, 3));
174        assert!(uf.union(3, 2));
175        assert!(! uf.union(0, 3));
176
177        assert!(uf.equiv(0, 1));
178        assert!(uf.equiv(0, 2));
179        assert!(uf.equiv(0, 3));
180        assert!(uf.equiv(0, 4));
181        assert!(!uf.equiv(0, 5));
182
183        uf.union(5, 3);
184        assert!(uf.equiv(0, 5));
185
186        uf.union(6, 7);
187        assert!(uf.equiv(6, 7));
188        assert!(!uf.equiv(5, 7));
189
190        uf.union(0, 7);
191        assert!(uf.equiv(5, 7));
192    }
193
194    #[cfg(feature = "serde")]
195    #[test]
196    fn serde_round_trip() {
197        extern crate serde_json;
198
199        let mut uf0: UnionFind<usize> = UnionFind::new(8);
200        uf0.union(0, 1);
201        uf0.union(2, 3);
202        assert!( uf0.equiv(0, 1));
203        assert!(!uf0.equiv(1, 2));
204        assert!( uf0.equiv(2, 3));
205
206        let json = serde_json::to_string(&uf0).unwrap();
207        let uf1: UnionFind<usize> = serde_json::from_str(&json).unwrap();
208        assert!( uf1.equiv(0, 1));
209        assert!(!uf1.equiv(1, 2));
210        assert!( uf1.equiv(2, 3));
211    }
212}