Skip to main content

yscv_imgproc/ops/
fast.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6/// A detected keypoint with orientation and scale information.
7#[derive(Debug, Clone)]
8pub struct Keypoint {
9    pub x: f32,
10    pub y: f32,
11    pub response: f32,
12    pub angle: f32,
13    pub octave: usize,
14}
15
16/// Bresenham circle of radius 3: 16 pixel offsets (dx, dy).
17const CIRCLE: [(i32, i32); 16] = [
18    (0, -3),
19    (1, -3),
20    (2, -2),
21    (3, -1),
22    (3, 0),
23    (3, 1),
24    (2, 2),
25    (1, 3),
26    (0, 3),
27    (-1, 3),
28    (-2, 2),
29    (-3, 1),
30    (-3, 0),
31    (-3, -1),
32    (-2, -2),
33    (-1, -3),
34];
35
36/// Precompute circle pixel offsets as flat `isize` offsets from center.
37#[inline]
38fn circle_offsets(w: usize) -> [isize; 16] {
39    let ws = w as isize;
40    let mut offsets = [0isize; 16];
41    for (i, &(dx, dy)) in CIRCLE.iter().enumerate() {
42        offsets[i] = dy as isize * ws + dx as isize;
43    }
44    offsets
45}
46
47/// Check if a 16-bit bitmask has 9 or more contiguous set bits (circular).
48/// Returns the length of the longest contiguous run, or 0 if < 9.
49#[inline]
50fn contiguous_run_from_mask(mask: u32) -> usize {
51    if mask == 0 {
52        return 0;
53    }
54    // Double the bits to handle wrap-around: bits 0..15 repeated at 16..31
55    let doubled = mask | (mask << 16);
56    let mut best = 0u32;
57    let mut run = 0u32;
58    // Only need to scan 32 bits
59    for i in 0..32 {
60        if (doubled >> i) & 1 != 0 {
61            run += 1;
62            if run > best {
63                best = run;
64            }
65        } else {
66            run = 0;
67        }
68    }
69    best.min(16) as usize
70}
71
72/// FAST-9 corner detection on a single-channel `[H, W, 1]` image.
73///
74/// Examines 16 pixels on a circle of radius 3 around each pixel.
75/// A corner exists if 9 contiguous pixels are all brighter or all darker
76/// than the center by at least `threshold`.
77///
78/// If `non_max` is true, non-maximum suppression is applied in a 3x3 neighbourhood.
79///
80/// Uses SIMD to batch the cardinal early-rejection test (4 pixels at a time),
81/// eliminating ~90% of non-corner pixels with minimal work.
82#[allow(unsafe_code)]
83pub fn fast9_detect(
84    image: &Tensor,
85    threshold: f32,
86    non_max: bool,
87) -> Result<Vec<Keypoint>, ImgProcError> {
88    let (h, w, c) = hwc_shape(image)?;
89    if c != 1 {
90        return Err(ImgProcError::InvalidChannelCount {
91            expected: 1,
92            got: c,
93        });
94    }
95    Ok(fast9_detect_raw(image.data(), h, w, threshold, non_max))
96}
97
98/// FAST-9 corner detection on raw f32 data — no Tensor allocation needed.
99pub fn fast9_detect_raw(
100    data: &[f32],
101    h: usize,
102    w: usize,
103    threshold: f32,
104    non_max: bool,
105) -> Vec<Keypoint> {
106    let offsets = circle_offsets(w);
107
108    // Cardinal directions for early rejection: N(0), E(4), S(8), W(12)
109    let card = [offsets[0], offsets[4], offsets[8], offsets[12]];
110
111    let y_start = 3;
112    let y_end = h.saturating_sub(3);
113    let x_start = 3;
114    let x_end = w.saturating_sub(3);
115
116    // Parallel FAST: each row processed independently, results merged.
117    let n_rows = y_end.saturating_sub(y_start);
118    let row_corners: Vec<Vec<Keypoint>> = {
119        use std::sync::Mutex;
120        let results: Vec<Mutex<Vec<Keypoint>>> =
121            (0..n_rows).map(|_| Mutex::new(Vec::new())).collect();
122
123        use super::u8ops::gcd;
124        gcd::parallel_for(n_rows, |row_idx| {
125            let y = y_start + row_idx;
126            let mut row_kps = Vec::new();
127
128            let row_base = y * w;
129            let mut x = x_start;
130
131            #[cfg(target_arch = "aarch64")]
132            if std::arch::is_aarch64_feature_detected!("neon") {
133                while x + 4 <= x_end {
134                    let pass_mask =
135                        unsafe { fast9_cardinal_check_neon(data, row_base + x, &card, threshold) };
136                    if pass_mask == 0 {
137                        x += 4;
138                        continue;
139                    }
140                    for i in 0..4 {
141                        if (pass_mask >> i) & 1 != 0 {
142                            let cx = x + i;
143                            let idx = row_base + cx;
144                            let max_run =
145                                unsafe { fast9_full_check(data, idx, &offsets, threshold) };
146                            if max_run >= 9 {
147                                row_kps.push(Keypoint {
148                                    x: cx as f32,
149                                    y: y as f32,
150                                    response: max_run as f32,
151                                    angle: 0.0,
152                                    octave: 0,
153                                });
154                            }
155                        }
156                    }
157                    x += 4;
158                }
159            }
160
161            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
162            if std::is_x86_feature_detected!("sse") {
163                while x + 4 <= x_end {
164                    let pass_mask =
165                        unsafe { fast9_cardinal_check_sse(data, row_base + x, &card, threshold) };
166                    if pass_mask == 0 {
167                        x += 4;
168                        continue;
169                    }
170                    for i in 0..4 {
171                        if (pass_mask >> i) & 1 != 0 {
172                            let cx = x + i;
173                            let idx = row_base + cx;
174                            let max_run =
175                                unsafe { fast9_full_check(data, idx, &offsets, threshold) };
176                            if max_run >= 9 {
177                                row_kps.push(Keypoint {
178                                    x: cx as f32,
179                                    y: y as f32,
180                                    response: max_run as f32,
181                                    angle: 0.0,
182                                    octave: 0,
183                                });
184                            }
185                        }
186                    }
187                    x += 4;
188                }
189            }
190
191            while x < x_end {
192                let idx = row_base + x;
193                let center = unsafe { *data.get_unchecked(idx) };
194                let bright_thresh = center + threshold;
195                let dark_thresh = center - threshold;
196                let mut bright_count = 0u32;
197                let mut dark_count = 0u32;
198                for &co in &card {
199                    let v = unsafe { *data.get_unchecked((idx as isize + co) as usize) };
200                    bright_count += (v > bright_thresh) as u32;
201                    dark_count += (v < dark_thresh) as u32;
202                }
203                if bright_count < 3 && dark_count < 3 {
204                    x += 1;
205                    continue;
206                }
207                let max_run = unsafe { fast9_full_check(data, idx, &offsets, threshold) };
208                if max_run >= 9 {
209                    row_kps.push(Keypoint {
210                        x: x as f32,
211                        y: y as f32,
212                        response: max_run as f32,
213                        angle: 0.0,
214                        octave: 0,
215                    });
216                }
217                x += 1;
218            }
219
220            *results[row_idx].lock().expect("mutex poisoned") = row_kps;
221        });
222
223        results
224            .into_iter()
225            .map(|m| m.into_inner().expect("mutex poisoned"))
226            .collect()
227    };
228
229    let mut corners: Vec<Keypoint> = row_corners.into_iter().flatten().collect();
230
231    // Dead code: old sequential loop replaced by parallel version above.
232    #[allow(unreachable_code)]
233    if false {
234        let y = y_start;
235        for _y in y_start..y_end {
236            let row_base = y * w;
237            let mut x = x_start;
238
239            // SIMD batch: check 4 consecutive center pixels at a time
240            // This vectorizes the cardinal early-rejection test
241            #[cfg(target_arch = "aarch64")]
242            if std::arch::is_aarch64_feature_detected!("neon") {
243                while x + 4 <= x_end {
244                    let pass_mask =
245                        unsafe { fast9_cardinal_check_neon(data, row_base + x, &card, threshold) };
246                    // If no pixels passed cardinal check, skip all 4
247                    if pass_mask == 0 {
248                        x += 4;
249                        continue;
250                    }
251                    // Process pixels that passed cardinal check individually
252                    for i in 0..4 {
253                        if (pass_mask >> i) & 1 != 0 {
254                            let cx = x + i;
255                            let idx = row_base + cx;
256                            let max_run =
257                                unsafe { fast9_full_check(data, idx, &offsets, threshold) };
258                            if max_run >= 9 {
259                                corners.push(Keypoint {
260                                    x: cx as f32,
261                                    y: y as f32,
262                                    response: max_run as f32,
263                                    angle: 0.0,
264                                    octave: 0,
265                                });
266                            }
267                        }
268                    }
269                    x += 4;
270                }
271            }
272
273            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
274            if std::is_x86_feature_detected!("sse") {
275                while x + 4 <= x_end {
276                    let pass_mask =
277                        unsafe { fast9_cardinal_check_sse(data, row_base + x, &card, threshold) };
278                    if pass_mask == 0 {
279                        x += 4;
280                        continue;
281                    }
282                    for i in 0..4 {
283                        if (pass_mask >> i) & 1 != 0 {
284                            let cx = x + i;
285                            let idx = row_base + cx;
286                            let max_run =
287                                unsafe { fast9_full_check(data, idx, &offsets, threshold) };
288                            if max_run >= 9 {
289                                corners.push(Keypoint {
290                                    x: cx as f32,
291                                    y: y as f32,
292                                    response: max_run as f32,
293                                    angle: 0.0,
294                                    octave: 0,
295                                });
296                            }
297                        }
298                    }
299                    x += 4;
300                }
301            }
302
303            // Scalar tail for remaining pixels
304            while x < x_end {
305                let idx = row_base + x;
306                let center = unsafe { *data.get_unchecked(idx) };
307                let bright_thresh = center + threshold;
308                let dark_thresh = center - threshold;
309
310                let mut bright_count = 0u32;
311                let mut dark_count = 0u32;
312                for &co in &card {
313                    let v = unsafe { *data.get_unchecked((idx as isize + co) as usize) };
314                    bright_count += (v > bright_thresh) as u32;
315                    dark_count += (v < dark_thresh) as u32;
316                }
317                if bright_count < 3 && dark_count < 3 {
318                    x += 1;
319                    continue;
320                }
321
322                let max_run = unsafe { fast9_full_check(data, idx, &offsets, threshold) };
323                if max_run >= 9 {
324                    corners.push(Keypoint {
325                        x: x as f32,
326                        y: y as f32,
327                        response: max_run as f32,
328                        angle: 0.0,
329                        octave: 0,
330                    });
331                }
332                x += 1;
333            }
334        }
335    } // if false
336
337    if non_max {
338        let mut response_map = vec![0.0f32; h * w];
339        for kp in &corners {
340            let ix = kp.x as usize;
341            let iy = kp.y as usize;
342            response_map[iy * w + ix] = kp.response;
343        }
344        corners.retain(|kp| {
345            let ix = kp.x as usize;
346            let iy = kp.y as usize;
347            for dy in -1i32..=1 {
348                for dx in -1i32..=1 {
349                    if dy == 0 && dx == 0 {
350                        continue;
351                    }
352                    let ny = (iy as i32 + dy) as usize;
353                    let nx = (ix as i32 + dx) as usize;
354                    if ny < h && nx < w && response_map[ny * w + nx] > kp.response {
355                        return false;
356                    }
357                }
358            }
359            true
360        });
361    }
362
363    corners
364}
365
366/// SIMD cardinal check for 4 consecutive center pixels (NEON).
367/// Returns a 4-bit mask: bit i is set if pixel i passes cardinal test.
368#[cfg(target_arch = "aarch64")]
369#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
370#[target_feature(enable = "neon")]
371unsafe fn fast9_cardinal_check_neon(
372    data: &[f32],
373    base_idx: usize,
374    card: &[isize; 4],
375    threshold: f32,
376) -> u32 {
377    use std::arch::aarch64::*;
378
379    let ptr = data.as_ptr();
380    let thresh = vdupq_n_f32(threshold);
381    let neg_thresh = vdupq_n_f32(-threshold);
382    let three = vdupq_n_u32(3);
383
384    // Load 4 consecutive center pixels
385    let centers = vld1q_f32(ptr.add(base_idx));
386    let bright_thresh = vaddq_f32(centers, thresh);
387    let dark_thresh = vaddq_f32(centers, neg_thresh);
388
389    // For each cardinal direction, load the 4 corresponding circle pixels
390    // (consecutive since centers are consecutive)
391    let mut bright_cnt = vdupq_n_u32(0);
392    let mut dark_cnt = vdupq_n_u32(0);
393
394    for &co in card.iter() {
395        let circle_px = vld1q_f32(ptr.add((base_idx as isize + co) as usize));
396        // brighter: circle > bright_thresh
397        let b = vcgtq_f32(circle_px, bright_thresh);
398        bright_cnt = vsubq_u32(bright_cnt, vreinterpretq_u32_f32(vreinterpretq_f32_u32(b)));
399        // darker: circle < dark_thresh
400        let d = vcltq_f32(circle_px, dark_thresh);
401        dark_cnt = vsubq_u32(dark_cnt, vreinterpretq_u32_f32(vreinterpretq_f32_u32(d)));
402    }
403
404    // Check if bright_cnt >= 3 OR dark_cnt >= 3 for each pixel
405    let bright_pass = vcgeq_u32(bright_cnt, three);
406    let dark_pass = vcgeq_u32(dark_cnt, three);
407    let pass = vorrq_u32(bright_pass, dark_pass);
408
409    // Extract to 4-bit mask
410    let mut mask = 0u32;
411    if vgetq_lane_u32(pass, 0) != 0 {
412        mask |= 1;
413    }
414    if vgetq_lane_u32(pass, 1) != 0 {
415        mask |= 2;
416    }
417    if vgetq_lane_u32(pass, 2) != 0 {
418        mask |= 4;
419    }
420    if vgetq_lane_u32(pass, 3) != 0 {
421        mask |= 8;
422    }
423    mask
424}
425
426/// SIMD cardinal check for 4 consecutive center pixels (SSE).
427#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
428#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
429#[target_feature(enable = "sse")]
430unsafe fn fast9_cardinal_check_sse(
431    data: &[f32],
432    base_idx: usize,
433    card: &[isize; 4],
434    threshold: f32,
435) -> u32 {
436    #[cfg(target_arch = "x86")]
437    use std::arch::x86::*;
438    #[cfg(target_arch = "x86_64")]
439    use std::arch::x86_64::*;
440
441    let ptr = data.as_ptr();
442    let thresh = _mm_set1_ps(threshold);
443    let neg_thresh = _mm_set1_ps(-threshold);
444    let _zero = _mm_setzero_ps();
445
446    let centers = _mm_loadu_ps(ptr.add(base_idx));
447    let bright_thresh = _mm_add_ps(centers, thresh);
448    let dark_thresh = _mm_add_ps(centers, neg_thresh);
449
450    // Count cardinals that pass each threshold
451    let mut bright_cnt = _mm_setzero_ps();
452    let mut dark_cnt = _mm_setzero_ps();
453    let one_bits = _mm_set1_ps(1.0);
454
455    for &co in card.iter() {
456        let circle_px = _mm_loadu_ps(ptr.add((base_idx as isize + co) as usize));
457        // cmpgt returns 0xFFFFFFFF for true, so AND with 1.0 gives 1.0 for true
458        let b = _mm_and_ps(_mm_cmpgt_ps(circle_px, bright_thresh), one_bits);
459        bright_cnt = _mm_add_ps(bright_cnt, b);
460        let d = _mm_and_ps(_mm_cmplt_ps(circle_px, dark_thresh), one_bits);
461        dark_cnt = _mm_add_ps(dark_cnt, d);
462    }
463
464    let three = _mm_set1_ps(3.0);
465    let bright_pass = _mm_cmpge_ps(bright_cnt, three);
466    let dark_pass = _mm_cmpge_ps(dark_cnt, three);
467    let pass = _mm_or_ps(bright_pass, dark_pass);
468
469    _mm_movemask_ps(pass) as u32
470}
471
472/// Full FAST-9 check: build bitmasks and find max contiguous run.
473#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
474#[inline]
475unsafe fn fast9_full_check(
476    data: &[f32],
477    idx: usize,
478    offsets: &[isize; 16],
479    threshold: f32,
480) -> usize {
481    let center = *data.get_unchecked(idx);
482    let bright_thresh = center + threshold;
483    let dark_thresh = center - threshold;
484
485    let mut bright_mask = 0u32;
486    let mut dark_mask = 0u32;
487    for i in 0..16 {
488        let v = *data.get_unchecked((idx as isize + offsets[i]) as usize);
489        if v > bright_thresh {
490            bright_mask |= 1 << i;
491        }
492        if v < dark_thresh {
493            dark_mask |= 1 << i;
494        }
495    }
496
497    let bright_run = contiguous_run_from_mask(bright_mask);
498    let dark_run = contiguous_run_from_mask(dark_mask);
499    bright_run.max(dark_run)
500}
501
502/// Compute the intensity centroid orientation for a keypoint.
503/// Uses moments in a circular patch of given radius around (kx, ky).
504pub(crate) fn intensity_centroid_angle(
505    data: &[f32],
506    w: usize,
507    h: usize,
508    kx: usize,
509    ky: usize,
510    radius: i32,
511) -> f32 {
512    let mut m01: f32 = 0.0;
513    let mut m10: f32 = 0.0;
514    for dy in -radius..=radius {
515        let max_dx = ((radius * radius - dy * dy) as f32).sqrt() as i32;
516        for dx in -max_dx..=max_dx {
517            let py = ky as i32 + dy;
518            let px = kx as i32 + dx;
519            if py >= 0 && py < h as i32 && px >= 0 && px < w as i32 {
520                let v = data[py as usize * w + px as usize];
521                m10 += dx as f32 * v;
522                m01 += dy as f32 * v;
523            }
524        }
525    }
526    m01.atan2(m10)
527}
528
529/// Find the maximum number of contiguous `true` values in a circular 16-element array.
530#[allow(dead_code)]
531fn max_consecutive(flags: &[bool; 16]) -> usize {
532    let mut best = 0usize;
533    let mut count = 0usize;
534    // Scan twice around the circle to handle wrap-around
535    for i in 0..32 {
536        if flags[i % 16] {
537            count += 1;
538            if count > best {
539                best = count;
540            }
541        } else {
542            count = 0;
543        }
544    }
545    best.min(16)
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_fast9_detects_corner() {
554        // Create 30x30 image with an L-shaped corner: bright pixels on two edges
555        let mut data = vec![0.0f32; 30 * 30];
556        // Horizontal bright bar y=15, x=10..20
557        for x in 10..20 {
558            data[15 * 30 + x] = 1.0;
559        }
560        // Vertical bright bar x=10, y=10..20
561        for y in 10..20 {
562            data[y * 30 + 10] = 1.0;
563        }
564        let img = Tensor::from_vec(vec![30, 30, 1], data).unwrap();
565        let kps = fast9_detect(&img, 0.3, false).unwrap();
566        assert!(!kps.is_empty(), "should detect corners near the L-shape");
567    }
568
569    #[test]
570    fn test_fast9_no_corners_on_flat() {
571        let img = Tensor::from_vec(vec![20, 20, 1], vec![0.5; 400]).unwrap();
572        let kps = fast9_detect(&img, 0.1, false).unwrap();
573        assert!(kps.is_empty(), "flat image should produce no corners");
574    }
575
576    #[test]
577    fn test_fast9_threshold() {
578        // Bright dot on dark background — high threshold should detect fewer
579        let mut data = vec![0.0f32; 30 * 30];
580        data[15 * 30 + 15] = 1.0;
581        for &(dx, dy) in &CIRCLE {
582            let px = (15 + dx) as usize;
583            let py = (15 + dy) as usize;
584            data[py * 30 + px] = 0.6;
585        }
586        let img = Tensor::from_vec(vec![30, 30, 1], data.clone()).unwrap();
587        let low = fast9_detect(&img, 0.1, false).unwrap();
588        let high = fast9_detect(&img, 0.8, false).unwrap();
589        assert!(
590            high.len() <= low.len(),
591            "higher threshold should produce fewer or equal corners: low={} high={}",
592            low.len(),
593            high.len()
594        );
595    }
596}