struct UnionFind {
parent: Vec<u32>,
rank: Vec<u8>,
}
impl UnionFind {
fn new(size: usize) -> Self {
Self {
parent: (0..size as u32).collect(),
rank: vec![0; size],
}
}
fn find(&mut self, mut x: u32) -> u32 {
while self.parent[x as usize] != x {
self.parent[x as usize] = self.parent[self.parent[x as usize] as usize];
x = self.parent[x as usize];
}
x
}
fn union(&mut self, a: u32, b: u32) {
let ra = self.find(a);
let rb = self.find(b);
if ra == rb {
return;
}
match self.rank[ra as usize].cmp(&self.rank[rb as usize]) {
std::cmp::Ordering::Less => self.parent[ra as usize] = rb,
std::cmp::Ordering::Greater => self.parent[rb as usize] = ra,
std::cmp::Ordering::Equal => {
self.parent[rb as usize] = ra;
self.rank[ra as usize] += 1;
}
}
}
}
pub fn label_components(mask: &[bool], width: u32, height: u32, connectivity: u8) -> Vec<u32> {
let len = (width * height) as usize;
assert_eq!(mask.len(), len);
let mut labels = vec![0u32; len];
let mut uf = UnionFind::new(len + 1); let mut next_label = 0u32;
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
if !mask[idx] {
continue;
}
let mut neighbor_labels: Vec<u32> = Vec::with_capacity(4);
if y > 0 {
let up = ((y - 1) * width + x) as usize;
if labels[up] > 0 {
neighbor_labels.push(labels[up]);
}
if connectivity == 8 {
if x > 0 {
let ul = ((y - 1) * width + (x - 1)) as usize;
if labels[ul] > 0 {
neighbor_labels.push(labels[ul]);
}
}
if x + 1 < width {
let ur = ((y - 1) * width + (x + 1)) as usize;
if labels[ur] > 0 {
neighbor_labels.push(labels[ur]);
}
}
}
}
if x > 0 {
let left = (y * width + (x - 1)) as usize;
if labels[left] > 0 {
neighbor_labels.push(labels[left]);
}
}
if neighbor_labels.is_empty() {
next_label += 1;
labels[idx] = next_label;
} else {
let min_label = *neighbor_labels.iter().min().unwrap();
labels[idx] = min_label;
for &nl in &neighbor_labels {
if nl != min_label {
uf.union(min_label, nl);
}
}
}
}
}
let mut remap = std::collections::HashMap::new();
let mut next_id = 0u32;
for label in labels.iter_mut() {
if *label == 0 {
continue;
}
let root = uf.find(*label);
let id = *remap.entry(root).or_insert_with(|| {
next_id += 1;
next_id
});
*label = id;
}
labels
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_pixel() {
let mask = vec![
false, false, false,
false, true, false,
false, false, false,
];
let labels = label_components(&mask, 3, 3, 8);
assert_eq!(labels[4], 1);
assert_eq!(labels.iter().filter(|&&l| l > 0).count(), 1);
}
#[test]
fn two_disconnected_regions() {
let mask = vec![
true, true, false, false, false,
true, true, false, false, false,
false, false, false, false, false,
false, false, false, true, true,
false, false, false, true, true,
];
let labels = label_components(&mask, 5, 5, 8);
let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn diagonal_connected_with_8_connectivity() {
let mask = vec![
true, false, false,
false, true, false,
false, false, true,
];
let labels = label_components(&mask, 3, 3, 8);
let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
assert_eq!(unique.len(), 1);
}
#[test]
fn diagonal_disconnected_with_4_connectivity() {
let mask = vec![
true, false, false,
false, true, false,
false, false, true,
];
let labels = label_components(&mask, 3, 3, 4);
let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn l_shape_is_single_region() {
let mask = vec![
true, false,
true, false,
true, true,
];
let labels = label_components(&mask, 2, 3, 4);
let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
assert_eq!(unique.len(), 1);
}
}