Skip to main content

patch_tracker/
corners_fast9.rs

1use image::GrayImage;
2use wide::{CmpGt, CmpLt, i16x8};
3
4#[derive(Copy, Clone, Debug, PartialEq)]
5pub struct Corner {
6    pub x: u32,
7    pub y: u32,
8    pub score: f32,
9}
10
11impl Corner {
12    pub fn new(x: u32, y: u32, score: f32) -> Corner {
13        Corner { x, y, score }
14    }
15}
16
17pub fn fast_corner_score(image: &GrayImage, threshold: u8, x: u32, y: u32) -> u8 {
18    let mut max = 255u8;
19    let mut min = threshold;
20
21    loop {
22        if max == min {
23            return max;
24        }
25
26        let mean = ((max as u16 + min as u16) / 2u16) as u8;
27        let probe = if max == min + 1 { max } else { mean };
28
29        if is_corner_fast9_scalar(image, probe, x, y) {
30            min = probe;
31        } else {
32            max = probe - 1;
33        }
34    }
35}
36
37#[inline(always)]
38fn load_8u8_to_i16x8(ptr: *const u8) -> i16x8 {
39    let bytes = unsafe { std::ptr::read_unaligned(ptr as *const [u8; 8]) };
40    i16x8::new([
41        bytes[0] as i16,
42        bytes[1] as i16,
43        bytes[2] as i16,
44        bytes[3] as i16,
45        bytes[4] as i16,
46        bytes[5] as i16,
47        bytes[6] as i16,
48        bytes[7] as i16,
49    ])
50}
51
52fn search_span<F>(circle: &[i16; 16], length: u8, f: F) -> bool
53where
54    F: Fn(&i16) -> bool,
55{
56    let mut nb_ok = 0u8;
57    let mut nb_ok_start = None;
58    for c in circle.iter() {
59        if f(c) {
60            nb_ok += 1;
61            if nb_ok == length {
62                return true;
63            }
64        } else {
65            if nb_ok_start.is_none() {
66                nb_ok_start = Some(nb_ok);
67            }
68            nb_ok = 0;
69        }
70    }
71    nb_ok + nb_ok_start.unwrap_or(0) >= length
72}
73
74fn is_corner_fast9_scalar(image: &GrayImage, threshold: u8, x: u32, y: u32) -> bool {
75    let c = image.get_pixel(x, y)[0] as i16;
76    let low_thresh = c - threshold as i16;
77    let high_thresh = c + threshold as i16;
78
79    let p0 = image.get_pixel(x, y - 3)[0] as i16;
80    let p8 = image.get_pixel(x, y + 3)[0] as i16;
81    let p4 = image.get_pixel(x + 3, y)[0] as i16;
82    let p12 = image.get_pixel(x - 3, y)[0] as i16;
83
84    let above = (p12 > high_thresh || p4 > high_thresh) && (p8 > high_thresh || p0 > high_thresh);
85    let below = (p12 < low_thresh || p4 < low_thresh) && (p8 < low_thresh || p0 < low_thresh);
86
87    if !above && !below {
88        return false;
89    }
90
91    let pixels = [
92        p0,
93        image.get_pixel(x + 1, y - 3)[0] as i16,
94        image.get_pixel(x + 2, y - 2)[0] as i16,
95        image.get_pixel(x + 3, y - 1)[0] as i16,
96        p4,
97        image.get_pixel(x + 3, y + 1)[0] as i16,
98        image.get_pixel(x + 2, y + 2)[0] as i16,
99        image.get_pixel(x + 1, y + 3)[0] as i16,
100        p8,
101        image.get_pixel(x - 1, y + 3)[0] as i16,
102        image.get_pixel(x - 2, y + 2)[0] as i16,
103        image.get_pixel(x - 3, y + 1)[0] as i16,
104        p12,
105        image.get_pixel(x - 3, y - 1)[0] as i16,
106        image.get_pixel(x - 2, y - 2)[0] as i16,
107        image.get_pixel(x - 1, y - 3)[0] as i16,
108    ];
109
110    if above && search_span(&pixels, 9, |&p| p > high_thresh) {
111        return true;
112    }
113    if below && search_span(&pixels, 9, |&p| p < low_thresh) {
114        return true;
115    }
116    false
117}
118
119pub fn simd_corners_fast9(image: &GrayImage, threshold: u8) -> Vec<Corner> {
120    let width = image.width() as usize;
121    let width_isize = width as isize;
122    let height = image.height() as usize;
123    let mut corners = Vec::new();
124
125    if width < 7 || height < 7 {
126        return corners;
127    }
128
129    let img_ptr = image.as_raw().as_ptr();
130
131    let ring_offsets: [isize; 16] = [
132        -3 * (width_isize),
133        -3 * (width_isize) + 1,
134        -2 * (width_isize) + 2,
135        -(width_isize) + 3,
136        3,
137        (width_isize) + 3,
138        2 * (width_isize) + 2,
139        3 * (width_isize) + 1,
140        3 * (width_isize),
141        3 * (width_isize) - 1,
142        2 * (width_isize) - 2,
143        (width_isize) - 3,
144        -3,
145        -(width_isize) - 3,
146        -2 * (width_isize) - 2,
147        -3 * (width_isize) - 1,
148    ];
149
150    let t = i16x8::splat(threshold as i16);
151
152    for y in 3..height - 3 {
153        let mut x = 3;
154
155        while x + 7 < width - 3 {
156            let center_ptr = unsafe { img_ptr.add(y * width + x) };
157            let center = load_8u8_to_i16x8(center_ptr);
158            let high_thresh = center + t;
159            let low_thresh = center - t;
160
161            let p0 = load_8u8_to_i16x8(unsafe { center_ptr.offset(ring_offsets[0]) });
162            let p8 = load_8u8_to_i16x8(unsafe { center_ptr.offset(ring_offsets[8]) });
163
164            let above_0 = p0.simd_gt(high_thresh);
165            let below_0 = p0.simd_lt(low_thresh);
166            let above_8 = p8.simd_gt(high_thresh);
167            let below_8 = p8.simd_lt(low_thresh);
168
169            let above_08 = above_0 | above_8;
170            let below_08 = below_0 | below_8;
171
172            if (above_08 | below_08).to_bitmask() == 0 {
173                x += 8;
174                continue;
175            }
176
177            let p4 = load_8u8_to_i16x8(unsafe { center_ptr.offset(ring_offsets[4]) });
178            let p12 = load_8u8_to_i16x8(unsafe { center_ptr.offset(ring_offsets[12]) });
179
180            let above_4 = p4.simd_gt(high_thresh);
181            let below_4 = p4.simd_lt(low_thresh);
182            let above_12 = p12.simd_gt(high_thresh);
183            let below_12 = p12.simd_lt(low_thresh);
184
185            let count_above = ((above_0 | above_8) & (above_4 | above_12))
186                | (above_0 & above_8)
187                | (above_4 & above_12);
188            let count_below = ((below_0 | below_8) & (below_4 | below_12))
189                | (below_0 & below_8)
190                | (below_4 & below_12);
191
192            let pass_quick = count_above | count_below;
193
194            if pass_quick.to_bitmask() == 0 {
195                x += 8;
196                continue;
197            }
198
199            let mut ring = [i16x8::splat(0); 16];
200            ring[0] = p0;
201            ring[4] = p4;
202            ring[8] = p8;
203            ring[12] = p12;
204            for i in [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15] {
205                ring[i] = load_8u8_to_i16x8(unsafe { center_ptr.offset(ring_offsets[i]) });
206            }
207
208            let mut above = [i16x8::splat(0); 16];
209            let mut below = [i16x8::splat(0); 16];
210            for i in 0..16 {
211                above[i] = ring[i].simd_gt(high_thresh);
212                below[i] = ring[i].simd_lt(low_thresh);
213            }
214
215            let check_9 = |arr: &[i16x8; 16]| -> u32 {
216                let mut a2 = [i16x8::splat(0); 16];
217                for i in 0..16 {
218                    a2[i] = arr[i] & arr[(i + 1) % 16];
219                }
220
221                let mut a4 = [i16x8::splat(0); 16];
222                for i in 0..16 {
223                    a4[i] = a2[i] & a2[(i + 2) % 16];
224                }
225
226                let mut a8 = [i16x8::splat(0); 16];
227                for i in 0..16 {
228                    a8[i] = a4[i] & a4[(i + 4) % 16];
229                }
230
231                let mut a9 = [i16x8::splat(0); 16];
232                for i in 0..16 {
233                    a9[i] = a8[i] & arr[(i + 8) % 16];
234                }
235
236                let mut final_a = a9[0];
237                for item in a9.iter().skip(1) {
238                    final_a |= item;
239                }
240
241                final_a.to_bitmask()
242            };
243
244            let mask_above = check_9(&above) & pass_quick.to_bitmask();
245            let mask_below = check_9(&below) & pass_quick.to_bitmask();
246            let final_mask = mask_above | mask_below;
247
248            if final_mask != 0 {
249                for i in 0..8 {
250                    if (final_mask & (1 << i)) != 0 {
251                        let cx = (x + i) as u32;
252                        let cy = y as u32;
253                        let score = fast_corner_score(image, threshold, cx, cy);
254                        corners.push(Corner::new(cx, cy, score as f32));
255                    }
256                }
257            }
258
259            x += 8;
260        }
261
262        for cx in x..width - 3 {
263            if is_corner_fast9_scalar(image, threshold, cx as u32, y as u32) {
264                let score = fast_corner_score(image, threshold, cx as u32, y as u32);
265                corners.push(Corner::new(cx as u32, y as u32, score as f32));
266            }
267        }
268    }
269
270    corners
271}