pineapple_core/cv/
connected.rs

1// Copyright (c) 2025, Tom Ouellette
2// Licensed under the BSD 3-Clause License
3
4use std::cmp::Ordering;
5
6/// A union-find structure for finding and merging connected components
7pub struct UnionFind {
8    parent: Vec<usize>,
9    rank: Vec<usize>,
10}
11
12impl UnionFind {
13    /// Initialize a new union-find object with `n` elements in `n` sets
14    pub fn new(n: usize) -> Self {
15        UnionFind {
16            parent: (0..n).collect(),
17            rank: vec![1; n],
18        }
19    }
20
21    /// Find the root of the set containing `x`
22    pub fn find(&mut self, x: usize) -> usize {
23        if self.parent[x] != x {
24            // Path compression
25            self.parent[x] = self.find(self.parent[x]);
26        }
27        self.parent[x]
28    }
29
30    /// Merge sets containing `x` and `y`
31    pub fn union(&mut self, x: usize, y: usize) {
32        let root_x = self.find(x);
33        let root_y = self.find(y);
34
35        if root_x != root_y {
36            match self.rank[root_x].cmp(&self.rank[root_y]) {
37                Ordering::Greater => self.parent[root_y] = root_x,
38                Ordering::Less => self.parent[root_x] = root_y,
39                Ordering::Equal => {
40                    self.parent[root_y] = root_x;
41                    self.rank[root_x] += 1;
42                }
43            }
44        }
45    }
46
47    /// Check if `x` and `y` belong to the same set
48    pub fn connected(&mut self, x: usize, y: usize) -> bool {
49        self.find(x) == self.find(y)
50    }
51}
52
53/// Two-pass 8-connected component labeling on mask buffers
54///
55/// # Arguments
56///
57/// * `width` - Width of mask
58/// * `height` - Height of mask
59/// * `buffer` - A row-major mask buffer
60///
61/// # Examples
62///
63/// ```
64/// use pineapple_core::im::PineappleMask;
65/// use pineapple_core::cv::connected_components;
66///
67/// let width = 3;
68/// let height = 3;
69///
70/// let buffer_one: Vec<u32> = vec![10, 10, 0, 10, 0, 20, 0, 20, 20];
71/// let labels_one = connected_components(width, height, &buffer_one);
72/// assert_eq!(labels_one, [1, 1, 0, 1, 0, 1, 0, 1, 1]);
73///
74/// let buffer_two: Vec<u32> = vec![10, 10, 10, 0, 0, 0, 20, 20, 20];
75/// let labels_two = connected_components(width, height, &buffer_two);
76/// assert_eq!(labels_two, [1, 1, 1, 0, 0, 0, 2, 2, 2]);
77/// ```
78pub fn connected_components(width: u32, height: u32, buffer: &[u32]) -> Vec<u32> {
79    let width = width as usize;
80    let height = height as usize;
81    let size = width * height;
82
83    let mut labels = vec![0u32; size];
84    let mut next_label = 1;
85    let mut uf = UnionFind::new(size);
86
87    // Assign preliminary labels (1st pass)
88    for y in 0..height {
89        for x in 0..width {
90            let idx = y * width + x;
91            if buffer[idx] == 0 {
92                // Ignore background pixels
93                continue;
94            }
95
96            let mut neighbors = vec![];
97
98            // Check left neighbor
99            if x > 0 && buffer[idx - 1] > 0 {
100                neighbors.push(labels[idx - 1]);
101            }
102
103            // Check top neighbor
104            if y > 0 && buffer[idx - width] > 0 {
105                neighbors.push(labels[idx - width]);
106            }
107
108            // Check top-left neighbor (diagonal)
109            if x > 0 && y > 0 && buffer[idx - width - 1] > 0 {
110                neighbors.push(labels[idx - width - 1]);
111            }
112
113            // Check top-right neighbor (diagonal)
114            if x < width - 1 && y > 0 && buffer[idx - width + 1] > 0 {
115                neighbors.push(labels[idx - width + 1]);
116            }
117
118            if neighbors.is_empty() {
119                labels[idx] = next_label;
120                next_label += 1;
121            } else {
122                let min_label = *neighbors.iter().min().unwrap();
123                labels[idx] = min_label;
124
125                // Take union of neighbors
126                for &label in &neighbors {
127                    uf.union(min_label as usize, label as usize);
128                }
129            }
130        }
131    }
132
133    // Resolve labels using union-find (2nd pass)
134    for label in labels.iter_mut().take(size) {
135        if label != &0 {
136            *label = uf.find(*label as usize) as u32;
137        }
138    }
139
140    labels
141}
142
143#[cfg(test)]
144mod test {
145
146    use super::*;
147
148    fn four_regions() -> (u32, u32, [u32; 9]) {
149        let mut buffer = [0u32; 9];
150
151        buffer[0] = 1u32;
152        buffer[2] = 2u32;
153        buffer[6] = 3u32;
154        buffer[8] = 3u32;
155
156        (3, 3, buffer)
157    }
158
159    fn three_regions() -> (u32, u32, [u32; 9]) {
160        let mut buffer = [0u32; 9];
161
162        buffer[0] = 1u32;
163        buffer[2] = 2u32;
164        buffer[6] = 3u32;
165        buffer[7] = 3u32;
166        buffer[8] = 3u32;
167
168        (3, 3, buffer)
169    }
170
171    fn touching_regions() -> (u32, u32, [u32; 9]) {
172        let mut buffer = [0u32; 9];
173
174        buffer[0] = 1u32;
175        buffer[2] = 2u32;
176        buffer[4] = 4u32;
177        buffer[6] = 3u32;
178        buffer[7] = 3u32;
179        buffer[8] = 3u32;
180
181        (3, 3, buffer)
182    }
183
184    #[test]
185    fn test_four_regions() {
186        let (w, h, buffer) = four_regions();
187
188        let mut labels = connected_components(w, h, &buffer);
189        labels.sort();
190        labels.dedup();
191
192        assert_eq!(labels, vec![0, 1, 2, 3, 4]);
193    }
194
195    #[test]
196    fn test_three_regions() {
197        let (w, h, buffer) = three_regions();
198
199        let mut labels = connected_components(w, h, &buffer);
200        labels.sort();
201        labels.dedup();
202
203        assert_eq!(labels, vec![0, 1, 2, 3]);
204    }
205
206    #[test]
207    fn test_middle_regions() {
208        let (w, h, buffer) = touching_regions();
209
210        let mut labels = connected_components(w, h, &buffer);
211        labels.sort();
212        labels.dedup();
213
214        assert_eq!(labels, vec![0, 1]);
215    }
216}