Skip to main content

locus_core/
segmentation.rs

1//! Connected components labeling (CCL) using Union-Find.
2//!
3//! This module implements the second stage of the pipeline, identifying and grouping
4//! dark pixels that could potentially form fiducial markers.
5//!
6//! It provides:
7//! - **Standard CCL**: Efficient binary component labeling.
8//! - **Threshold-Model CCL**: Advanced connectivity based on local adaptive thresholds.
9
10#![allow(unsafe_code)]
11
12use bumpalo::Bump;
13use bumpalo::collections::Vec as BumpVec;
14use multiversion::multiversion;
15use rayon::prelude::*;
16
17/// A disjoint-set forest (Union-Find) with path compression and rank optimization.
18pub struct UnionFind<'a> {
19    parent: &'a mut [u32],
20    rank: &'a mut [u8],
21}
22
23impl<'a> UnionFind<'a> {
24    /// Create a new UnionFind structure backed by the provided arena.
25    pub fn new_in(arena: &'a Bump, size: usize) -> Self {
26        let parent = arena.alloc_slice_fill_with(size, |i| i as u32);
27        let rank = arena.alloc_slice_fill_copy(size, 0u8);
28        Self { parent, rank }
29    }
30
31    /// Find the representative (root) of the set containing `i`.
32    #[inline]
33    pub fn find(&mut self, i: u32) -> u32 {
34        let mut root = i;
35        while self.parent[root as usize] != root {
36            self.parent[root as usize] = self.parent[self.parent[root as usize] as usize];
37            root = self.parent[root as usize];
38        }
39        root
40    }
41
42    /// Unite the sets containing `i` and `j`.
43    #[inline]
44    pub fn union(&mut self, i: u32, j: u32) {
45        let root_i = self.find(i);
46        let root_j = self.find(j);
47        if root_i != root_j {
48            match self.rank[root_i as usize].cmp(&self.rank[root_j as usize]) {
49                std::cmp::Ordering::Less => self.parent[root_i as usize] = root_j,
50                std::cmp::Ordering::Greater => self.parent[root_j as usize] = root_i,
51                std::cmp::Ordering::Equal => {
52                    self.parent[root_i as usize] = root_j;
53                    self.rank[root_j as usize] += 1;
54                },
55            }
56        }
57    }
58}
59
60/// Bounding box and statistics for a connected component.
61#[derive(Clone, Copy, Debug)]
62pub struct ComponentStats {
63    /// Minimum x coordinate.
64    pub min_x: u16,
65    /// Maximum x coordinate.
66    pub max_x: u16,
67    /// Minimum y coordinate.
68    pub min_y: u16,
69    /// Maximum y coordinate.
70    pub max_y: u16,
71    /// Total number of pixels in the component.
72    pub pixel_count: u32,
73    /// First encountered pixel X (for boundary start).
74    pub first_pixel_x: u16,
75    /// First encountered pixel Y (for boundary start).
76    pub first_pixel_y: u16,
77}
78
79impl Default for ComponentStats {
80    fn default() -> Self {
81        Self {
82            min_x: u16::MAX,
83            max_x: 0,
84            min_y: u16::MAX,
85            max_y: 0,
86            pixel_count: 0,
87            first_pixel_x: 0,
88            first_pixel_y: 0,
89        }
90    }
91}
92
93/// Result of connected component labeling.
94pub struct LabelResult<'a> {
95    /// Flat array of pixel labels (row-major).
96    pub labels: &'a [u32],
97    /// Statistics for each component (indexed by label - 1).
98    pub component_stats: Vec<ComponentStats>,
99}
100
101/// A detected run of background pixels in a row.
102#[derive(Clone, Copy, Debug)]
103struct Run {
104    y: u32,
105    x_start: u32,
106    x_end: u32,
107    id: u32,
108}
109
110/// Label connected components in a binary image.
111pub fn label_components<'a>(
112    arena: &'a Bump,
113    binary: &[u8],
114    width: usize,
115    height: usize,
116    use_8_connectivity: bool,
117) -> &'a [u32] {
118    label_components_with_stats(arena, binary, width, height, use_8_connectivity).labels
119}
120
121/// Label components and compute bounding box stats for each.
122#[allow(clippy::too_many_lines)]
123pub fn label_components_with_stats<'a>(
124    arena: &'a Bump,
125    binary: &[u8],
126    width: usize,
127    height: usize,
128    use_8_connectivity: bool,
129) -> LabelResult<'a> {
130    // Pass 1: Extract runs - Optimized with Rayon parallel processing
131    let all_runs: Vec<Vec<Run>> = binary
132        .par_chunks(width)
133        .enumerate()
134        .map(|(y, row)| {
135            let mut row_runs = Vec::with_capacity(width / 4 + 1);
136            let mut x = 0;
137            while x < width {
138                // Find start of background run (0)
139                if let Some(pos) = row[x..].iter().position(|&p| p == 0) {
140                    let start = x + pos;
141                    // Find end of run
142                    let len = row[start..].iter().take_while(|&&p| p == 0).count();
143                    row_runs.push(Run {
144                        y: y as u32,
145                        x_start: start as u32,
146                        x_end: (start + len - 1) as u32,
147                        id: 0,
148                    });
149                    x = start + len;
150                } else {
151                    break;
152                }
153            }
154            row_runs
155        })
156        .collect();
157
158    let total_runs: usize = all_runs.iter().map(std::vec::Vec::len).sum();
159    let mut runs: BumpVec<Run> = BumpVec::with_capacity_in(total_runs, arena);
160    for (id, mut run) in all_runs.into_iter().flatten().enumerate() {
161        run.id = id as u32;
162        runs.push(run);
163    }
164
165    if runs.is_empty() {
166        return LabelResult {
167            labels: arena.alloc_slice_fill_copy(width * height, 0u32),
168            component_stats: Vec::new(),
169        };
170    }
171
172    let mut uf = UnionFind::new_in(arena, runs.len());
173    let mut curr_row_range = 0..0; // Initialize curr_row_range
174    let mut i = 0;
175
176    // Pass 2: Link runs between adjacent rows using two-pointer linear scan
177    while i < runs.len() {
178        let y = runs[i].y;
179
180        // Identify the range of runs in the current row
181        let start = i;
182        while i < runs.len() && runs[i].y == y {
183            i += 1;
184        }
185        let prev_row_range = curr_row_range; // Now correctly uses the previously assigned curr_row_range
186        curr_row_range = start..i;
187
188        if y > 0 && !prev_row_range.is_empty() && runs[prev_row_range.start].y == y - 1 {
189            let mut p_idx = prev_row_range.start;
190            for c_idx in curr_row_range.clone() {
191                let curr = &runs[c_idx];
192
193                if use_8_connectivity {
194                    // 8-connectivity check: overlap if [xs1, xe1] and [xs2-1, xe2+1] intersect
195                    while p_idx < prev_row_range.end && runs[p_idx].x_end + 1 < curr.x_start {
196                        p_idx += 1;
197                    }
198                    let mut temp_p = p_idx;
199                    while temp_p < prev_row_range.end && runs[temp_p].x_start <= curr.x_end + 1 {
200                        uf.union(curr.id, runs[temp_p].id);
201                        temp_p += 1;
202                    }
203                } else {
204                    // 4-connectivity check: overlap if [xs1, xe1] and [xs2, xe2] intersect
205                    while p_idx < prev_row_range.end && runs[p_idx].x_end < curr.x_start {
206                        p_idx += 1;
207                    }
208                    let mut temp_p = p_idx;
209                    while temp_p < prev_row_range.end && runs[temp_p].x_start <= curr.x_end {
210                        uf.union(curr.id, runs[temp_p].id);
211                        temp_p += 1;
212                    }
213                }
214            }
215        }
216    }
217
218    // Pass 3: Collect stats per root and assign labels
219    let mut root_to_label: Vec<u32> = vec![0; runs.len()];
220    let mut component_stats: Vec<ComponentStats> = Vec::new();
221    let mut next_label = 1u32;
222
223    // Pre-resolve roots to avoid repeated find() in Pass 4
224    let mut run_roots = Vec::with_capacity(runs.len());
225
226    for run in &runs {
227        let root = uf.find(run.id) as usize;
228        run_roots.push(root);
229        if root_to_label[root] == 0 {
230            root_to_label[root] = next_label;
231            next_label += 1;
232            let new_stat = ComponentStats {
233                first_pixel_x: run.x_start as u16,
234                first_pixel_y: run.y as u16,
235                ..Default::default()
236            };
237            component_stats.push(new_stat);
238        }
239        let label_idx = (root_to_label[root] - 1) as usize;
240        let stats = &mut component_stats[label_idx];
241        stats.min_x = stats.min_x.min(run.x_start as u16);
242        stats.max_x = stats.max_x.max(run.x_end as u16);
243        stats.min_y = stats.min_y.min(run.y as u16);
244        stats.max_y = stats.max_y.max(run.y as u16);
245        stats.pixel_count += run.x_end - run.x_start + 1;
246    }
247
248    // Pass 4: Assign labels to pixels - Optimized with slice fill
249    let labels = arena.alloc_slice_fill_copy(width * height, 0u32);
250    for (run, root) in runs.iter().zip(run_roots) {
251        let label = root_to_label[root];
252        let row_off = run.y as usize * width;
253        labels[row_off + run.x_start as usize..=row_off + run.x_end as usize].fill(label);
254    }
255
256    LabelResult {
257        labels,
258        component_stats,
259    }
260}
261
262fn parse_mask_into_runs(mask: u32, bits: usize, x_offset: usize, y: u32, row_runs: &mut Vec<Run>) {
263    let mut mask = mask;
264    while mask != 0 {
265        let start = mask.trailing_zeros() as usize;
266        if start >= bits {
267            break;
268        }
269
270        // Find end of run: first 0 after the 1s
271        let inverted_mask = !mask >> start;
272        let run_len = inverted_mask.trailing_zeros() as usize;
273        let end = (start + run_len - 1).min(bits - 1);
274
275        if let Some(last) = row_runs.last_mut()
276            && last.x_end == (x_offset + start) as u32 - 1
277        {
278            last.x_end = (x_offset + end) as u32;
279        } else {
280            row_runs.push(Run {
281                y,
282                x_start: (x_offset + start) as u32,
283                x_end: (x_offset + end) as u32,
284                id: 0,
285            });
286        }
287
288        // Clear the processed bits
289        if start + run_len >= 32 {
290            mask = 0;
291        } else {
292            mask &= !((1 << (start + run_len)) - 1);
293        }
294    }
295}
296
297#[cfg(target_arch = "x86_64")]
298#[target_feature(enable = "avx2")]
299unsafe fn extract_runs_row_avx2(
300    row_gs: &[u8],
301    row_th: &[u8],
302    width: usize,
303    y: u32,
304    margin: i16,
305    row_runs: &mut Vec<Run>,
306) -> usize {
307    use std::arch::x86_64::{
308        __m128i, _mm_loadu_si128, _mm_movemask_epi8, _mm256_castsi256_si128, _mm256_cmpgt_epi16,
309        _mm256_cvtepu8_epi16, _mm256_extracti128_si256, _mm256_set1_epi16, _mm256_sub_epi16,
310    };
311    let m_vec = _mm256_set1_epi16(-margin);
312    let mut x = 0;
313    while x + 16 <= width {
314        let (gs_low, th_low) = unsafe {
315            let gs_ptr = row_gs.as_ptr().add(x);
316            let th_ptr = row_th.as_ptr().add(x);
317            #[allow(clippy::cast_ptr_alignment)]
318            (
319                _mm_loadu_si128(gs_ptr.cast::<__m128i>()),
320                _mm_loadu_si128(th_ptr.cast::<__m128i>()),
321            )
322        };
323        let (mask_low, mask_high) = {
324            let gs_16 = _mm256_cvtepu8_epi16(gs_low);
325            let th_16 = _mm256_cvtepu8_epi16(th_low);
326            let diff = _mm256_sub_epi16(gs_16, th_16);
327            let cmp = _mm256_cmpgt_epi16(m_vec, diff);
328            (
329                _mm_movemask_epi8(_mm256_castsi256_si128(cmp)),
330                _mm_movemask_epi8(_mm256_extracti128_si256(cmp, 1)),
331            )
332        };
333        let mut final_mask = 0u32;
334        for i in 0..8 {
335            if (mask_low >> (i * 2)) & 1 != 0 {
336                final_mask |= 1 << i;
337            }
338            if (mask_high >> (i * 2)) & 1 != 0 {
339                final_mask |= 1 << (i + 8);
340            }
341        }
342        if final_mask != 0 {
343            parse_mask_into_runs(final_mask, 16, x, y, row_runs);
344        }
345        x += 16;
346    }
347    x
348}
349
350#[cfg(target_arch = "aarch64")]
351#[target_feature(enable = "neon")]
352unsafe fn extract_runs_row_neon(
353    row_gs: &[u8],
354    row_th: &[u8],
355    width: usize,
356    y: u32,
357    margin: i16,
358    row_runs: &mut Vec<Run>,
359) -> usize {
360    use std::arch::aarch64::*;
361    let m_vec = vdupq_n_s16(-margin);
362    let mut x = 0;
363    while x + 8 <= width {
364        let (gs_8, th_8) = unsafe {
365            let gs_ptr = row_gs.as_ptr().add(x);
366            let th_ptr = row_th.as_ptr().add(x);
367            (vld1_u8(gs_ptr), vld1_u8(th_ptr))
368        };
369        let mask_res = {
370            let gs_16 = vreinterpretq_s16_u16(vmovl_u8(gs_8));
371            let th_16 = vreinterpretq_s16_u16(vmovl_u8(th_8));
372            let diff = vsubq_s16(gs_16, th_16);
373            vcltq_s16(diff, m_vec)
374        };
375        let mut final_mask = 0u32;
376        let res_u16: [u16; 8] = std::mem::transmute(mask_res);
377        for (i, &val) in res_u16.iter().enumerate() {
378            if val != 0 {
379                final_mask |= 1 << i;
380            }
381        }
382        if final_mask != 0 {
383            parse_mask_into_runs(final_mask, 8, x, y, row_runs);
384        }
385        x += 8;
386    }
387    x
388}
389
390/// Threshold-model-aware connected component labeling.
391#[multiversion(targets(
392    "x86_64+avx2+bmi1+bmi2+popcnt+lzcnt",
393    "x86_64+avx512f+avx512bw+avx512dq+avx512vl",
394    "aarch64+neon"
395))]
396#[allow(clippy::too_many_lines)]
397#[allow(clippy::cast_sign_loss)]
398#[allow(clippy::cast_possible_wrap, clippy::too_many_arguments)]
399#[tracing::instrument(skip_all, name = "pipeline::segmentation")]
400pub fn label_components_threshold_model<'a>(
401    arena: &'a Bump,
402    grayscale: &[u8],
403    grayscale_stride: usize,
404    threshold_map: &[u8],
405    width: usize,
406    height: usize,
407    use_8_connectivity: bool,
408    min_area: u32,
409    margin: i16,
410) -> LabelResult<'a> {
411    // Pass 1: Extract runs of "consistently dark" pixels
412    let all_runs: Vec<Vec<Run>> = (0..height)
413        .into_par_iter()
414        .map(|y| {
415            let row_gs = &grayscale[y * grayscale_stride..y * grayscale_stride + width];
416            let row_th = &threshold_map[y * width..(y + 1) * width];
417            let mut row_runs = Vec::with_capacity(width / 4 + 1);
418            let mut x = 0;
419
420            #[cfg(target_arch = "x86_64")]
421            if std::is_x86_feature_detected!("avx2") {
422                x = unsafe {
423                    extract_runs_row_avx2(row_gs, row_th, width, y as u32, margin, &mut row_runs)
424                };
425            }
426
427            #[cfg(target_arch = "aarch64")]
428            {
429                // NEON is always available on aarch64
430                x = unsafe {
431                    extract_runs_row_neon(row_gs, row_th, width, y as u32, margin, &mut row_runs)
432                };
433            }
434            // Scalar tail
435            while x < width {
436                let gs = row_gs[x];
437                let th = row_th[x];
438                if i16::from(gs) - i16::from(th) < -margin {
439                    let start = x;
440                    x += 1;
441                    while x < width && i16::from(row_gs[x]) - i16::from(row_th[x]) < -margin {
442                        x += 1;
443                    }
444                    row_runs.push(Run {
445                        y: y as u32,
446                        x_start: start as u32,
447                        x_end: (x - 1) as u32,
448                        id: 0,
449                    });
450                } else {
451                    x += 1;
452                }
453            }
454            row_runs
455        })
456        .collect();
457
458    let total_runs: usize = all_runs.iter().map(std::vec::Vec::len).sum();
459    let mut runs = BumpVec::with_capacity_in(total_runs, arena);
460    for (id, mut run) in all_runs.into_iter().flatten().enumerate() {
461        run.id = id as u32;
462        runs.push(run);
463    }
464
465    if runs.is_empty() {
466        return LabelResult {
467            labels: arena.alloc_slice_fill_copy(width * height, 0u32),
468            component_stats: Vec::new(),
469        };
470    }
471
472    // Pass 2-4 are the same as label_components_with_stats
473    let mut uf = UnionFind::new_in(arena, runs.len());
474    let mut curr_row_range = 0..0;
475    let mut i = 0;
476
477    while i < runs.len() {
478        let y = runs[i].y;
479        let start = i;
480        while i < runs.len() && runs[i].y == y {
481            i += 1;
482        }
483        let prev_row_range = curr_row_range;
484        curr_row_range = start..i;
485
486        if y > 0 && !prev_row_range.is_empty() && runs[prev_row_range.start].y == y - 1 {
487            let mut p_idx = prev_row_range.start;
488            for c_idx in curr_row_range.clone() {
489                let curr = &runs[c_idx];
490                if use_8_connectivity {
491                    while p_idx < prev_row_range.end && runs[p_idx].x_end + 1 < curr.x_start {
492                        p_idx += 1;
493                    }
494                    let mut temp_p = p_idx;
495                    while temp_p < prev_row_range.end && runs[temp_p].x_start <= curr.x_end + 1 {
496                        uf.union(curr.id, runs[temp_p].id);
497                        temp_p += 1;
498                    }
499                } else {
500                    while p_idx < prev_row_range.end && runs[p_idx].x_end < curr.x_start {
501                        p_idx += 1;
502                    }
503                    let mut temp_p = p_idx;
504                    while temp_p < prev_row_range.end && runs[temp_p].x_start <= curr.x_end {
505                        uf.union(curr.id, runs[temp_p].id);
506                        temp_p += 1;
507                    }
508                }
509            }
510        }
511    }
512
513    // Pass 3: Aggregate stats for ALL potential components
514    let mut root_to_temp_idx = vec![usize::MAX; runs.len()];
515    let mut temp_stats = Vec::new();
516
517    for run in &runs {
518        let root = uf.find(run.id) as usize;
519        if root_to_temp_idx[root] == usize::MAX {
520            root_to_temp_idx[root] = temp_stats.len();
521            temp_stats.push(ComponentStats {
522                first_pixel_x: run.x_start as u16,
523                first_pixel_y: run.y as u16,
524                ..Default::default()
525            });
526        }
527        let s_idx = root_to_temp_idx[root];
528        let stats = &mut temp_stats[s_idx];
529        stats.min_x = stats.min_x.min(run.x_start as u16);
530        stats.max_x = stats.max_x.max(run.x_end as u16);
531        stats.min_y = stats.min_y.min(run.y as u16);
532        stats.max_y = stats.max_y.max(run.y as u16);
533        stats.pixel_count += run.x_end - run.x_start + 1;
534    }
535
536    // Pass 4: Filter by area and assign final labels
537    let mut component_stats = Vec::with_capacity(temp_stats.len());
538    let mut root_to_final_label = vec![0u32; runs.len()];
539    let mut next_label = 1u32;
540
541    for root in 0..runs.len() {
542        let s_idx = root_to_temp_idx[root];
543        if s_idx != usize::MAX {
544            let s = temp_stats[s_idx];
545            if s.pixel_count >= min_area {
546                component_stats.push(s);
547                root_to_final_label[root] = next_label;
548                next_label += 1;
549            }
550        }
551    }
552
553    // Pass 5: Parallel label writing for surviving components
554    let labels = arena.alloc_slice_fill_copy(width * height, 0u32);
555    let mut runs_by_y: Vec<Vec<(usize, usize, u32)>> = vec![Vec::new(); height];
556
557    for run in &runs {
558        let root = uf.find(run.id) as usize;
559        let label = root_to_final_label[root];
560        if label > 0 {
561            runs_by_y[run.y as usize].push((run.x_start as usize, run.x_end as usize, label));
562        }
563    }
564
565    labels
566        .par_chunks_exact_mut(width)
567        .enumerate()
568        .for_each(|(y, row)| {
569            for &(x_start, x_end, label) in &runs_by_y[y] {
570                row[x_start..=x_end].fill(label);
571            }
572        });
573
574    LabelResult {
575        labels,
576        component_stats,
577    }
578}
579
580#[cfg(test)]
581#[allow(clippy::unwrap_used, clippy::cast_sign_loss)]
582mod tests {
583    use super::*;
584    use bumpalo::Bump;
585    use proptest::prelude::prop;
586    use proptest::proptest;
587
588    #[test]
589    fn test_union_find() {
590        let arena = Bump::new();
591        let mut uf = UnionFind::new_in(&arena, 10);
592
593        uf.union(1, 2);
594        uf.union(2, 3);
595        uf.union(5, 6);
596
597        assert_eq!(uf.find(1), uf.find(3));
598        assert_eq!(uf.find(1), uf.find(2));
599        assert_ne!(uf.find(1), uf.find(5));
600
601        uf.union(3, 5);
602        assert_eq!(uf.find(1), uf.find(6));
603    }
604
605    #[test]
606    fn test_label_components_simple() {
607        let arena = Bump::new();
608        // 6x6 image with two separate 2x2 squares that are NOT 8-connected.
609        // 0 = background (black), 255 = foreground (white)
610        // Tag detector looks for black components (0)
611        let binary = [
612            0, 0, 255, 255, 255, 255, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
613            255, 255, 0, 0, 255, 255, 255, 255, 0, 0, 255, 255, 255, 255, 255, 255, 255,
614        ];
615        let width = 6;
616        let height = 6;
617
618        let result = label_components_with_stats(&arena, &binary, width, height, false);
619
620        assert_eq!(result.component_stats.len(), 2);
621
622        // Component 1 (top-left)
623        let s1 = result.component_stats[0];
624        assert_eq!(s1.pixel_count, 4);
625        assert_eq!(s1.min_x, 0);
626        assert_eq!(s1.max_x, 1);
627        assert_eq!(s1.min_y, 0);
628        assert_eq!(s1.max_y, 1);
629
630        // Component 2 (middle-rightish)
631        let s2 = result.component_stats[1];
632        assert_eq!(s2.pixel_count, 4);
633        assert_eq!(s2.min_x, 3);
634        assert_eq!(s2.max_x, 4);
635        assert_eq!(s2.min_y, 3);
636        assert_eq!(s2.max_y, 4);
637    }
638
639    #[test]
640    fn test_segmentation_with_decimation() {
641        let arena = Bump::new();
642        let width = 32;
643        let height = 32;
644        let mut binary = vec![255u8; width * height];
645        // Draw a 10x10 black square at (10,10)
646        for y in 10..20 {
647            for x in 10..20 {
648                binary[y * width + x] = 0;
649            }
650        }
651
652        // use statement moved to top of module or block
653        let img =
654            crate::image::ImageView::new(&binary, width, height, width).expect("valid creation");
655
656        // Decimate by 2 -> 16x16
657        let mut decimated_data = vec![0u8; 16 * 16];
658        let decimated_img = img
659            .decimate_to(2, &mut decimated_data)
660            .expect("decimation failed");
661        // In decimated image, square should be roughly at (5,5) with size 5x5
662        let result = label_components_with_stats(&arena, decimated_img.data, 16, 16, true);
663
664        assert_eq!(result.component_stats.len(), 1);
665        let s = result.component_stats[0];
666        assert_eq!(s.pixel_count, 25);
667        assert_eq!(s.min_x, 5);
668        assert_eq!(s.max_x, 9);
669        assert_eq!(s.min_y, 5);
670        assert_eq!(s.max_y, 9);
671    }
672
673    proptest! {
674        #[test]
675        fn prop_union_find_reflexivity(size in 1..1000usize) {
676            let arena = Bump::new();
677            let mut uf = UnionFind::new_in(&arena, size);
678            for i in 0..size as u32 {
679                assert_eq!(uf.find(i), i);
680            }
681        }
682
683        #[test]
684        fn prop_union_find_transitivity(size in 1..1000usize, pairs in prop::collection::vec((0..1000u32, 0..1000u32), 0..100)) {
685            let arena = Bump::new();
686            let real_size = size.max(1001); // Ensure indices are in range
687            let mut uf = UnionFind::new_in(&arena, real_size);
688
689            for (a, b) in pairs {
690                let a = a % real_size as u32;
691                let b = b % real_size as u32;
692                uf.union(a, b);
693                assert_eq!(uf.find(a), uf.find(b));
694            }
695        }
696
697        #[test]
698        fn prop_label_components_no_panic(
699            width in 1..64usize,
700            height in 1..64usize,
701            data in prop::collection::vec(0..=1u8, 64 * 64)
702        ) {
703            let arena = Bump::new();
704            let binary: Vec<u8> = data.iter().map(|&b| if b == 0 { 0 } else { 255 }).collect();
705            let real_width = width.min(64);
706            let real_height = height.min(64);
707            let slice = &binary[..real_width * real_height];
708
709            let result = label_components_with_stats(&arena, slice, real_width, real_height, true);
710
711            for stat in result.component_stats {
712                assert!(stat.pixel_count > 0);
713                assert!(stat.max_x < real_width as u16);
714                assert!(stat.max_y < real_height as u16);
715                assert!(stat.min_x <= stat.max_x);
716                assert!(stat.min_y <= stat.max_y);
717            }
718        }
719    }
720
721    // ========================================================================
722    // SEGMENTATION ROBUSTNESS TESTS
723    // ========================================================================
724
725    use crate::config::TagFamily;
726    use crate::image::ImageView;
727    use crate::test_utils::{TestImageParams, generate_test_image_with_params};
728    use crate::threshold::ThresholdEngine;
729
730    /// Helper: Generate a binarized tag image at the given size.
731    fn generate_binarized_tag(tag_size: usize, canvas_size: usize) -> (Vec<u8>, [[f64; 2]; 4]) {
732        let params = TestImageParams {
733            family: TagFamily::AprilTag36h11,
734            id: 0,
735            tag_size,
736            canvas_size,
737            ..Default::default()
738        };
739
740        let (data, corners) = generate_test_image_with_params(&params);
741        let img = ImageView::new(&data, canvas_size, canvas_size, canvas_size).unwrap();
742
743        let arena = Bump::new();
744        let engine = ThresholdEngine::new();
745        let stats = engine.compute_tile_stats(&arena, &img);
746        let mut binary = vec![0u8; canvas_size * canvas_size];
747        engine.apply_threshold(&arena, &img, &stats, &mut binary);
748
749        (binary, corners)
750    }
751
752    /// Test segmentation at varying tag sizes.
753    #[test]
754    fn test_segmentation_at_varying_tag_sizes() {
755        let canvas_size = 640;
756        let tag_sizes = [32, 64, 100, 200, 300];
757
758        for tag_size in tag_sizes {
759            let arena = Bump::new();
760            let (binary, corners) = generate_binarized_tag(tag_size, canvas_size);
761
762            let result =
763                label_components_with_stats(&arena, &binary, canvas_size, canvas_size, true);
764
765            assert!(
766                !result.component_stats.is_empty(),
767                "Tag size {tag_size}: No components found"
768            );
769
770            let largest = result
771                .component_stats
772                .iter()
773                .max_by_key(|s| s.pixel_count)
774                .unwrap();
775
776            let expected_min_x = corners[0][0] as u16;
777            let expected_max_x = corners[1][0] as u16;
778            let tolerance = 5;
779
780            assert!(
781                (i32::from(largest.min_x) - i32::from(expected_min_x)).abs() <= tolerance,
782                "Tag size {tag_size}: min_x mismatch"
783            );
784            assert!(
785                (i32::from(largest.max_x) - i32::from(expected_max_x)).abs() <= tolerance,
786                "Tag size {tag_size}: max_x mismatch"
787            );
788
789            println!(
790                "Tag size {:>3}px: {} components, largest has {} px",
791                tag_size,
792                result.component_stats.len(),
793                largest.pixel_count
794            );
795        }
796    }
797
798    /// Test component pixel counts are reasonable for clean binarization.
799    #[test]
800    fn test_segmentation_component_accuracy() {
801        let canvas_size = 320;
802        let tag_size = 120;
803
804        let arena = Bump::new();
805        let (binary, corners) = generate_binarized_tag(tag_size, canvas_size);
806
807        let result = label_components_with_stats(&arena, &binary, canvas_size, canvas_size, true);
808
809        let largest = result
810            .component_stats
811            .iter()
812            .max_by_key(|s| s.pixel_count)
813            .unwrap();
814
815        let expected_min = (tag_size * tag_size / 3) as u32;
816        let expected_max = (tag_size * tag_size) as u32;
817
818        assert!(largest.pixel_count >= expected_min);
819        assert!(largest.pixel_count <= expected_max);
820
821        let gt_width = (corners[1][0] - corners[0][0]).abs() as i32;
822        let gt_height = (corners[2][1] - corners[0][1]).abs() as i32;
823        let bbox_width = i32::from(largest.max_x - largest.min_x);
824        let bbox_height = i32::from(largest.max_y - largest.min_y);
825
826        assert!((bbox_width - gt_width).abs() <= 2);
827        assert!((bbox_height - gt_height).abs() <= 2);
828
829        println!(
830            "Component accuracy: {} pixels, bbox={}x{} (GT: {}x{})",
831            largest.pixel_count, bbox_width, bbox_height, gt_width, gt_height
832        );
833    }
834
835    /// Test segmentation with noisy binary boundaries.
836    #[test]
837    fn test_segmentation_noisy_boundaries() {
838        let canvas_size = 320;
839        let tag_size = 120;
840
841        let arena = Bump::new();
842        let (mut binary, _corners) = generate_binarized_tag(tag_size, canvas_size);
843
844        let noise_rate = 0.05;
845
846        for y in 1..(canvas_size - 1) {
847            for x in 1..(canvas_size - 1) {
848                let idx = y * canvas_size + x;
849                let current = binary[idx];
850                let left = binary[idx - 1];
851                let right = binary[idx + 1];
852                let up = binary[idx - canvas_size];
853                let down = binary[idx + canvas_size];
854
855                let is_edge =
856                    current != left || current != right || current != up || current != down;
857                if is_edge && rand::random_range(0.0..1.0_f32) < noise_rate {
858                    binary[idx] = if current == 0 { 255 } else { 0 };
859                }
860            }
861        }
862
863        let result = label_components_with_stats(&arena, &binary, canvas_size, canvas_size, true);
864
865        assert!(!result.component_stats.is_empty());
866
867        let largest = result
868            .component_stats
869            .iter()
870            .max_by_key(|s| s.pixel_count)
871            .unwrap();
872
873        let min_expected = (tag_size * tag_size / 4) as u32;
874        assert!(largest.pixel_count >= min_expected);
875
876        println!(
877            "Noisy segmentation: {} components, largest has {} px",
878            result.component_stats.len(),
879            largest.pixel_count
880        );
881    }
882
883    #[test]
884    fn test_segmentation_correctness_large_image() {
885        let arena = Bump::new();
886        let width = 3840; // 4K width
887        let height = 2160; // 4K height
888        let mut binary = vec![255u8; width * height];
889
890        // Create multiple separate components
891        // 1. A square at the top left
892        for y in 100..200 {
893            for x in 100..200 {
894                binary[y * width + x] = 0;
895            }
896        }
897
898        // 2. A long horizontal strip in the middle
899        for x in 500..3000 {
900            binary[1000 * width + x] = 0;
901            binary[1001 * width + x] = 0;
902        }
903
904        // 3. Horizontal stripes at the bottom
905        for y in 1800..2000 {
906            if y % 4 == 0 {
907                for x in 1800..2000 {
908                    binary[y * width + x] = 0;
909                }
910            }
911        }
912
913        // 4. Noise (avoiding the square area)
914        for y in 0..height {
915            if y % 10 == 0 {
916                for x in 0..width {
917                    if (!(100..200).contains(&x) || !(100..200).contains(&y)) && (x + y) % 31 == 0 {
918                        binary[y * width + x] = 0;
919                    }
920                }
921            }
922        }
923
924        // Run segmentation
925        let start = std::time::Instant::now();
926        let result = label_components_with_stats(&arena, &binary, width, height, true);
927        let duration = start.elapsed();
928
929        // Basic verification of component counts
930        assert!(result.component_stats.len() > 1000);
931
932        println!(
933            "Found {} components on 4K image in {:?}",
934            result.component_stats.len(),
935            duration
936        );
937    }
938}