1use crate::Id;
2use std::fmt::Debug;
3
4#[derive(Debug, Clone, Default)]
5#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
6pub struct UnionFind {
7 parents: Vec<Id>,
8}
9
10impl UnionFind {
11 pub fn make_set(&mut self) -> Id {
12 let id = Id::from(self.parents.len());
13 self.parents.push(id);
14 id
15 }
16
17 pub fn size(&self) -> usize {
18 self.parents.len()
19 }
20
21 fn parent(&self, query: Id) -> Id {
22 self.parents[usize::from(query)]
23 }
24
25 fn parent_mut(&mut self, query: Id) -> &mut Id {
26 &mut self.parents[usize::from(query)]
27 }
28
29 pub fn find(&self, mut current: Id) -> Id {
30 while current != self.parent(current) {
31 current = self.parent(current)
32 }
33 current
34 }
35
36 pub fn find_mut(&mut self, mut current: Id) -> Id {
37 while current != self.parent(current) {
38 let grandparent = self.parent(self.parent(current));
39 *self.parent_mut(current) = grandparent;
40 current = grandparent;
41 }
42 current
43 }
44
45 pub fn union(&mut self, root1: Id, root2: Id) -> Id {
47 *self.parent_mut(root2) = root1;
48 root1
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 fn ids(us: impl IntoIterator<Item = usize>) -> Vec<Id> {
57 us.into_iter().map(|u| u.into()).collect()
58 }
59
60 #[test]
61 fn union_find() {
62 let n = 10;
63 let id = Id::from;
64
65 let mut uf = UnionFind::default();
66 for _ in 0..n {
67 uf.make_set();
68 }
69
70 assert_eq!(uf.parents, ids(0..n));
72
73 uf.union(id(0), id(1));
75 uf.union(id(0), id(2));
76 uf.union(id(0), id(3));
77
78 uf.union(id(6), id(7));
80 uf.union(id(6), id(8));
81 uf.union(id(6), id(9));
82
83 for i in 0..n {
85 uf.find_mut(id(i));
86 }
87
88 let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6];
90 assert_eq!(uf.parents, ids(expected));
91 }
92}