use scirs2_core::numeric::Complex64;
#[inline]
pub fn is_avx512_available() -> bool {
is_x86_feature_detected!("avx512f")
}
#[inline(always)]
pub fn radix4_butterfly_scalar(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
let x0 = a[0];
let x1 = a[1];
let x2 = a[2];
let x3 = a[3];
let w1 = twiddles[0];
let w2 = twiddles[1];
let w3 = twiddles[2];
let w4 = w2 * w2;
let w6 = w4 * w2;
let w9 = w3 * w3 * w3;
a[0] = x0 + x1 + x2 + x3;
a[1] = x0 + w1 * x1 + w2 * x2 + w3 * x3;
a[2] = x0 + w2 * x1 + w4 * x2 + w6 * x3;
a[3] = x0 + w3 * x1 + w6 * x2 + w9 * x3;
}
#[inline(always)]
pub fn radix8_butterfly_scalar(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
let w = [
Complex64::new(1.0, 0.0),
twiddles[0],
twiddles[1],
twiddles[2],
twiddles[3],
twiddles[4],
twiddles[5],
twiddles[6],
];
let input = *a;
for k in 0..8 {
let mut sum = Complex64::new(0.0, 0.0);
for n in 0..8 {
let idx = (n * k) % 8;
sum += input[n] * w[idx];
}
a[k] = sum;
}
}
use std::arch::x86_64::*;
#[target_feature(enable = "avx512f,sse3")]
#[inline]
unsafe fn cmul_pd(a: __m128d, b: __m128d) -> __m128d {
let re_b = _mm_shuffle_pd(b, b, 0b00); let im_b = _mm_shuffle_pd(b, b, 0b11);
let a_swap = _mm_shuffle_pd(a, a, 0b01);
let ac = _mm_mul_pd(a, re_b);
let ad = _mm_mul_pd(a_swap, im_b);
_mm_addsub_pd(ac, ad)
}
#[target_feature(enable = "avx512f,sse3")]
pub unsafe fn radix4_butterfly_avx512(a: *mut Complex64, twiddles: *const Complex64) {
let p = a as *const f64;
let tp = twiddles as *const f64;
let x0 = _mm_loadu_pd(p); let x1 = _mm_loadu_pd(p.add(2)); let x2 = _mm_loadu_pd(p.add(4)); let x3 = _mm_loadu_pd(p.add(6));
let w1 = _mm_loadu_pd(tp); let w2 = _mm_loadu_pd(tp.add(2)); let w3 = _mm_loadu_pd(tp.add(4));
let w4 = cmul_pd(w2, w2);
let w6 = cmul_pd(w4, w2);
let w3sq = cmul_pd(w3, w3);
let w9 = cmul_pd(w3sq, w3);
let out0 = _mm_add_pd(_mm_add_pd(x0, x1), _mm_add_pd(x2, x3));
let out1 = _mm_add_pd(
_mm_add_pd(x0, cmul_pd(w1, x1)),
_mm_add_pd(cmul_pd(w2, x2), cmul_pd(w3, x3)),
);
let out2 = _mm_add_pd(
_mm_add_pd(x0, cmul_pd(w2, x1)),
_mm_add_pd(cmul_pd(w4, x2), cmul_pd(w6, x3)),
);
let out3 = _mm_add_pd(
_mm_add_pd(x0, cmul_pd(w3, x1)),
_mm_add_pd(cmul_pd(w6, x2), cmul_pd(w9, x3)),
);
let q = a as *mut f64;
_mm_storeu_pd(q, out0);
_mm_storeu_pd(q.add(2), out1);
_mm_storeu_pd(q.add(4), out2);
_mm_storeu_pd(q.add(6), out3);
}
#[target_feature(enable = "avx512f,sse3")]
pub unsafe fn radix4_butterfly_x2_avx512(
a0: *mut Complex64,
tw0: *const Complex64,
a1: *mut Complex64,
tw1: *const Complex64,
) {
radix4_butterfly_avx512(a0, tw0);
radix4_butterfly_avx512(a1, tw1);
}
#[target_feature(enable = "avx512f,sse3")]
pub unsafe fn radix8_butterfly_avx512(a: *mut Complex64, twiddles: *const Complex64) {
let p = a as *const f64;
let tp = twiddles as *const f64;
let x = [
_mm_loadu_pd(p), _mm_loadu_pd(p.add(2)), _mm_loadu_pd(p.add(4)), _mm_loadu_pd(p.add(6)), _mm_loadu_pd(p.add(8)), _mm_loadu_pd(p.add(10)), _mm_loadu_pd(p.add(12)), _mm_loadu_pd(p.add(14)), ];
let w0 = _mm_set_pd(0.0, 1.0);
let w = [
w0,
_mm_loadu_pd(tp), _mm_loadu_pd(tp.add(2)), _mm_loadu_pd(tp.add(4)), _mm_loadu_pd(tp.add(6)), _mm_loadu_pd(tp.add(8)), _mm_loadu_pd(tp.add(10)), _mm_loadu_pd(tp.add(12)), ];
let mut out = [_mm_setzero_pd(); 8];
for k in 0..8 {
let mut sum = _mm_setzero_pd();
for n in 0..8 {
let idx = (n * k) % 8;
sum = _mm_add_pd(sum, cmul_pd(w[idx], x[n]));
}
out[k] = sum;
}
let q = a as *mut f64;
for k in 0..8 {
_mm_storeu_pd(q.add(k * 2), out[k]);
}
}
pub fn radix4_butterfly_dispatch(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
if is_avx512_available() {
unsafe {
radix4_butterfly_avx512(a.as_mut_ptr(), twiddles.as_ptr());
}
} else {
radix4_butterfly_scalar(a, twiddles);
}
}
pub fn radix8_butterfly_dispatch(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
if is_avx512_available() {
unsafe {
radix8_butterfly_avx512(a.as_mut_ptr(), twiddles.as_ptr());
}
} else {
radix8_butterfly_scalar(a, twiddles);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn max_err(a: &[Complex64], b: &[Complex64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).norm())
.fold(0.0_f64, f64::max)
}
#[test]
fn test_scalar_radix4_matches_known() {
let input = [
Complex64::new(1.0, 0.0),
Complex64::new(2.0, 0.0),
Complex64::new(3.0, 0.0),
Complex64::new(4.0, 0.0),
];
let twiddles = [
Complex64::new(0.0, -1.0), Complex64::new(-1.0, 0.0), Complex64::new(0.0, 1.0), ];
let mut data = input;
radix4_butterfly_scalar(&mut data, &twiddles);
assert!((data[0] - Complex64::new(10.0, 0.0)).norm() < 1e-12);
assert!((data[1] - Complex64::new(-2.0, 2.0)).norm() < 1e-12);
assert!((data[2] - Complex64::new(-2.0, 0.0)).norm() < 1e-12);
assert!((data[3] - Complex64::new(-2.0, -2.0)).norm() < 1e-12);
}
#[test]
fn test_scalar_radix8_matches_direct_dft() {
let input: [Complex64; 8] = std::array::from_fn(|k| Complex64::new(k as f64 + 1.0, 0.0));
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut data = input;
radix8_butterfly_scalar(&mut data, &twiddles);
let mut reference = [Complex64::new(0.0, 0.0); 8];
for k in 0..8 {
for n in 0..8 {
let angle = -2.0 * PI * (n * k) as f64 / 8.0;
reference[k] += input[n] * Complex64::new(angle.cos(), angle.sin());
}
}
let err = max_err(&data, &reference);
assert!(err < 1e-10, "radix8 scalar err={err}");
}
#[test]
fn test_avx512_radix4_matches_scalar() {
let twiddles = [
Complex64::new(0.0, -1.0),
Complex64::new(-1.0, 0.0),
Complex64::new(0.0, 1.0),
];
let input = [
Complex64::new(1.0, 2.0),
Complex64::new(3.0, 4.0),
Complex64::new(5.0, 6.0),
Complex64::new(7.0, 8.0),
];
let mut ref_data = input;
radix4_butterfly_scalar(&mut ref_data, &twiddles);
if is_avx512_available() {
let mut avx_data = input;
unsafe {
radix4_butterfly_avx512(avx_data.as_mut_ptr(), twiddles.as_ptr());
}
let err = max_err(&ref_data, &avx_data);
assert!(
err < 1e-12,
"AVX-512 radix-4 diverges from scalar by {err}: \nscalar={ref_data:?}\navx512={avx_data:?}"
);
} else {
eprintln!("[avx512] AVX-512F not available on this host — compile-check only");
}
assert!(ref_data
.iter()
.all(|c| c.re.is_finite() && c.im.is_finite()));
}
#[test]
fn test_avx512_radix8_matches_scalar() {
let input: [Complex64; 8] =
std::array::from_fn(|k| Complex64::new((k as f64 + 1.0) * 0.5, -(k as f64 * 0.3)));
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut ref_data = input;
radix8_butterfly_scalar(&mut ref_data, &twiddles);
if is_avx512_available() {
let mut avx_data = input;
unsafe {
radix8_butterfly_avx512(avx_data.as_mut_ptr(), twiddles.as_ptr());
}
let err = max_err(&ref_data, &avx_data);
assert!(err < 1e-12, "AVX-512 radix-8 diverges from scalar by {err}");
} else {
eprintln!("[avx512] AVX-512F not available on this host — compile-check only");
}
assert!(ref_data
.iter()
.all(|c| c.re.is_finite() && c.im.is_finite()));
}
#[test]
fn test_dispatch_radix4_agrees_with_scalar() {
let twiddles = [
Complex64::new(0.0, -1.0),
Complex64::new(-1.0, 0.0),
Complex64::new(0.0, 1.0),
];
let input = [
Complex64::new(2.0, -1.0),
Complex64::new(0.5, 3.0),
Complex64::new(-1.0, 1.0),
Complex64::new(4.0, -2.0),
];
let mut ref_data = input;
radix4_butterfly_scalar(&mut ref_data, &twiddles);
let mut dispatch_data = input;
radix4_butterfly_dispatch(&mut dispatch_data, &twiddles);
let err = max_err(&ref_data, &dispatch_data);
assert!(err < 1e-12, "dispatch vs scalar err={err}");
}
#[test]
fn test_dispatch_radix8_agrees_with_scalar() {
let input: [Complex64; 8] =
std::array::from_fn(|k| Complex64::new(k as f64 * 0.7 - 1.0, k as f64 * 0.3));
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut ref_data = input;
radix8_butterfly_scalar(&mut ref_data, &twiddles);
let mut dispatch_data = input;
radix8_butterfly_dispatch(&mut dispatch_data, &twiddles);
let err = max_err(&ref_data, &dispatch_data);
assert!(err < 1e-12, "dispatch radix-8 vs scalar err={err}");
}
}