path_finding/
union_find.rs1pub struct UnionFind {
2 sizes: Vec<usize>,
3 ids: Vec<usize>,
4 components: usize,
5}
6
7impl UnionFind {
8 pub fn from(node_count: usize) -> UnionFind {
9 let sizes = vec![1; node_count];
10 let mut ids = Vec::with_capacity(node_count);
11 let components = node_count;
12
13 for i in 0..node_count {
14 ids.push(i);
15 }
16
17 return UnionFind {
18 sizes,
19 ids,
20 components,
21 };
22 }
23
24 fn find(&mut self, mut p: usize) -> usize {
25 let mut root = p;
26
27 while root != self.ids[root] {
28 root = self.ids[root];
29 }
30
31 while p != root {
32 let next = self.ids[p];
33 self.ids[p] = root;
34 p = next;
35 }
36
37 return root;
38 }
39
40 pub fn connected(&mut self, p: usize, q: usize) -> bool {
41 return self.find(p) == self.find(q);
42 }
43
44 pub fn unify(&mut self, p: usize, q: usize) {
45 let p_root = self.find(p);
46 let q_root = self.find(q);
47
48 if p_root == q_root {
49 return;
50 }
51
52 if self.sizes[p_root] < self.sizes[q_root] {
53 self.sizes[q_root] += self.sizes[p_root];
54 self.ids[p_root] = self.ids[q_root];
55 } else {
56 self.sizes[p_root] += self.sizes[q_root];
57 self.ids[q_root] = self.ids[p_root];
58 }
59
60 self.components -= 1;
61 }
62
63 pub fn size(&self, id: usize) -> usize {
64 return self.sizes[id];
65 }
66
67 pub fn parent(&self, id: usize) -> usize {
68 return self.ids[id];
69 }
70}
71
72#[test]
73fn union_find_with_zero_edges_should_succeed() {
74 let union_find = UnionFind::from(0);
75
76 assert_eq!(0, union_find.components)
77}
78
79#[test]
80fn unify_should_decrease_components() {
81 let mut union_find = UnionFind::from(2);
82
83 assert_eq!(2, union_find.components);
84
85 union_find.unify(0, 1);
86 assert_eq!(1, union_find.components);
87 assert_eq!(2, union_find.size(0));
88 assert_eq!(0, union_find.parent(1));
89}
90
91#[test]
92fn test_find() {
93 let mut union_find = UnionFind::from(5);
94 union_find.unify(0, 1);
95 union_find.unify(1, 2);
96
97 assert_eq!(0, union_find.find(2));
98}
99
100#[test]
101fn test_connected() {
102 let mut union_find = UnionFind::from(5);
103 union_find.unify(0, 1);
104 union_find.unify(3, 4);
105
106 assert!(union_find.connected(0, 1));
107 assert!(union_find.connected(3, 4));
108 assert!(!union_find.connected(0, 4));
109}
110
111#[test]
112fn test_unify() {
113 let mut union_find = UnionFind::from(4);
114 union_find.unify(0, 1);
115
116 assert_eq!(2, union_find.size(0));
117 assert_eq!(1, union_find.size(1));
118 assert_eq!(1, union_find.size(2));
119 assert_eq!(1, union_find.size(3));
120}
121
122#[test]
123fn test_unify_multiple_groups() {
124 let mut union_find = UnionFind::from(6);
125 union_find.unify(0, 1);
126 union_find.unify(1, 2);
127 union_find.unify(3, 4);
128
129 assert_eq!(3, union_find.size(0));
130 assert_eq!(1, union_find.size(1));
131 assert_eq!(1, union_find.size(2));
132 assert_eq!(2, union_find.size(3));
133 assert_eq!(1, union_find.size(4));
134 assert_eq!(1, union_find.size(5));
135}
136
137#[test]
138fn test_components() {
139 let mut union_find = UnionFind::from(5);
140 assert_eq!(5, union_find.components);
141
142 union_find.unify(0, 1);
143 assert_eq!(4, union_find.components);
144
145 union_find.unify(1, 2);
146 assert_eq!(3, union_find.components);
147
148 union_find.unify(3, 4);
149 assert_eq!(2, union_find.components);
150}