Skip to main content

fast_canny/kernel/
mod.rs

1/// Sobel 融合算子的硬件无关接口。
2///
3/// # 安全性
4/// 实现者必须保证:
5/// - `src` 指向有效的 `width × height` f32 数组
6/// - `x_start >= 1`,`x_end <= width - 1`,`y >= 1` 且 `y <= height - 2`
7/// - `mag_out` 和 `dir_out` 指向同等大小的可写缓冲区
8pub trait SobelKernel: Send + Sync {
9    /// 处理图像中第 `y` 行 `[x_start, x_end)` 范围内的像素。
10    ///
11    /// 使用切片接口替代裸指针,降低调用方的 unsafe 负担。
12    fn process_row_slice(
13        &self,
14        src:     &[f32],
15        mag_out: &mut [f32],
16        dir_out: &mut [u8],
17        width:   usize,
18        x_start: usize,
19        x_end:   usize,
20        y:       usize,
21    );
22}
23
24#[cfg(target_arch = "x86_64")]
25pub mod avx2;
26
27#[cfg(target_arch = "aarch64")]
28pub mod aarch64;
29
30/// 返回当前平台最优的 SobelKernel 实现。
31pub fn detect() -> Box<dyn SobelKernel> {
32    #[cfg(target_arch = "x86_64")]
33    {
34        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
35            log::info!("[kernel::detect] selected: AVX2+FMA");
36            return Box::new(avx2::Avx2SobelKernel);
37        }
38        log::warn!("[kernel::detect] AVX2/FMA not available, falling back to scalar");
39    }
40
41    #[cfg(target_arch = "aarch64")]
42    {
43        log::info!("[kernel::detect] selected: AArch64 NEON");
44        return Box::new(aarch64::NeonSobelKernel);
45    }
46
47    #[allow(unreachable_code)]
48    {
49        log::warn!("[kernel::detect] no SIMD available, using scalar fallback");
50        Box::new(ScalarSobelKernel)
51    }
52}
53
54// ── 标量后备实现 ──────────────────────────────────────────────────
55pub(crate) struct ScalarSobelKernel;
56
57impl SobelKernel for ScalarSobelKernel {
58    fn process_row_slice(
59        &self,
60        src:     &[f32],
61        mag_out: &mut [f32],
62        dir_out: &mut [u8],
63        width:   usize,
64        x_start: usize,
65        x_end:   usize,
66        y:       usize,
67    ) {
68        debug_assert!(x_start >= 1 && x_end <= width - 1);
69        debug_assert!(y >= 1 && y < src.len() / width);
70
71        let row = y * width;
72        for x in x_start..x_end {
73            let idx = row + x;
74            // 3×3 邻域索引,边界由调用方保证合法
75            let tl = src[idx - width - 1];
76            let tm = src[idx - width];
77            let tr = src[idx - width + 1];
78            let ml = src[idx - 1];
79            let mr = src[idx + 1];
80            let bl = src[idx + width - 1];
81            let bm = src[idx + width];
82            let br = src[idx + width + 1];
83
84            let gx = -tl + tr - 2.0 * ml + 2.0 * mr - bl + br;
85            let gy = -tl - 2.0 * tm - tr + bl + 2.0 * bm + br;
86
87            mag_out[idx] = (gx * gx + gy * gy).sqrt();
88
89            let ax = gx.abs();
90            let ay = gy.abs();
91            dir_out[idx] = if ay <= ax * 0.414_213_56 {
92                0u8
93            } else if ay >= ax * 2.414_213_56 {
94                2u8
95            } else if (gx >= 0.0) == (gy >= 0.0) {
96                1u8
97            } else {
98                3u8
99            };
100        }
101    }
102}