use super::SobelKernel;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub struct Avx2SobelKernel;
impl SobelKernel for Avx2SobelKernel {
fn process_row_slice(
&self,
src: &[f32],
mag_out: &mut [f32],
dir_out: &mut [u8],
width: usize,
x_start: usize,
x_end: usize,
y: usize,
) {
#[cfg(target_arch = "x86_64")]
unsafe {
avx2_process_row_slice(src, mag_out, dir_out, width, x_start, x_end, y);
}
#[cfg(not(target_arch = "x86_64"))]
super::ScalarSobelKernel.process_row_slice(
src, mag_out, dir_out, width, x_start, x_end, y,
);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn avx2_process_row_slice(
src: &[f32],
mag_out: &mut [f32],
dir_out: &mut [u8],
width: usize,
x_start: usize,
x_end: usize,
y: usize,
) {
unsafe {
let two = _mm256_set1_ps(2.0);
let tan_22 = _mm256_set1_ps(0.41421356);
let tan_67 = _mm256_set1_ps(2.41421356);
let sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF_u32 as i32));
let zero_ps = _mm256_setzero_ps();
let row_offset = y * width;
let src_ptr = src.as_ptr();
let mag_ptr = mag_out.as_mut_ptr();
let dir_ptr = dir_out.as_mut_ptr();
let p_up = src_ptr.add((y - 1) * width);
let p_mid = src_ptr.add(row_offset);
let p_down = src_ptr.add((y + 1) * width);
let mut x = x_start;
while x + 8 <= x_end {
let tl = _mm256_loadu_ps(p_up.add(x - 1));
let tm = _mm256_loadu_ps(p_up.add(x));
let tr = _mm256_loadu_ps(p_up.add(x + 1));
let ml = _mm256_loadu_ps(p_mid.add(x - 1));
let mr = _mm256_loadu_ps(p_mid.add(x + 1));
let bl = _mm256_loadu_ps(p_down.add(x - 1));
let bm = _mm256_loadu_ps(p_down.add(x));
let br = _mm256_loadu_ps(p_down.add(x + 1));
let gx = _mm256_add_ps(
_mm256_add_ps(
_mm256_sub_ps(tr, tl),
_mm256_mul_ps(_mm256_sub_ps(mr, ml), two),
),
_mm256_sub_ps(br, bl),
);
let gy = _mm256_add_ps(
_mm256_add_ps(
_mm256_sub_ps(bl, tl),
_mm256_mul_ps(_mm256_sub_ps(bm, tm), two),
),
_mm256_sub_ps(br, tr),
);
let mag_sq = _mm256_fmadd_ps(gy, gy, _mm256_mul_ps(gx, gx));
let mag = _mm256_sqrt_ps(mag_sq);
_mm256_storeu_ps(mag_ptr.add(row_offset + x), mag);
let gx_n = _mm256_add_ps(gx, zero_ps);
let gy_n = _mm256_add_ps(gy, zero_ps);
let ax = _mm256_and_ps(gx_n, sign_mask);
let ay = _mm256_and_ps(gy_n, sign_mask);
let mask_0 = _mm256_cmp_ps(ay, _mm256_mul_ps(ax, tan_22), _CMP_LE_OQ);
let mask_90 = _mm256_cmp_ps(ay, _mm256_mul_ps(ax, tan_67), _CMP_GE_OQ);
let xor_sign = _mm256_xor_si256(
_mm256_castps_si256(gx_n),
_mm256_castps_si256(gy_n),
);
let diff_sign_mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), xor_sign);
let dir_base = _mm256_blendv_epi8(
_mm256_set1_epi32(1),
_mm256_set1_epi32(3),
diff_sign_mask,
);
let dir_step1 = _mm256_andnot_si256(
_mm256_castps_si256(mask_0), dir_base,
);
let dir_step2 = _mm256_blendv_epi8(
dir_step1,
_mm256_set1_epi32(2),
_mm256_castps_si256(mask_90),
);
let zero_mask = _mm256_cmp_ps(mag_sq, zero_ps, _CMP_EQ_OQ);
let dir_final = _mm256_blendv_epi8(
dir_step2,
_mm256_setzero_si256(),
_mm256_castps_si256(zero_mask),
);
let dir_16 = _mm256_packus_epi32(dir_final, dir_final);
let dir_16_merged = _mm256_permute4x64_epi64(dir_16, 0x88);
let dir_128 = _mm256_castsi256_si128(dir_16_merged);
let dir_8 = _mm_packus_epi16(dir_128, dir_128);
_mm_storel_epi64(
dir_ptr.add(row_offset + x) as *mut __m128i,
dir_8,
);
x += 8;
}
while x < x_end {
let idx = row_offset + x;
let tl = *src_ptr.add(idx - width - 1);
let tm = *src_ptr.add(idx - width);
let tr = *src_ptr.add(idx - width + 1);
let ml = *src_ptr.add(idx - 1);
let mr = *src_ptr.add(idx + 1);
let bl = *src_ptr.add(idx + width - 1);
let bm = *src_ptr.add(idx + width);
let br = *src_ptr.add(idx + width + 1);
let gx = -tl + tr - 2.0 * ml + 2.0 * mr - bl + br;
let gy = -tl - 2.0 * tm - tr + bl + 2.0 * bm + br;
*mag_ptr.add(idx) = (gx * gx + gy * gy).sqrt();
let ax = gx.abs();
let ay = gy.abs();
*dir_ptr.add(idx) = if ay <= ax * 0.414_213_56 {
0u8
} else if ay >= ax * 2.414_213_56 {
2u8
} else if (gx >= 0.0) == (gy >= 0.0) {
1u8
} else {
3u8
};
x += 1;
}
}
}