Skip to main content

agent_image_diff/
cluster.rs

1/// Connected component labeling using two-pass algorithm with union-find.
2///
3/// Labels each true pixel in the mask with a component ID (1-indexed).
4/// Background (false) pixels get label 0.
5
6struct UnionFind {
7    parent: Vec<u32>,
8    rank: Vec<u8>,
9}
10
11impl UnionFind {
12    fn new(size: usize) -> Self {
13        Self {
14            parent: (0..size as u32).collect(),
15            rank: vec![0; size],
16        }
17    }
18
19    fn find(&mut self, mut x: u32) -> u32 {
20        while self.parent[x as usize] != x {
21            self.parent[x as usize] = self.parent[self.parent[x as usize] as usize];
22            x = self.parent[x as usize];
23        }
24        x
25    }
26
27    fn union(&mut self, a: u32, b: u32) {
28        let ra = self.find(a);
29        let rb = self.find(b);
30        if ra == rb {
31            return;
32        }
33        match self.rank[ra as usize].cmp(&self.rank[rb as usize]) {
34            std::cmp::Ordering::Less => self.parent[ra as usize] = rb,
35            std::cmp::Ordering::Greater => self.parent[rb as usize] = ra,
36            std::cmp::Ordering::Equal => {
37                self.parent[rb as usize] = ra;
38                self.rank[ra as usize] += 1;
39            }
40        }
41    }
42}
43
44/// Label connected components in a boolean diff mask.
45///
46/// Returns a Vec<u32> of the same length as mask, where 0 = background
47/// and 1+ = component label.
48pub fn label_components(mask: &[bool], width: u32, height: u32, connectivity: u8) -> Vec<u32> {
49    let len = (width * height) as usize;
50    assert_eq!(mask.len(), len);
51
52    let mut labels = vec![0u32; len];
53    let mut uf = UnionFind::new(len + 1); // +1 since labels are 1-indexed
54    let mut next_label = 0u32;
55
56    // Pass 1: assign provisional labels
57    for y in 0..height {
58        for x in 0..width {
59            let idx = (y * width + x) as usize;
60            if !mask[idx] {
61                continue;
62            }
63
64            let mut neighbor_labels: Vec<u32> = Vec::with_capacity(4);
65
66            // Check already-visited neighbors based on connectivity
67            // For 4-connectivity: up, left
68            // For 8-connectivity: up-left, up, up-right, left
69            if y > 0 {
70                // Up
71                let up = ((y - 1) * width + x) as usize;
72                if labels[up] > 0 {
73                    neighbor_labels.push(labels[up]);
74                }
75
76                if connectivity == 8 {
77                    // Up-left
78                    if x > 0 {
79                        let ul = ((y - 1) * width + (x - 1)) as usize;
80                        if labels[ul] > 0 {
81                            neighbor_labels.push(labels[ul]);
82                        }
83                    }
84                    // Up-right
85                    if x + 1 < width {
86                        let ur = ((y - 1) * width + (x + 1)) as usize;
87                        if labels[ur] > 0 {
88                            neighbor_labels.push(labels[ur]);
89                        }
90                    }
91                }
92            }
93            // Left
94            if x > 0 {
95                let left = (y * width + (x - 1)) as usize;
96                if labels[left] > 0 {
97                    neighbor_labels.push(labels[left]);
98                }
99            }
100
101            if neighbor_labels.is_empty() {
102                next_label += 1;
103                labels[idx] = next_label;
104            } else {
105                let min_label = *neighbor_labels.iter().min().unwrap();
106                labels[idx] = min_label;
107                for &nl in &neighbor_labels {
108                    if nl != min_label {
109                        uf.union(min_label, nl);
110                    }
111                }
112            }
113        }
114    }
115
116    // Pass 2: replace each label with its root, renumber contiguously
117    let mut remap = std::collections::HashMap::new();
118    let mut next_id = 0u32;
119
120    for label in labels.iter_mut() {
121        if *label == 0 {
122            continue;
123        }
124        let root = uf.find(*label);
125        let id = *remap.entry(root).or_insert_with(|| {
126            next_id += 1;
127            next_id
128        });
129        *label = id;
130    }
131
132    labels
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn single_pixel() {
141        let mask = vec![
142            false, false, false,
143            false, true, false,
144            false, false, false,
145        ];
146        let labels = label_components(&mask, 3, 3, 8);
147        assert_eq!(labels[4], 1);
148        assert_eq!(labels.iter().filter(|&&l| l > 0).count(), 1);
149    }
150
151    #[test]
152    fn two_disconnected_regions() {
153        let mask = vec![
154            true, true, false, false, false,
155            true, true, false, false, false,
156            false, false, false, false, false,
157            false, false, false, true, true,
158            false, false, false, true, true,
159        ];
160        let labels = label_components(&mask, 5, 5, 8);
161        let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
162        assert_eq!(unique.len(), 2);
163    }
164
165    #[test]
166    fn diagonal_connected_with_8_connectivity() {
167        let mask = vec![
168            true, false, false,
169            false, true, false,
170            false, false, true,
171        ];
172        let labels = label_components(&mask, 3, 3, 8);
173        // All three should be the same component with 8-connectivity
174        let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
175        assert_eq!(unique.len(), 1);
176    }
177
178    #[test]
179    fn diagonal_disconnected_with_4_connectivity() {
180        let mask = vec![
181            true, false, false,
182            false, true, false,
183            false, false, true,
184        ];
185        let labels = label_components(&mask, 3, 3, 4);
186        // Each pixel should be its own component with 4-connectivity
187        let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
188        assert_eq!(unique.len(), 3);
189    }
190
191    #[test]
192    fn l_shape_is_single_region() {
193        let mask = vec![
194            true, false,
195            true, false,
196            true, true,
197        ];
198        let labels = label_components(&mask, 2, 3, 4);
199        let unique: std::collections::HashSet<u32> = labels.iter().filter(|&&l| l > 0).copied().collect();
200        assert_eq!(unique.len(), 1);
201    }
202}