use super::SobelKernel;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
pub struct NeonSobelKernel;
impl SobelKernel for NeonSobelKernel {
fn process_row_slice(
&self,
src: &[f32],
mag_out: &mut [f32],
dir_out: &mut [u8],
width: usize,
x_start: usize,
x_end: usize,
y: usize,
) {
unsafe {
neon_process_row(
src.as_ptr(),
mag_out.as_mut_ptr(),
dir_out.as_mut_ptr(),
width,
x_start,
x_end,
y,
);
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_process_row(
src: *const f32,
mag_out: *mut f32,
dir_out: *mut u8,
width: usize,
x_start: usize,
x_end: usize,
y: usize,
) {
unsafe {
let two = vdupq_n_f32(2.0);
let tan_22 = vdupq_n_f32(0.414_213_56_f32); let tan_67 = vdupq_n_f32(2.414_213_56_f32); let zero = vdupq_n_f32(0.0);
let abs_mask = vreinterpretq_f32_u32(vdupq_n_u32(0x7FFF_FFFF));
let row_offset = y * width;
let p_up = src.add((y - 1) * width);
let p_mid = src.add(row_offset);
let p_down = src.add((y + 1) * width);
let mut x = x_start;
while x + 4 <= x_end {
let tl = vld1q_f32(p_up.add(x - 1));
let tm = vld1q_f32(p_up.add(x));
let tr = vld1q_f32(p_up.add(x + 1));
let ml = vld1q_f32(p_mid.add(x - 1));
let mr = vld1q_f32(p_mid.add(x + 1));
let bl = vld1q_f32(p_down.add(x - 1));
let bm = vld1q_f32(p_down.add(x));
let br = vld1q_f32(p_down.add(x + 1));
let gx = vaddq_f32(
vaddq_f32(vsubq_f32(tr, tl), vmulq_f32(vsubq_f32(mr, ml), two)),
vsubq_f32(br, bl),
);
let gy = vaddq_f32(
vaddq_f32(vsubq_f32(bl, tl), vmulq_f32(vsubq_f32(bm, tm), two)),
vsubq_f32(br, tr),
);
let mag_sq = vfmaq_f32(vmulq_f32(gx, gx), gy, gy);
let mag = vsqrtq_f32(mag_sq);
vst1q_f32(mag_out.add(row_offset + x), mag);
let gx_n = vaddq_f32(gx, zero);
let gy_n = vaddq_f32(gy, zero);
let ax = vandq_u32(vreinterpretq_u32_f32(gx_n), vreinterpretq_u32_f32(abs_mask));
let ay = vandq_u32(vreinterpretq_u32_f32(gy_n), vreinterpretq_u32_f32(abs_mask));
let ax_f = vreinterpretq_f32_u32(ax);
let ay_f = vreinterpretq_f32_u32(ay);
let mask_0 = vcleq_f32(ay_f, vmulq_f32(ax_f, tan_22));
let mask_90 = vcgeq_f32(ay_f, vmulq_f32(ax_f, tan_67));
let xor_sign = veorq_u32(vreinterpretq_u32_f32(gx_n), vreinterpretq_u32_f32(gy_n));
let diff_sign = vshrq_n_u32(xor_sign, 31);
let dir_base = vbslq_u32(
vceqq_u32(diff_sign, vdupq_n_u32(1)), vdupq_n_u32(3),
vdupq_n_u32(1),
);
let dir_step1 = vbslq_u32(mask_0, vdupq_n_u32(0), dir_base);
let dir_step2 = vbslq_u32(mask_90, vdupq_n_u32(2), dir_step1);
let zero_mask = vceqq_f32(mag_sq, zero);
let dir_final = vbslq_u32(zero_mask, vdupq_n_u32(0), dir_step2);
let dir_u16 = vqmovn_u32(dir_final); let dir_u8 = vqmovn_u16(vcombine_u16(dir_u16, dir_u16)); vst1_lane_u32(
dir_out.add(row_offset + x) as *mut u32,
vreinterpret_u32_u8(dir_u8),
0,
);
x += 4;
}
while x < x_end {
let idx = row_offset + x;
let w = width as isize;
let i = idx as isize;
macro_rules! p {
($off:expr) => {
*src.offset(i + $off)
};
}
let gx = p!(-w - 1) * -1.0
+ p!(-w + 1) * 1.0
+ p!(-1) * -2.0
+ p!(1) * 2.0
+ p!(w - 1) * -1.0
+ p!(w + 1) * 1.0;
let gy = p!(-w - 1) * -1.0
+ p!(-w) * -2.0
+ p!(-w + 1) * -1.0
+ p!(w - 1) * 1.0
+ p!(w) * 2.0
+ p!(w + 1) * 1.0;
*mag_out.add(idx) = (gx * gx + gy * gy).sqrt();
let ax = gx.abs();
let ay = gy.abs();
let dir = if ay <= ax * 0.414_213_56 {
0u8
} else if ay >= ax * 2.414_213_56 {
2u8
} else if (gx >= 0.0) == (gy >= 0.0) {
1u8
} else {
3u8
};
*dir_out.add(idx) = dir;
x += 1;
}
}
}