use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::get_twiddles_neon;
use crate::kernel::Complex;
use crate::prelude::*;
#[target_feature(enable = "neon")]
pub unsafe fn stockham_radix4_neon(
input: &[Complex<f64>],
output: &mut [Complex<f64>],
sign: Sign,
) {
unsafe {
use core::arch::aarch64::*;
let n = input.len();
let log_n = n.trailing_zeros() as usize;
if n <= 4 {
stockham_small_neon(input, output, sign);
return;
}
let half_n = n / 2;
let quarter_n = n / 4;
let mut scratch: Vec<Complex<f64>> = vec![Complex::zero(); n];
let num_fused = log_n / 2;
let has_final = usize::from(log_n % 2 == 1);
let total_writes = num_fused + has_final;
let (mut src_ptr, mut dst_ptr): (*mut f64, *mut f64) = if total_writes.is_multiple_of(2) {
output.copy_from_slice(input);
(
output.as_mut_ptr() as *mut f64,
scratch.as_mut_ptr() as *mut f64,
)
} else {
scratch.copy_from_slice(input);
(
scratch.as_mut_ptr() as *mut f64,
output.as_mut_ptr() as *mut f64,
)
};
let sign_pattern = vld1q_f64([-1.0_f64, 1.0].as_ptr());
let twiddles = get_twiddles_neon();
let forward = sign == Sign::Forward;
let mut stage = 0;
let mut m = 1usize;
while stage + 1 < log_n {
let m1 = m;
let m2 = m * 2;
let m4 = m * 4;
let tw1_base = if forward {
twiddles.forward[twiddles.offsets[stage]..].as_ptr()
} else {
twiddles.inverse[twiddles.offsets[stage]..].as_ptr()
};
let tw2_base = if forward {
twiddles.forward[twiddles.offsets[stage + 1]..].as_ptr()
} else {
twiddles.inverse[twiddles.offsets[stage + 1]..].as_ptr()
};
let src = src_ptr;
let dst = dst_ptr;
let num_groups = half_n / m2;
let prefetch_dist = 16;
for g in 0..num_groups {
let mut j = 0;
while j + 2 <= m1 {
if j + prefetch_dist < m1 {
let k_pf = g * m1 + j + prefetch_dist;
prefetch_read(src.add((k_pf) * 2) as *const i8);
prefetch_read(src.add((k_pf + quarter_n) * 2) as *const i8);
}
let k0 = g * m1 + j;
let k1 = k0 + 1;
let s0_0 = k0;
let s1_0 = k0 + quarter_n;
let s2_0 = k0 + half_n;
let s3_0 = k0 + half_n + quarter_n;
let dst_base0 = g * m4 + j;
let x0_0 = vld1q_f64(src.add(s0_0 * 2));
let x1_0 = vld1q_f64(src.add(s1_0 * 2));
let x2_0 = vld1q_f64(src.add(s2_0 * 2));
let x3_0 = vld1q_f64(src.add(s3_0 * 2));
let tw1_0 = vld1q_f64(tw1_base.add(j) as *const f64);
let tw1_flip_0 = vextq_f64(tw1_0, tw1_0, 1);
let t2_0 = neon_complex_mul(x2_0, tw1_0, tw1_flip_0, sign_pattern);
let t3_0 = neon_complex_mul(x3_0, tw1_0, tw1_flip_0, sign_pattern);
let a0_0 = vaddq_f64(x0_0, t2_0);
let a1_0 = vsubq_f64(x0_0, t2_0);
let a2_0 = vaddq_f64(x1_0, t3_0);
let a3_0 = vsubq_f64(x1_0, t3_0);
let tw2a_0 = vld1q_f64(tw2_base.add(j) as *const f64);
let tw2a_flip_0 = vextq_f64(tw2a_0, tw2a_0, 1);
let tw2b_0 = vld1q_f64(tw2_base.add(j + m1) as *const f64);
let tw2b_flip_0 = vextq_f64(tw2b_0, tw2b_0, 1);
let b2_0 = neon_complex_mul(a2_0, tw2a_0, tw2a_flip_0, sign_pattern);
let b3_0 = neon_complex_mul(a3_0, tw2b_0, tw2b_flip_0, sign_pattern);
vst1q_f64(dst.add(dst_base0 * 2), vaddq_f64(a0_0, b2_0));
vst1q_f64(dst.add((dst_base0 + m1) * 2), vaddq_f64(a1_0, b3_0));
vst1q_f64(dst.add((dst_base0 + m2) * 2), vsubq_f64(a0_0, b2_0));
vst1q_f64(dst.add((dst_base0 + m2 + m1) * 2), vsubq_f64(a1_0, b3_0));
let s0_1 = k1;
let s1_1 = k1 + quarter_n;
let s2_1 = k1 + half_n;
let s3_1 = k1 + half_n + quarter_n;
let dst_base1 = g * m4 + j + 1;
let x0_1 = vld1q_f64(src.add(s0_1 * 2));
let x1_1 = vld1q_f64(src.add(s1_1 * 2));
let x2_1 = vld1q_f64(src.add(s2_1 * 2));
let x3_1 = vld1q_f64(src.add(s3_1 * 2));
let tw1_1 = vld1q_f64(tw1_base.add(j + 1) as *const f64);
let tw1_flip_1 = vextq_f64(tw1_1, tw1_1, 1);
let t2_1 = neon_complex_mul(x2_1, tw1_1, tw1_flip_1, sign_pattern);
let t3_1 = neon_complex_mul(x3_1, tw1_1, tw1_flip_1, sign_pattern);
let a0_1 = vaddq_f64(x0_1, t2_1);
let a1_1 = vsubq_f64(x0_1, t2_1);
let a2_1 = vaddq_f64(x1_1, t3_1);
let a3_1 = vsubq_f64(x1_1, t3_1);
let tw2a_1 = vld1q_f64(tw2_base.add(j + 1) as *const f64);
let tw2a_flip_1 = vextq_f64(tw2a_1, tw2a_1, 1);
let tw2b_1 = vld1q_f64(tw2_base.add(j + 1 + m1) as *const f64);
let tw2b_flip_1 = vextq_f64(tw2b_1, tw2b_1, 1);
let b2_1 = neon_complex_mul(a2_1, tw2a_1, tw2a_flip_1, sign_pattern);
let b3_1 = neon_complex_mul(a3_1, tw2b_1, tw2b_flip_1, sign_pattern);
vst1q_f64(dst.add(dst_base1 * 2), vaddq_f64(a0_1, b2_1));
vst1q_f64(dst.add((dst_base1 + m1) * 2), vaddq_f64(a1_1, b3_1));
vst1q_f64(dst.add((dst_base1 + m2) * 2), vsubq_f64(a0_1, b2_1));
vst1q_f64(dst.add((dst_base1 + m2 + m1) * 2), vsubq_f64(a1_1, b3_1));
j += 2;
}
while j < m1 {
let k = g * m1 + j;
let s0 = k;
let s1 = k + quarter_n;
let s2 = k + half_n;
let s3 = k + half_n + quarter_n;
let dst_base = g * m4 + j;
let x0 = vld1q_f64(src.add(s0 * 2));
let x1 = vld1q_f64(src.add(s1 * 2));
let x2 = vld1q_f64(src.add(s2 * 2));
let x3 = vld1q_f64(src.add(s3 * 2));
let tw1 = vld1q_f64(tw1_base.add(j) as *const f64);
let tw1_flip = vextq_f64(tw1, tw1, 1);
let t2 = neon_complex_mul(x2, tw1, tw1_flip, sign_pattern);
let t3 = neon_complex_mul(x3, tw1, tw1_flip, sign_pattern);
let a0 = vaddq_f64(x0, t2);
let a1 = vsubq_f64(x0, t2);
let a2 = vaddq_f64(x1, t3);
let a3 = vsubq_f64(x1, t3);
let tw2a = vld1q_f64(tw2_base.add(j) as *const f64);
let tw2a_flip = vextq_f64(tw2a, tw2a, 1);
let tw2b = vld1q_f64(tw2_base.add(j + m1) as *const f64);
let tw2b_flip = vextq_f64(tw2b, tw2b, 1);
let b2 = neon_complex_mul(a2, tw2a, tw2a_flip, sign_pattern);
let b3 = neon_complex_mul(a3, tw2b, tw2b_flip, sign_pattern);
vst1q_f64(dst.add(dst_base * 2), vaddq_f64(a0, b2));
vst1q_f64(dst.add((dst_base + m1) * 2), vaddq_f64(a1, b3));
vst1q_f64(dst.add((dst_base + m2) * 2), vsubq_f64(a0, b2));
vst1q_f64(dst.add((dst_base + m2 + m1) * 2), vsubq_f64(a1, b3));
j += 1;
}
}
core::mem::swap(&mut src_ptr, &mut dst_ptr);
stage += 2;
m *= 4;
}
if stage < log_n {
let m2 = m * 2;
let tw_base = if forward {
twiddles.forward[twiddles.offsets[stage]..].as_ptr()
} else {
twiddles.inverse[twiddles.offsets[stage]..].as_ptr()
};
let src = src_ptr;
let dst = dst_ptr;
let num_groups = half_n / m;
for g in 0..num_groups {
let src_base = g * m;
let dst_base = g * m2;
for j in 0..m {
let src_u = src_base + j;
let src_v = src_u + half_n;
let dst_u = dst_base + j;
let dst_v = dst_u + m;
let u = vld1q_f64(src.add(src_u * 2));
let v = vld1q_f64(src.add(src_v * 2));
let tw = vld1q_f64(tw_base.add(j) as *const f64);
let tw_flip = vextq_f64(tw, tw, 1);
let t = neon_complex_mul(v, tw, tw_flip, sign_pattern);
vst1q_f64(dst.add(dst_u * 2), vaddq_f64(u, t));
vst1q_f64(dst.add(dst_v * 2), vsubq_f64(u, t));
}
}
}
}
}
#[inline(always)]
unsafe fn neon_complex_mul(
v: core::arch::aarch64::float64x2_t,
tw: core::arch::aarch64::float64x2_t,
tw_flip: core::arch::aarch64::float64x2_t,
sign_pattern: core::arch::aarch64::float64x2_t,
) -> core::arch::aarch64::float64x2_t {
unsafe {
use core::arch::aarch64::*;
let v_re = vdupq_laneq_f64::<0>(v);
let v_im = vdupq_laneq_f64::<1>(v);
let prod1 = vmulq_f64(v_re, tw);
let prod2 = vmulq_f64(v_im, tw_flip);
vfmaq_f64(prod1, prod2, sign_pattern)
}
}
#[target_feature(enable = "neon")]
unsafe fn stockham_small_neon(input: &[Complex<f64>], output: &mut [Complex<f64>], sign: Sign) {
unsafe {
use core::arch::aarch64::*;
let n = input.len();
let sign_val = f64::from(sign.value());
match n {
1 => {
output[0] = input[0];
}
2 => {
let x0 = vld1q_f64(input.as_ptr() as *const f64);
let x1 = vld1q_f64((input.as_ptr() as *const f64).add(2));
vst1q_f64(output.as_mut_ptr() as *mut f64, vaddq_f64(x0, x1));
vst1q_f64((output.as_mut_ptr() as *mut f64).add(2), vsubq_f64(x0, x1));
}
4 => {
let x0 = vld1q_f64(input.as_ptr() as *const f64);
let x1 = vld1q_f64((input.as_ptr() as *const f64).add(2));
let x2 = vld1q_f64((input.as_ptr() as *const f64).add(4));
let x3 = vld1q_f64((input.as_ptr() as *const f64).add(6));
let a = vaddq_f64(x0, x2);
let b = vsubq_f64(x0, x2);
let c = vaddq_f64(x1, x3);
let diff = vsubq_f64(x1, x3);
let swapped = vextq_f64(diff, diff, 1);
let d = if sign_val < 0.0 {
vmulq_f64(swapped, vld1q_f64([1.0, -1.0].as_ptr()))
} else {
vmulq_f64(swapped, vld1q_f64([-1.0, 1.0].as_ptr()))
};
vst1q_f64(output.as_mut_ptr() as *mut f64, vaddq_f64(a, c));
vst1q_f64((output.as_mut_ptr() as *mut f64).add(2), vaddq_f64(b, d));
vst1q_f64((output.as_mut_ptr() as *mut f64).add(4), vsubq_f64(a, c));
vst1q_f64((output.as_mut_ptr() as *mut f64).add(6), vsubq_f64(b, d));
}
_ => {
output.copy_from_slice(input);
}
}
}
}
#[inline(always)]
#[allow(dead_code)]
unsafe fn stockham_butterfly_neon(
src: *const f64,
dst: *mut f64,
src_u: usize,
src_v: usize,
dst_u: usize,
dst_v: usize,
w: &Complex<f64>,
sign_pattern: core::arch::aarch64::float64x2_t,
) {
unsafe {
use core::arch::aarch64::*;
let u = vld1q_f64(src.add(src_u * 2));
let v = vld1q_f64(src.add(src_v * 2));
let tw_ptr = core::ptr::from_ref(w) as *const f64;
let tw = vld1q_f64(tw_ptr);
let tw_flip = vextq_f64(tw, tw, 1);
let v_re = vdupq_laneq_f64::<0>(v);
let v_im = vdupq_laneq_f64::<1>(v);
let prod1 = vmulq_f64(v_re, tw);
let prod2 = vmulq_f64(v_im, tw_flip);
let t = vfmaq_f64(prod1, prod2, sign_pattern);
vst1q_f64(dst.add(dst_u * 2), vaddq_f64(u, t));
vst1q_f64(dst.add(dst_v * 2), vsubq_f64(u, t));
}
}
#[inline(always)]
unsafe fn prefetch_read(addr: *const i8) {
use core::arch::asm;
unsafe {
asm!(
"prfm pldl1keep, [{0}]",
in(reg) addr,
options(readonly, nostack, preserves_flags)
);
}
}
#[target_feature(enable = "neon")]
#[allow(dead_code)]
unsafe fn stockham_neon(input: &[Complex<f64>], output: &mut [Complex<f64>], sign: Sign) {
unsafe {
use core::arch::aarch64::*;
let n = input.len();
let log_n = n.trailing_zeros() as usize;
let sign_val = f64::from(sign.value());
let half_n = n / 2;
let mut scratch: Vec<Complex<f64>> = vec![Complex::zero(); n];
let (mut src_ptr, mut dst_ptr): (*mut Complex<f64>, *mut Complex<f64>) =
if log_n.is_multiple_of(2) {
output.copy_from_slice(input);
(output.as_mut_ptr(), scratch.as_mut_ptr())
} else {
scratch.copy_from_slice(input);
(scratch.as_mut_ptr(), output.as_mut_ptr())
};
let sign_arr = [-1.0_f64, 1.0];
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let mut m = 1;
for _ in 0..log_n {
let m2 = m * 2;
let num_groups = half_n / m;
let angle_step = sign_val * core::f64::consts::TAU / (m2 as f64);
let w_step = Complex::cis(angle_step);
let src = src_ptr as *const f64;
let dst = dst_ptr as *mut f64;
for g in 0..num_groups {
let src_base = g * m;
let dst_base = g * m2;
let mut w = Complex::new(1.0, 0.0);
let mut j = 0;
while j + 3 < m {
let w0 = w;
let w1 = w * w_step;
let w2 = w1 * w_step;
let w3 = w2 * w_step;
stockham_butterfly_neon(
src,
dst,
src_base + j,
src_base + j + half_n,
dst_base + j,
dst_base + j + m,
&w0,
sign_pattern,
);
stockham_butterfly_neon(
src,
dst,
src_base + j + 1,
src_base + j + 1 + half_n,
dst_base + j + 1,
dst_base + j + 1 + m,
&w1,
sign_pattern,
);
stockham_butterfly_neon(
src,
dst,
src_base + j + 2,
src_base + j + 2 + half_n,
dst_base + j + 2,
dst_base + j + 2 + m,
&w2,
sign_pattern,
);
stockham_butterfly_neon(
src,
dst,
src_base + j + 3,
src_base + j + 3 + half_n,
dst_base + j + 3,
dst_base + j + 3 + m,
&w3,
sign_pattern,
);
w = w3 * w_step;
j += 4;
}
while j < m {
stockham_butterfly_neon(
src,
dst,
src_base + j,
src_base + j + half_n,
dst_base + j,
dst_base + j + m,
&w,
sign_pattern,
);
w = w * w_step;
j += 1;
}
}
core::mem::swap(&mut src_ptr, &mut dst_ptr);
m *= 2;
}
}
}