1struct UnionFind {
7 parent: Vec<u32>,
8 rank: Vec<u8>,
9}
10
11impl UnionFind {
12 fn new(size: usize) -> Self {
13 Self {
14 parent: (0..size as u32).collect(),
15 rank: vec![0; size],
16 }
17 }
18
19 fn find(&mut self, mut x: u32) -> u32 {
20 while self.parent[x as usize] != x {
21 self.parent[x as usize] = self.parent[self.parent[x as usize] as usize];
22 x = self.parent[x as usize];
23 }
24 x
25 }
26
27 fn union(&mut self, a: u32, b: u32) {
28 let ra = self.find(a);
29 let rb = self.find(b);
30 if ra == rb {
31 return;
32 }
33 match self.rank[ra as usize].cmp(&self.rank[rb as usize]) {
34 std::cmp::Ordering::Less => self.parent[ra as usize] = rb,
35 std::cmp::Ordering::Greater => self.parent[rb as usize] = ra,
36 std::cmp::Ordering::Equal => {
37 self.parent[rb as usize] = ra;
38 self.rank[ra as usize] += 1;
39 }
40 }
41 }
42}
43
44pub fn label_components(mask: &[bool], width: u32, height: u32, connectivity: u8) -> Vec<u32> {
49 let len = (width * height) as usize;
50 assert_eq!(mask.len(), len);
51
52 let mut labels = vec![0u32; len];
53 let mut uf = UnionFind::new(len + 1); let mut next_label = 0u32;
55
56 for y in 0..height {
58 for x in 0..width {
59 let idx = (y * width + x) as usize;
60 if !mask[idx] {
61 continue;
62 }
63
64 let mut neighbor_labels: Vec<u32> = Vec::with_capacity(4);
65
66 if y > 0 {
70 let up = ((y - 1) * width + x) as usize;
72 if labels[up] > 0 {
73 neighbor_labels.push(labels[up]);
74 }
75
76 if connectivity == 8 {
77 if x > 0 {
79 let ul = ((y - 1) * width + (x - 1)) as usize;
80 if labels[ul] > 0 {
81 neighbor_labels.push(labels[ul]);
82 }
83 }
84 if x + 1 < width {
86 let ur = ((y - 1) * width + (x + 1)) as usize;
87 if labels[ur] > 0 {
88 neighbor_labels.push(labels[ur]);
89 }
90 }
91 }
92 }
93 if x > 0 {
95 let left = (y * width + (x - 1)) as usize;
96 if labels[left] > 0 {
97 neighbor_labels.push(labels[left]);
98 }
99 }
100
101 if neighbor_labels.is_empty() {
102 next_label += 1;
103 labels[idx] = next_label;
104 } else {
105 let min_label = *neighbor_labels.iter().min().unwrap();
106 labels[idx] = min_label;
107 for &nl in &neighbor_labels {
108 if nl != min_label {
109 uf.union(min_label, nl);
110 }
111 }
112 }
113 }
114 }
115
116 let mut remap = std::collections::HashMap::new();
118 let mut next_id = 0u32;
119
120 for label in labels.iter_mut() {
121 if *label == 0 {
122 continue;
123 }
124 let root = uf.find(*label);
125 let id = *remap.entry(root).or_insert_with(|| {
126 next_id += 1;
127 next_id
128 });
129 *label = id;
130 }
131
132 labels
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn single_pixel() {
141 let mask = vec![
142 false, false, false,
143 false, true, false,
144 false, false, false,
145 ];
146 let labels = label_components(&mask, 3, 3, 8);
147 assert_eq!(labels[4], 1);
148 assert_eq!(labels.iter().filter(|&&l| l > 0).count(), 1);
149 }
150
151 #[test]
152 fn two_disconnected_regions() {
153 let mask = vec![
154 true, true, false, false, false,
155 true, true, false, false, false,
156 false, false, false, false, false,
157 false, false, false, true, true,
158 false, false, false, true, true,
159 ];
160 let labels = label_components(&mask, 5, 5, 8);
161 let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
162 assert_eq!(unique.len(), 2);
163 }
164
165 #[test]
166 fn diagonal_connected_with_8_connectivity() {
167 let mask = vec![
168 true, false, false,
169 false, true, false,
170 false, false, true,
171 ];
172 let labels = label_components(&mask, 3, 3, 8);
173 let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
175 assert_eq!(unique.len(), 1);
176 }
177
178 #[test]
179 fn diagonal_disconnected_with_4_connectivity() {
180 let mask = vec![
181 true, false, false,
182 false, true, false,
183 false, false, true,
184 ];
185 let labels = label_components(&mask, 3, 3, 4);
186 let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
188 assert_eq!(unique.len(), 3);
189 }
190
191 #[test]
192 fn l_shape_is_single_region() {
193 let mask = vec![
194 true, false,
195 true, false,
196 true, true,
197 ];
198 let labels = label_components(&mask, 2, 3, 4);
199 let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
200 assert_eq!(unique.len(), 1);
201 }
202}