Skip to main content

locus_core/simd_ccl_fusion/
mod.rs

1//! SIMD-Accelerated Connected Components Labeling (CCL) with Fused Thresholding.
2//!
3//! This module implements a high-performance segmentation pipeline that defeats the "memory wall"
4//! by fusing adaptive thresholding with Run-Length Encoding (RLE) extraction. It processes images
5//! in 1D segments rather than individual pixels, drastically reducing memory bandwidth requirements
6//! and branch mispredictions.
7
8use crate::image::ImageView;
9use crate::segmentation::{ComponentStats, LabelResult, UnionFind};
10use bumpalo::Bump;
11
12/// A 1D Run-Length Encoded (RLE) segment representing contiguous foreground pixels.
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub struct RleSegment {
15    /// The row (Y coordinate) of the segment.
16    pub y: u16,
17    /// The starting column (X coordinate) of the segment (inclusive).
18    pub start_x: u16,
19    /// The ending column (X coordinate) of the segment (exclusive).
20    pub end_x: u16,
21    /// The component label assigned to this segment. Uninitialized/unassigned is 0.
22    pub label: u32,
23}
24
25impl RleSegment {
26    /// Create a new RleSegment with an unassigned label.
27    #[must_use]
28    pub const fn new(y: u16, start_x: u16, end_x: u16) -> Self {
29        Self {
30            y,
31            start_x,
32            end_x,
33            label: 0,
34        }
35    }
36}
37
38/// Architecture-specific SIMD scanline processing.
39pub mod simd_scanline;
40
41/// Baseline scalar implementation of the fused Threshold + RLE extraction kernel.
42/// Extracts runs of black pixels (value < threshold) from the image.
43#[allow(dead_code)]
44#[must_use]
45pub fn extract_rle_segments_scalar(img: &ImageView, threshold_map: &[u8]) -> Vec<RleSegment> {
46    let mut segments = Vec::new();
47    let height = img.height as u16;
48
49    for y in 0..height {
50        let y_usize = y as usize;
51        let src_row = img.get_row(y_usize);
52        let thresh_row = &threshold_map[y_usize * img.width..(y_usize + 1) * img.width];
53
54        process_row_scalar(src_row, thresh_row, y, &mut segments);
55    }
56
57    segments
58}
59
60/// Scalar fallback for a single row.
61pub fn process_row_scalar(
62    src_row: &[u8],
63    thresh_row: &[u8],
64    y: u16,
65    segments: &mut Vec<RleSegment>,
66) {
67    let mut in_segment = false;
68    let mut start_x = 0;
69    let width = src_row.len();
70
71    for (x, (&s, &t)) in src_row.iter().zip(thresh_row.iter()).enumerate() {
72        let is_foreground = s < t;
73
74        if is_foreground && !in_segment {
75            in_segment = true;
76            start_x = x as u16;
77        } else if !is_foreground && in_segment {
78            in_segment = false;
79            segments.push(RleSegment::new(y, start_x, x as u16));
80        }
81    }
82
83    if in_segment {
84        segments.push(RleSegment::new(y, start_x, width as u16));
85    }
86}
87
88/// Performs Light-Speed Labeling (LSL) / Run-based Union-Find on the extracted RLE segments.
89/// Fully resolves equivalences and outputs the 2D label map expected by the rest of the pipeline.
90#[allow(clippy::too_many_lines)]
91pub fn label_components_lsl<'a>(
92    arena: &'a Bump,
93    img: &ImageView,
94    threshold_map: &[u8],
95    use_8_connectivity: bool,
96    min_area: u32,
97) -> LabelResult<'a> {
98    let mut runs = simd_scanline::extract_rle_segments(img, threshold_map);
99
100    // Assign consecutive IDs for Union-Find using the label field
101    for (id, run) in runs.iter_mut().enumerate() {
102        run.label = id as u32;
103    }
104
105    if runs.is_empty() {
106        return LabelResult {
107            labels: arena.alloc_slice_fill_copy(img.width * img.height, 0u32),
108            component_stats: Vec::new(),
109        };
110    }
111
112    let mut uf = UnionFind::new_in(arena, runs.len());
113    let mut curr_row_range = 0..0;
114    let mut i = 0;
115
116    while i < runs.len() {
117        let y = runs[i].y;
118        let start = i;
119        while i < runs.len() && runs[i].y == y {
120            i += 1;
121        }
122        let prev_row_range = curr_row_range;
123        curr_row_range = start..i;
124
125        if y > 0 && !prev_row_range.is_empty() && runs[prev_row_range.start].y == y - 1 {
126            let mut p_idx = prev_row_range.start;
127            for c_idx in curr_row_range.clone() {
128                let curr = &runs[c_idx];
129                if use_8_connectivity {
130                    // 8-connectivity: [start_x, end_x)
131                    // overlap diagonally if prev.end_x >= curr.start_x and prev.start_x <= curr.end_x
132                    while p_idx < prev_row_range.end && runs[p_idx].end_x < curr.start_x {
133                        p_idx += 1;
134                    }
135                    let mut temp_p = p_idx;
136                    while temp_p < prev_row_range.end && runs[temp_p].start_x <= curr.end_x {
137                        uf.union(curr.label, runs[temp_p].label);
138                        temp_p += 1;
139                    }
140                } else {
141                    // 4-connectivity
142                    while p_idx < prev_row_range.end && runs[p_idx].end_x <= curr.start_x {
143                        p_idx += 1;
144                    }
145                    let mut temp_p = p_idx;
146                    while temp_p < prev_row_range.end && runs[temp_p].start_x < curr.end_x {
147                        uf.union(curr.label, runs[temp_p].label);
148                        temp_p += 1;
149                    }
150                }
151            }
152        }
153    }
154
155    // Collect stats per root and assign labels
156    let mut root_to_temp_idx = vec![usize::MAX; runs.len()];
157    let mut temp_stats = Vec::new();
158
159    for run in &runs {
160        let root = uf.find(run.label) as usize;
161        if root_to_temp_idx[root] == usize::MAX {
162            root_to_temp_idx[root] = temp_stats.len();
163            temp_stats.push(ComponentStats {
164                first_pixel_x: run.start_x,
165                first_pixel_y: run.y,
166                ..ComponentStats::default()
167            });
168        }
169        let s_idx = root_to_temp_idx[root];
170        let stats = &mut temp_stats[s_idx];
171        stats.min_x = stats.min_x.min(run.start_x);
172        stats.max_x = stats.max_x.max(run.end_x - 1);
173        stats.min_y = stats.min_y.min(run.y);
174        stats.max_y = stats.max_y.max(run.y);
175        stats.pixel_count += u32::from(run.end_x - run.start_x);
176        // Accumulate spatial moments using closed-form per-run sums.
177        // Run covers x in [a, b) exclusive (end_x is exclusive).
178        let a = u64::from(run.start_x);
179        let b = u64::from(run.end_x);
180        let yu = u64::from(run.y);
181        // SAFETY: `a - 1` and `2*a - 1` are always multiplied by `a`, so their
182        // value is irrelevant when a = 0. saturating_sub avoids u64 underflow in
183        // debug builds (release wraps, but the `* a` factor zeros the term anyway).
184        stats.m10 += b * (b - 1) / 2 - a * a.saturating_sub(1) / 2;
185        stats.m01 += yu * (b - a);
186        stats.m20 +=
187            (b - 1) * b * (2 * b - 1) / 6 - a.saturating_sub(1) * a * (2 * a).saturating_sub(1) / 6;
188        stats.m02 += yu * yu * (b - a);
189        stats.m11 += yu * (b * (b - 1) / 2 - a * a.saturating_sub(1) / 2);
190    }
191
192    let mut component_stats = Vec::with_capacity(temp_stats.len());
193    let mut root_to_final_label = vec![0u32; runs.len()];
194    let mut next_label = 1u32;
195
196    for (root, root_to_temp) in root_to_temp_idx.iter().enumerate() {
197        if *root_to_temp != usize::MAX {
198            let s = temp_stats[*root_to_temp];
199            if s.pixel_count >= min_area {
200                component_stats.push(s);
201                root_to_final_label[root] = next_label;
202                next_label += 1;
203            }
204        }
205    }
206
207    let labels = arena.alloc_slice_fill_copy(img.width * img.height, 0u32);
208    let width = img.width;
209
210    for run in &runs {
211        let root = uf.find(run.label) as usize;
212        let label = root_to_final_label[root];
213        if label > 0 {
214            let row_off = run.y as usize * width;
215            labels[row_off + run.start_x as usize..row_off + run.end_x as usize].fill(label);
216        }
217    }
218
219    LabelResult {
220        labels,
221        component_stats,
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_label_components_lsl_red_phase() {
231        let arena = Bump::new();
232        let width = 8;
233        let height = 4;
234
235        // Image with two separated 2x2 blocks:
236        // row 0: . x x . . x x .
237        // row 1: . x x . . x x .
238        // row 2: . . . . . . . .
239        // row 3: . . . . . . . .
240
241        let mut pixels = vec![200u8; width * height];
242        pixels[1] = 50;
243        pixels[2] = 50;
244        pixels[5] = 50;
245        pixels[6] = 50;
246        pixels[8 + 1] = 50;
247        pixels[8 + 2] = 50;
248        pixels[8 + 5] = 50;
249        pixels[8 + 6] = 50;
250
251        let threshold_map = vec![128u8; width * height];
252        let img = ImageView::new(&pixels, width, height, width).expect("Valid test image");
253
254        let result = label_components_lsl(&arena, &img, &threshold_map, true, 1);
255
256        assert_eq!(result.component_stats.len(), 2);
257
258        // Assuming label 1 is left block and label 2 is right block
259        let mut found_labels = std::collections::HashSet::new();
260        found_labels.insert(result.labels[1]);
261        found_labels.insert(result.labels[5]);
262
263        assert_eq!(found_labels.len(), 2);
264        assert!(!found_labels.contains(&0));
265
266        // Verify stats
267        assert_eq!(result.component_stats[0].pixel_count, 4);
268        assert_eq!(result.component_stats[1].pixel_count, 4);
269    }
270
271    #[test]
272    fn test_extract_rle_segments_scalar() {
273        let width = 8;
274        let height = 2;
275        // Image with two black segments on first row, one on second row.
276        // Black is < threshold (e.g. 100 vs 128)
277        let pixels = vec![
278            200, 50, 50, 200, 50, 200, 200, 200, // Row 0: RLE at [1, 3) and [4, 5)
279            50, 50, 50, 50, 200, 200, 200, 200, // Row 1: RLE at [0, 4)
280        ];
281        let threshold_map = vec![128u8; width * height];
282
283        let img = ImageView::new(&pixels, width, height, width).expect("Valid test image");
284        let segments = extract_rle_segments_scalar(&img, &threshold_map);
285
286        assert_eq!(segments.len(), 3);
287        assert_eq!(segments[0], RleSegment::new(0, 1, 3));
288        assert_eq!(segments[1], RleSegment::new(0, 4, 5));
289        assert_eq!(segments[2], RleSegment::new(1, 0, 4));
290    }
291}