fast-canny 0.1.0

Industrial-grade Zero-Allocation SIMD Canny Edge Detector
Documentation
use super::SobelKernel;

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

/// AVX2+FMA Sobel 融合算子。零大小类型,无运行时开销。
pub struct Avx2SobelKernel;

impl SobelKernel for Avx2SobelKernel {
    fn process_row_slice(
        &self,
        src:     &[f32],
        mag_out: &mut [f32],
        dir_out: &mut [u8],
        width:   usize,
        x_start: usize,
        x_end:   usize,
        y:       usize,
    ) {
        // SAFETY:
        // - AVX2+FMA 在 detect() 中已通过 is_x86_feature_detected! 确认可用
        // - 切片边界由调用方(pipeline)保证:x_start>=1, x_end<=width-1, y>=1
        // - 切片长度 == width * height,索引访问不会越界
        #[cfg(target_arch = "x86_64")]
        unsafe {
            avx2_process_row_slice(src, mag_out, dir_out, width, x_start, x_end, y);
        }

        // 非 x86_64 平台回退到标量(编译期死代码消除)
        #[cfg(not(target_arch = "x86_64"))]
        super::ScalarSobelKernel.process_row_slice(
            src, mag_out, dir_out, width, x_start, x_end, y,
        );
    }
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn avx2_process_row_slice(
    src:     &[f32],
    mag_out: &mut [f32],
    dir_out: &mut [u8],
    width:   usize,
    x_start: usize,
    x_end:   usize,
    y:       usize,
) {
    // SAFETY: 此函数仅在 avx2+fma target_feature 下编译和调用
    unsafe {
        let two       = _mm256_set1_ps(2.0);
        let tan_22    = _mm256_set1_ps(0.41421356);
        let tan_67    = _mm256_set1_ps(2.41421356);
        let sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF_u32 as i32));
        let zero_ps   = _mm256_setzero_ps();

        let row_offset = y * width;

        // 将切片转为裸指针,仅在此 unsafe 块内使用,生命周期由切片保证
        let src_ptr     = src.as_ptr();
        let mag_ptr     = mag_out.as_mut_ptr();
        let dir_ptr     = dir_out.as_mut_ptr();

        let p_up   = src_ptr.add((y - 1) * width);
        let p_mid  = src_ptr.add(row_offset);
        let p_down = src_ptr.add((y + 1) * width);

        let mut x = x_start;

        while x + 8 <= x_end {
            // ---------- 阶段 1:加载 3×3 邻居 ----------
            let tl = _mm256_loadu_ps(p_up.add(x - 1));
            let tm = _mm256_loadu_ps(p_up.add(x));
            let tr = _mm256_loadu_ps(p_up.add(x + 1));
            let ml = _mm256_loadu_ps(p_mid.add(x - 1));
            let mr = _mm256_loadu_ps(p_mid.add(x + 1));
            let bl = _mm256_loadu_ps(p_down.add(x - 1));
            let bm = _mm256_loadu_ps(p_down.add(x));
            let br = _mm256_loadu_ps(p_down.add(x + 1));

            // ---------- 阶段 2:Sobel ----------
            let gx = _mm256_add_ps(
                _mm256_add_ps(
                    _mm256_sub_ps(tr, tl),
                    _mm256_mul_ps(_mm256_sub_ps(mr, ml), two),
                ),
                _mm256_sub_ps(br, bl),
            );
            let gy = _mm256_add_ps(
                _mm256_add_ps(
                    _mm256_sub_ps(bl, tl),
                    _mm256_mul_ps(_mm256_sub_ps(bm, tm), two),
                ),
                _mm256_sub_ps(br, tr),
            );

            // ---------- 阶段 3:幅值 ----------
            let mag_sq = _mm256_fmadd_ps(gy, gy, _mm256_mul_ps(gx, gx));
            let mag    = _mm256_sqrt_ps(mag_sq);
            _mm256_storeu_ps(mag_ptr.add(row_offset + x), mag);

            // ---------- 阶段 4:方向量化(无分支)----------
            let gx_n = _mm256_add_ps(gx, zero_ps);
            let gy_n = _mm256_add_ps(gy, zero_ps);

            let ax = _mm256_and_ps(gx_n, sign_mask);
            let ay = _mm256_and_ps(gy_n, sign_mask);

            let mask_0  = _mm256_cmp_ps(ay, _mm256_mul_ps(ax, tan_22), _CMP_LE_OQ);
            let mask_90 = _mm256_cmp_ps(ay, _mm256_mul_ps(ax, tan_67), _CMP_GE_OQ);

            let xor_sign = _mm256_xor_si256(
                _mm256_castps_si256(gx_n),
                _mm256_castps_si256(gy_n),
            );
            let diff_sign_mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), xor_sign);

            let dir_base = _mm256_blendv_epi8(
                _mm256_set1_epi32(1),
                _mm256_set1_epi32(3),
                diff_sign_mask,
            );
            let dir_step1 = _mm256_andnot_si256(
                _mm256_castps_si256(mask_0), dir_base,
            );
            let dir_step2 = _mm256_blendv_epi8(
                dir_step1,
                _mm256_set1_epi32(2),
                _mm256_castps_si256(mask_90),
            );

            // 零梯度 → 方向 0
            let zero_mask = _mm256_cmp_ps(mag_sq, zero_ps, _CMP_EQ_OQ);
            let dir_final = _mm256_blendv_epi8(
                dir_step2,
                _mm256_setzero_si256(),
                _mm256_castps_si256(zero_mask),
            );

            // ---------- 阶段 5:Pack 32→8 bit ----------
            let dir_16        = _mm256_packus_epi32(dir_final, dir_final);
            let dir_16_merged = _mm256_permute4x64_epi64(dir_16, 0x88);
            let dir_128       = _mm256_castsi256_si128(dir_16_merged);
            let dir_8         = _mm_packus_epi16(dir_128, dir_128);
            _mm_storel_epi64(
                dir_ptr.add(row_offset + x) as *mut __m128i,
                dir_8,
            );

            x += 8;
        }

        // ---------- 尾部标量处理(补全 AVX2 未覆盖的 1~7 个像素)----------
        // SAFETY: x < x_end <= width-1,且 y >= 1,邻域访问合法
        while x < x_end {
            let idx = row_offset + x;
            let tl = *src_ptr.add(idx - width - 1);
            let tm = *src_ptr.add(idx - width);
            let tr = *src_ptr.add(idx - width + 1);
            let ml = *src_ptr.add(idx - 1);
            let mr = *src_ptr.add(idx + 1);
            let bl = *src_ptr.add(idx + width - 1);
            let bm = *src_ptr.add(idx + width);
            let br = *src_ptr.add(idx + width + 1);

            let gx = -tl + tr - 2.0 * ml + 2.0 * mr - bl + br;
            let gy = -tl - 2.0 * tm - tr + bl + 2.0 * bm + br;

            *mag_ptr.add(idx) = (gx * gx + gy * gy).sqrt();

            let ax = gx.abs();
            let ay = gy.abs();
            *dir_ptr.add(idx) = if ay <= ax * 0.414_213_56 {
                0u8
            } else if ay >= ax * 2.414_213_56 {
                2u8
            } else if (gx >= 0.0) == (gy >= 0.0) {
                1u8
            } else {
                3u8
            };
            x += 1;
        }
    }
}