fast-canny 0.1.0

Industrial-grade Zero-Allocation SIMD Canny Edge Detector
Documentation
// =====================================================================
// 模块说明:
// AArch64 NEON 的 Sobel 融合算子,实现 kernel::SobelKernel trait。
// 与 avx2.rs 保持相同的功能接口:Sobel 滤波 + 梯度幅值 + 方向量化。
//
// NEON 寄存器宽度为 128-bit,每次处理 4 个 f32(AVX2 处理 8 个)。
// 方向量化逻辑与 AVX2 版本完全一致,输出值域 {0,1,2,3} 相同。
// =====================================================================

use super::SobelKernel;

#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;

/// AArch64 NEON Sobel 融合算子。零大小类型,无运行时开销。
pub struct NeonSobelKernel;

impl SobelKernel for NeonSobelKernel {
    fn process_row_slice(
        &self,
        src: &[f32],
        mag_out: &mut [f32],
        dir_out: &mut [u8],
        width: usize,
        x_start: usize,
        x_end: usize,
        y: usize,
    ) {
        unsafe {
            neon_process_row(
                src.as_ptr(),
                mag_out.as_mut_ptr(),
                dir_out.as_mut_ptr(),
                width,
                x_start,
                x_end,
                y,
            );
        }
    }
}

#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_process_row(
    src: *const f32,
    mag_out: *mut f32,
    dir_out: *mut u8,
    width: usize,
    x_start: usize,
    x_end: usize,
    y: usize,
) {
    unsafe {
        // ── 常量 ──────────────────────────────────────────────────
        let two = vdupq_n_f32(2.0);
        let tan_22 = vdupq_n_f32(0.414_213_56_f32); // tan(22.5°)
        let tan_67 = vdupq_n_f32(2.414_213_56_f32); // tan(67.5°)
        let zero = vdupq_n_f32(0.0);

        // 符号掩码:清除 IEEE 754 符号位,等效于 fabsf
        // 0x7FFF_FFFF = i32::MAX,转为 f32 位模式
        let abs_mask = vreinterpretq_f32_u32(vdupq_n_u32(0x7FFF_FFFF));

        let row_offset = y * width;
        let p_up = src.add((y - 1) * width);
        let p_mid = src.add(row_offset);
        let p_down = src.add((y + 1) * width);

        let mut x = x_start;

        // ── NEON 主循环:每次处理 4 个像素 ───────────────────────
        while x + 4 <= x_end {
            // ── 阶段 1:加载 3×3 邻居(4 列 × 3 行)─────────────
            // 命名规则:t=top m=mid b=bottom l=left c=center r=right
            let tl = vld1q_f32(p_up.add(x - 1));
            let tm = vld1q_f32(p_up.add(x));
            let tr = vld1q_f32(p_up.add(x + 1));
            let ml = vld1q_f32(p_mid.add(x - 1));
            let mr = vld1q_f32(p_mid.add(x + 1));
            let bl = vld1q_f32(p_down.add(x - 1));
            let bm = vld1q_f32(p_down.add(x));
            let br = vld1q_f32(p_down.add(x + 1));

            // ── 阶段 2:Sobel 卷积 ────────────────────────────────
            // Gx = (tr - tl) + 2*(mr - ml) + (br - bl)
            let gx = vaddq_f32(
                vaddq_f32(vsubq_f32(tr, tl), vmulq_f32(vsubq_f32(mr, ml), two)),
                vsubq_f32(br, bl),
            );
            // Gy = (bl - tl) + 2*(bm - tm) + (br - tr)
            let gy = vaddq_f32(
                vaddq_f32(vsubq_f32(bl, tl), vmulq_f32(vsubq_f32(bm, tm), two)),
                vsubq_f32(br, tr),
            );

            // ── 阶段 3:梯度幅值 = sqrt(gx² + gy²) ──────────────
            // vfmaq_f32(a, b, c) = a + b*c
            let mag_sq = vfmaq_f32(vmulq_f32(gx, gx), gy, gy);
            let mag = vsqrtq_f32(mag_sq);
            vst1q_f32(mag_out.add(row_offset + x), mag);

            // ── 阶段 4:方向量化(无分支 NEON 实现)──────────────
            //
            // 规范化 ±0.0:加 0.0 消除负零的符号位歧义
            let gx_n = vaddq_f32(gx, zero);
            let gy_n = vaddq_f32(gy, zero);

            // 绝对值
            let ax = vandq_u32(vreinterpretq_u32_f32(gx_n), vreinterpretq_u32_f32(abs_mask));
            let ay = vandq_u32(vreinterpretq_u32_f32(gy_n), vreinterpretq_u32_f32(abs_mask));
            let ax_f = vreinterpretq_f32_u32(ax);
            let ay_f = vreinterpretq_f32_u32(ay);

            // mask_0  = (ay <= ax * tan_22)  → 方向 0°
            // mask_90 = (ay >= ax * tan_67)  → 方向 90°
            let mask_0 = vcleq_f32(ay_f, vmulq_f32(ax_f, tan_22));
            let mask_90 = vcgeq_f32(ay_f, vmulq_f32(ax_f, tan_67));

            // 符号异或:gx 和 gy 符号位不同 → 方向 135°(3),相同 → 45°(1)
            let xor_sign = veorq_u32(vreinterpretq_u32_f32(gx_n), vreinterpretq_u32_f32(gy_n));
            // 符号位在最高位(bit31),右移 31 位得到 0 或 1
            let diff_sign = vshrq_n_u32(xor_sign, 31);

            // dir_base = diff_sign ? 3 : 1
            let dir_base = vbslq_u32(
                vceqq_u32(diff_sign, vdupq_n_u32(1)), // diff_sign == 1
                vdupq_n_u32(3),
                vdupq_n_u32(1),
            );

            // mask_0 为真时强制 dir = 0
            let dir_step1 = vbslq_u32(mask_0, vdupq_n_u32(0), dir_base);

            // mask_90 为真时强制 dir = 2(优先级高于 45°/135°,与 AVX2 一致)
            let dir_step2 = vbslq_u32(mask_90, vdupq_n_u32(2), dir_step1);

            // 零梯度 → 方向 0(mag_sq == 0.0)
            let zero_mask = vceqq_f32(mag_sq, zero);
            let dir_final = vbslq_u32(zero_mask, vdupq_n_u32(0), dir_step2);

            // ── 阶段 5:Pack u32×4 → u8×4 写入 dir_out ──────────
            // u32 → u16(saturating narrowing)
            let dir_u16 = vqmovn_u32(dir_final); // uint16x4_t
            // u16 → u8(saturating narrowing)
            let dir_u8 = vqmovn_u16(vcombine_u16(dir_u16, dir_u16)); // uint8x8_t
            // 只写低 4 字节(对应 4 个像素)
            vst1_lane_u32(
                dir_out.add(row_offset + x) as *mut u32,
                vreinterpret_u32_u8(dir_u8),
                0,
            );

            x += 4;
        }

        // ── 标量尾部处理(剩余 < 4 个像素)──────────────────────
        // 与 ScalarSobelKernel 逻辑完全一致,确保边界正确
        while x < x_end {
            let idx = row_offset + x;
            let w = width as isize;
            let i = idx as isize;

            macro_rules! p {
                ($off:expr) => {
                    *src.offset(i + $off)
                };
            }

            let gx = p!(-w - 1) * -1.0
                + p!(-w + 1) * 1.0
                + p!(-1) * -2.0
                + p!(1) * 2.0
                + p!(w - 1) * -1.0
                + p!(w + 1) * 1.0;
            let gy = p!(-w - 1) * -1.0
                + p!(-w) * -2.0
                + p!(-w + 1) * -1.0
                + p!(w - 1) * 1.0
                + p!(w) * 2.0
                + p!(w + 1) * 1.0;

            *mag_out.add(idx) = (gx * gx + gy * gy).sqrt();

            let ax = gx.abs();
            let ay = gy.abs();
            let dir = if ay <= ax * 0.414_213_56 {
                0u8
            } else if ay >= ax * 2.414_213_56 {
                2u8
            } else if (gx >= 0.0) == (gy >= 0.0) {
                1u8
            } else {
                3u8
            };
            *dir_out.add(idx) = dir;

            x += 1;
        }
    }
}