use scirs2_core::numeric::Complex64;
use std::arch::aarch64::*;
#[inline(always)]
pub fn radix4_butterfly_scalar(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
let x0 = a[0];
let x1 = a[1];
let x2 = a[2];
let x3 = a[3];
let w1 = twiddles[0];
let w2 = twiddles[1];
let w3 = twiddles[2];
let w4 = w2 * w2;
let w6 = w4 * w2;
let w9 = w3 * w3 * w3;
a[0] = x0 + x1 + x2 + x3;
a[1] = x0 + w1 * x1 + w2 * x2 + w3 * x3;
a[2] = x0 + w2 * x1 + w4 * x2 + w6 * x3;
a[3] = x0 + w3 * x1 + w6 * x2 + w9 * x3;
}
#[inline(always)]
pub fn radix8_butterfly_scalar(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
let w = [
Complex64::new(1.0, 0.0),
twiddles[0],
twiddles[1],
twiddles[2],
twiddles[3],
twiddles[4],
twiddles[5],
twiddles[6],
];
let input = *a;
for k in 0..8 {
let mut sum = Complex64::new(0.0, 0.0);
for n in 0..8 {
let idx = (n * k) % 8;
sum += input[n] * w[idx];
}
a[k] = sum;
}
}
#[target_feature(enable = "neon")]
#[inline]
unsafe fn cmul_f64(a: float64x2_t, b: float64x2_t) -> float64x2_t {
let b_re = vdupq_laneq_f64::<0>(b); let b_im = vdupq_laneq_f64::<1>(b);
let a_swap = vextq_f64::<1>(a, a);
let ac = vmulq_f64(a, b_re);
let re_out = vfmsq_f64(ac, a_swap, b_im); let im_vec = vfmaq_f64(ac, a_swap, b_im); vcombine_f64(vget_low_f64(re_out), vget_high_f64(im_vec))
}
#[target_feature(enable = "neon")]
pub unsafe fn radix4_butterfly_neon(a: *mut Complex64, twiddles: *const Complex64) {
let p = a as *const f64;
let tp = twiddles as *const f64;
let x0 = vld1q_f64(p); let x1 = vld1q_f64(p.add(2)); let x2 = vld1q_f64(p.add(4)); let x3 = vld1q_f64(p.add(6));
let w1 = vld1q_f64(tp); let w2 = vld1q_f64(tp.add(2)); let w3 = vld1q_f64(tp.add(4));
let w4 = cmul_f64(w2, w2);
let w6 = cmul_f64(w4, w2);
let w3sq = cmul_f64(w3, w3);
let w9 = cmul_f64(w3sq, w3);
let out0 = vaddq_f64(vaddq_f64(x0, x1), vaddq_f64(x2, x3));
let out1 = vaddq_f64(
vaddq_f64(x0, cmul_f64(w1, x1)),
vaddq_f64(cmul_f64(w2, x2), cmul_f64(w3, x3)),
);
let out2 = vaddq_f64(
vaddq_f64(x0, cmul_f64(w2, x1)),
vaddq_f64(cmul_f64(w4, x2), cmul_f64(w6, x3)),
);
let out3 = vaddq_f64(
vaddq_f64(x0, cmul_f64(w3, x1)),
vaddq_f64(cmul_f64(w6, x2), cmul_f64(w9, x3)),
);
let q = a as *mut f64;
vst1q_f64(q, out0);
vst1q_f64(q.add(2), out1);
vst1q_f64(q.add(4), out2);
vst1q_f64(q.add(6), out3);
}
#[target_feature(enable = "neon")]
pub unsafe fn radix8_butterfly_neon(a: *mut Complex64, twiddles: *const Complex64) {
let p = a as *const f64;
let tp = twiddles as *const f64;
let x = [
vld1q_f64(p), vld1q_f64(p.add(2)), vld1q_f64(p.add(4)), vld1q_f64(p.add(6)), vld1q_f64(p.add(8)), vld1q_f64(p.add(10)), vld1q_f64(p.add(12)), vld1q_f64(p.add(14)), ];
let w0 = {
let re: f64 = 1.0;
let im: f64 = 0.0;
vld1q_f64([re, im].as_ptr())
};
let w = [
w0,
vld1q_f64(tp), vld1q_f64(tp.add(2)), vld1q_f64(tp.add(4)), vld1q_f64(tp.add(6)), vld1q_f64(tp.add(8)), vld1q_f64(tp.add(10)), vld1q_f64(tp.add(12)), ];
let mut out = [vdupq_n_f64(0.0); 8];
for k in 0..8 {
let mut sum = vdupq_n_f64(0.0);
for n in 0..8 {
let idx = (n * k) % 8;
sum = vaddq_f64(sum, cmul_f64(w[idx], x[n]));
}
out[k] = sum;
}
let q = a as *mut f64;
for k in 0..8 {
vst1q_f64(q.add(k * 2), out[k]);
}
}
#[inline]
pub fn is_neon_available() -> bool {
true
}
pub fn radix4_butterfly_dispatch(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
unsafe {
radix4_butterfly_neon(a.as_mut_ptr(), twiddles.as_ptr());
}
}
pub fn radix8_butterfly_dispatch(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
unsafe {
radix8_butterfly_neon(a.as_mut_ptr(), twiddles.as_ptr());
}
}
pub mod sve {
use super::{radix4_butterfly_neon, radix8_butterfly_neon};
use scirs2_core::numeric::Complex64;
pub fn is_sve_available() -> bool {
std::arch::is_aarch64_feature_detected!("sve")
}
pub fn radix4_butterfly_sve(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
unsafe {
radix4_butterfly_neon(a.as_mut_ptr(), twiddles.as_ptr());
}
}
pub fn radix8_butterfly_sve(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
unsafe {
radix8_butterfly_neon(a.as_mut_ptr(), twiddles.as_ptr());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn max_err(a: &[Complex64], b: &[Complex64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).norm())
.fold(0.0_f64, f64::max)
}
#[test]
fn test_scalar_radix4_matches_known() {
let input = [
Complex64::new(1.0, 0.0),
Complex64::new(2.0, 0.0),
Complex64::new(3.0, 0.0),
Complex64::new(4.0, 0.0),
];
let twiddles = [
Complex64::new(0.0, -1.0), Complex64::new(-1.0, 0.0), Complex64::new(0.0, 1.0), ];
let mut data = input;
radix4_butterfly_scalar(&mut data, &twiddles);
assert!((data[0] - Complex64::new(10.0, 0.0)).norm() < 1e-12);
assert!((data[1] - Complex64::new(-2.0, 2.0)).norm() < 1e-12);
assert!((data[2] - Complex64::new(-2.0, 0.0)).norm() < 1e-12);
assert!((data[3] - Complex64::new(-2.0, -2.0)).norm() < 1e-12);
}
#[test]
fn test_scalar_radix8_matches_direct_dft() {
let input: [Complex64; 8] = std::array::from_fn(|k| Complex64::new(k as f64 + 1.0, 0.0));
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut data = input;
radix8_butterfly_scalar(&mut data, &twiddles);
let mut reference = [Complex64::new(0.0, 0.0); 8];
for k in 0..8 {
for n in 0..8 {
let angle = -2.0 * PI * (n * k) as f64 / 8.0;
reference[k] += input[n] * Complex64::new(angle.cos(), angle.sin());
}
}
let err = max_err(&data, &reference);
assert!(err < 1e-10, "radix8 scalar err={err}");
}
#[test]
fn test_neon_radix4_matches_scalar() {
let twiddles = [
Complex64::new(0.0, -1.0), Complex64::new(-1.0, 0.0), Complex64::new(0.0, 1.0), ];
let input = [
Complex64::new(1.0, 2.0),
Complex64::new(3.0, 4.0),
Complex64::new(5.0, 6.0),
Complex64::new(7.0, 8.0),
];
let mut ref_data = input;
radix4_butterfly_scalar(&mut ref_data, &twiddles);
let mut neon_data = input;
unsafe { radix4_butterfly_neon(neon_data.as_mut_ptr(), twiddles.as_ptr()) };
let err = max_err(&ref_data, &neon_data);
assert!(
err < 1e-12,
"NEON radix-4 diverges from scalar by {err}\n scalar={ref_data:?}\n neon={neon_data:?}"
);
}
#[test]
fn test_neon_radix8_matches_scalar() {
let input: [Complex64; 8] = std::array::from_fn(|k| {
let t = k as f64 * 0.5;
Complex64::new(t.sin() + 1.0, t.cos() - 0.5)
});
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut ref_data = input;
radix8_butterfly_scalar(&mut ref_data, &twiddles);
let mut neon_data = input;
unsafe { radix8_butterfly_neon(neon_data.as_mut_ptr(), twiddles.as_ptr()) };
let err = max_err(&ref_data, &neon_data);
assert!(err < 1e-12, "NEON radix-8 diverges from scalar by {err}");
}
#[test]
fn test_dispatch_radix4_agrees_with_scalar() {
let twiddles = [
Complex64::new(0.0, -1.0),
Complex64::new(-1.0, 0.0),
Complex64::new(0.0, 1.0),
];
let input = [
Complex64::new(2.0, -1.0),
Complex64::new(0.5, 3.0),
Complex64::new(-1.0, 1.0),
Complex64::new(4.0, -2.0),
];
let mut ref_data = input;
radix4_butterfly_scalar(&mut ref_data, &twiddles);
let mut dispatch_data = input;
radix4_butterfly_dispatch(&mut dispatch_data, &twiddles);
let err = max_err(&ref_data, &dispatch_data);
assert!(err < 1e-12, "dispatch vs scalar radix-4 err={err}");
}
#[test]
fn test_dispatch_radix8_agrees_with_scalar() {
let input: [Complex64; 8] =
std::array::from_fn(|k| Complex64::new(k as f64 * 0.7 - 1.0, k as f64 * 0.3));
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut ref_data = input;
radix8_butterfly_scalar(&mut ref_data, &twiddles);
let mut dispatch_data = input;
radix8_butterfly_dispatch(&mut dispatch_data, &twiddles);
let err = max_err(&ref_data, &dispatch_data);
assert!(err < 1e-12, "dispatch radix-8 vs scalar err={err}");
}
}