#![allow(clippy::items_after_statements)] #![allow(clippy::large_stack_arrays)]
#![allow(clippy::too_many_lines)]
#[cfg(target_arch = "x86_64")]
use super::hand_avx512_twiddles::{
twiddles_16_f32, twiddles_16_f64, twiddles_32_f32, twiddles_32_f64, twiddles_64_f32,
twiddles_64_f64,
};
#[cfg(target_arch = "x86_64")]
use crate::kernel::Complex;
#[cfg(target_arch = "x86_64")]
const BR16: [usize; 16] = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15];
#[cfg(target_arch = "x86_64")]
const BR32: [usize; 32] = [
0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30, 1, 17, 9, 25, 5, 21, 13, 29, 3, 19,
11, 27, 7, 23, 15, 31,
];
#[cfg(target_arch = "x86_64")]
const BR64: [usize; 64] = [
0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, 2, 34, 18, 50, 10, 42, 26, 58, 6,
38, 22, 54, 14, 46, 30, 62, 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61, 3, 35,
19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63,
];
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn cmul_f64(re: f64, im: f64, c: f64, s: f64) -> (f64, f64) {
(re * c - im * s, re * s + im * c)
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn cmul_f32(re: f32, im: f32, c: f32, s: f32) -> (f32, f32) {
(re * c - im * s, re * s + im * c)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size16_f64(data: *mut Complex<f64>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_16_f64().fwd.0
} else {
&twiddles_16_f64().inv.0
};
let ptr = data.cast::<f64>();
let mut stage = [0.0_f64; 32];
for (i, &br_i) in BR16.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let (mut z0, mut z1, mut z2, mut z3) = unsafe {
(
_mm512_loadu_pd(sp),
_mm512_loadu_pd(sp.add(8)),
_mm512_loadu_pd(sp.add(16)),
_mm512_loadu_pd(sp.add(24)),
)
};
macro_rules! stage1_zmm {
($z:expr) => {{
let even = _mm512_shuffle_f64x2($z, $z, 0x88); let odd = _mm512_shuffle_f64x2($z, $z, 0xDD); let sum = _mm512_add_pd(even, odd);
let diff = _mm512_sub_pd(even, odd);
let lo = _mm512_shuffle_f64x2(sum, diff, 0b00_00_00_00); let hi = _mm512_shuffle_f64x2(sum, diff, 0b01_01_01_01); _mm512_shuffle_f64x2(lo, hi, 0b10_00_10_00) }};
}
z0 = stage1_zmm!(z0);
z1 = stage1_zmm!(z1);
z2 = stage1_zmm!(z2);
z3 = stage1_zmm!(z3);
let sp_mut = stage.as_mut_ptr();
unsafe {
_mm512_storeu_pd(sp_mut, z0);
_mm512_storeu_pd(sp_mut.add(8), z1);
_mm512_storeu_pd(sp_mut.add(16), z2);
_mm512_storeu_pd(sp_mut.add(24), z3);
}
let a = stage.as_mut_ptr();
let n = 16usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span; let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride; let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f64(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..16usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_16_F64.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size32_f64(data: *mut Complex<f64>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_32_f64().fwd.0
} else {
&twiddles_32_f64().inv.0
};
let ptr = data.cast::<f64>();
let mut stage = [0.0_f64; 64];
for (i, &br_i) in BR32.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let mut z = [_mm512_setzero_pd(); 8];
for (i, zi) in z.iter_mut().enumerate() {
*zi = unsafe { _mm512_loadu_pd(sp.add(i * 8)) };
}
macro_rules! stage1_zmm {
($z:expr) => {{
let even = _mm512_shuffle_f64x2($z, $z, 0x88);
let odd = _mm512_shuffle_f64x2($z, $z, 0xDD);
let sum = _mm512_add_pd(even, odd);
let diff = _mm512_sub_pd(even, odd);
let lo = _mm512_shuffle_f64x2(sum, diff, 0b00_00_00_00);
let hi = _mm512_shuffle_f64x2(sum, diff, 0b01_01_01_01);
_mm512_shuffle_f64x2(lo, hi, 0b10_00_10_00)
}};
}
for zi in &mut z {
*zi = stage1_zmm!(*zi);
}
let sp_mut = stage.as_mut_ptr();
for (i, zi) in z.iter().enumerate() {
unsafe { _mm512_storeu_pd(sp_mut.add(i * 8), *zi) };
}
let a = stage.as_mut_ptr();
let n = 32usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span;
let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride;
let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f64(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..32usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_32_F64.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size64_f64(data: *mut Complex<f64>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_64_f64().fwd.0
} else {
&twiddles_64_f64().inv.0
};
let ptr = data.cast::<f64>();
let mut stage = [0.0_f64; 128];
for (i, &br_i) in BR64.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let mut z = [_mm512_setzero_pd(); 16];
for (i, zi) in z.iter_mut().enumerate() {
*zi = unsafe { _mm512_loadu_pd(sp.add(i * 8)) };
}
macro_rules! stage1_zmm {
($z:expr) => {{
let even = _mm512_shuffle_f64x2($z, $z, 0x88);
let odd = _mm512_shuffle_f64x2($z, $z, 0xDD);
let sum = _mm512_add_pd(even, odd);
let diff = _mm512_sub_pd(even, odd);
let lo = _mm512_shuffle_f64x2(sum, diff, 0b00_00_00_00);
let hi = _mm512_shuffle_f64x2(sum, diff, 0b01_01_01_01);
_mm512_shuffle_f64x2(lo, hi, 0b10_00_10_00)
}};
}
for zi in &mut z {
*zi = stage1_zmm!(*zi);
}
let sp_mut = stage.as_mut_ptr();
for (i, zi) in z.iter().enumerate() {
unsafe { _mm512_storeu_pd(sp_mut.add(i * 8), *zi) };
}
let a = stage.as_mut_ptr();
let n = 64usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span;
let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride;
let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f64(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..64usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_64_F64.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size16_f32(data: *mut Complex<f32>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_16_f32().fwd.0
} else {
&twiddles_16_f32().inv.0
};
let ptr = data.cast::<f32>();
let mut stage = [0.0_f32; 32];
for (i, &br_i) in BR16.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let (mut z0, mut z1) = unsafe { (_mm512_loadu_ps(sp), _mm512_loadu_ps(sp.add(16))) };
macro_rules! stage1_zmm_f32 {
($z:expr) => {{
let a_dup = _mm512_shuffle_ps($z, $z, 0b01_00_01_00);
let b_dup = _mm512_shuffle_ps($z, $z, 0b11_10_11_10);
let sum = _mm512_add_ps(a_dup, b_dup);
let diff = _mm512_sub_ps(a_dup, b_dup);
_mm512_shuffle_ps(sum, diff, 0b01_00_01_00)
}};
}
z0 = stage1_zmm_f32!(z0);
z1 = stage1_zmm_f32!(z1);
let sp_mut = stage.as_mut_ptr();
unsafe {
_mm512_storeu_ps(sp_mut, z0);
_mm512_storeu_ps(sp_mut.add(16), z1);
}
let a = stage.as_mut_ptr();
let n = 16usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span;
let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride;
let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f32(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..16usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_16_F32.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size32_f32(data: *mut Complex<f32>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_32_f32().fwd.0
} else {
&twiddles_32_f32().inv.0
};
let ptr = data.cast::<f32>();
let mut stage = [0.0_f32; 64];
for (i, &br_i) in BR32.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let mut z = [_mm512_setzero_ps(); 4];
for (i, zi) in z.iter_mut().enumerate() {
*zi = unsafe { _mm512_loadu_ps(sp.add(i * 16)) };
}
macro_rules! stage1_zmm_f32 {
($z:expr) => {{
let a_dup = _mm512_shuffle_ps($z, $z, 0b01_00_01_00); let b_dup = _mm512_shuffle_ps($z, $z, 0b11_10_11_10); let sum = _mm512_add_ps(a_dup, b_dup);
let diff = _mm512_sub_ps(a_dup, b_dup);
_mm512_shuffle_ps(sum, diff, 0b01_00_01_00) }};
}
for zi in &mut z {
*zi = stage1_zmm_f32!(*zi);
}
let sp_mut = stage.as_mut_ptr();
for (i, zi) in z.iter().enumerate() {
unsafe { _mm512_storeu_ps(sp_mut.add(i * 16), *zi) };
}
let a = stage.as_mut_ptr();
let n = 32usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span;
let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride;
let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f32(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..32usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_32_F32.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hand_avx512_size64_f32(data: *mut Complex<f32>, sign: i32) {
use core::arch::x86_64::*;
let twiddles = if sign < 0 {
&twiddles_64_f32().fwd.0
} else {
&twiddles_64_f32().inv.0
};
let ptr = data.cast::<f32>();
let mut stage = [0.0_f32; 128];
for (i, &br_i) in BR64.iter().enumerate() {
unsafe {
let src = ptr.add(br_i * 2);
stage[2 * i] = *src;
stage[2 * i + 1] = *src.add(1);
}
}
let sp = stage.as_ptr();
let mut z = [_mm512_setzero_ps(); 8];
for (i, zi) in z.iter_mut().enumerate() {
*zi = unsafe { _mm512_loadu_ps(sp.add(i * 16)) };
}
macro_rules! stage1_zmm_f32 {
($z:expr) => {{
let a_dup = _mm512_shuffle_ps($z, $z, 0b01_00_01_00); let b_dup = _mm512_shuffle_ps($z, $z, 0b11_10_11_10); let sum = _mm512_add_ps(a_dup, b_dup);
let diff = _mm512_sub_ps(a_dup, b_dup);
_mm512_shuffle_ps(sum, diff, 0b01_00_01_00) }};
}
for zi in &mut z {
*zi = stage1_zmm_f32!(*zi);
}
let sp_mut = stage.as_mut_ptr();
for (i, zi) in z.iter().enumerate() {
unsafe { _mm512_storeu_ps(sp_mut.add(i * 16), *zi) };
}
let a = stage.as_mut_ptr();
let n = 64usize;
let mut span = 4usize;
while span <= n {
let half = span / 2;
let stride = n / span;
let mut k = 0;
while k < n {
for j in 0..half {
let tw_idx = j * stride;
let c = twiddles[2 * tw_idx];
let s = twiddles[2 * tw_idx + 1];
unsafe {
let u_re = *a.add((k + j) * 2);
let u_im = *a.add((k + j) * 2 + 1);
let v_re_raw = *a.add((k + j + half) * 2);
let v_im_raw = *a.add((k + j + half) * 2 + 1);
let (v_re, v_im) = cmul_f32(v_re_raw, v_im_raw, c, s);
*a.add((k + j) * 2) = u_re + v_re;
*a.add((k + j) * 2 + 1) = u_im + v_im;
*a.add((k + j + half) * 2) = u_re - v_re;
*a.add((k + j + half) * 2 + 1) = u_im - v_im;
}
}
k += span;
}
span *= 2;
}
let src = stage.as_ptr();
unsafe {
for i in 0..64usize {
*data.add(i) = Complex {
re: *src.add(2 * i),
im: *src.add(2 * i + 1),
};
}
}
#[cfg(test)]
{
use core::sync::atomic::Ordering;
super::hand_avx512_tests::HAND_AVX512_HIT_64_F32.store(true, Ordering::Relaxed);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size16_f64(data: &mut [Complex<f64>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size16_f64(data.as_mut_ptr(), sign);
}
} else {
super::notw_16(data, sign);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size32_f64(data: &mut [Complex<f64>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size32_f64(data.as_mut_ptr(), sign);
}
} else {
super::notw_32(data, sign);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size64_f64(data: &mut [Complex<f64>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size64_f64(data.as_mut_ptr(), sign);
}
} else {
super::notw_64(data, sign);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size16_f32(data: &mut [Complex<f32>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size16_f32(data.as_mut_ptr(), sign);
}
} else {
super::notw_16(data, sign);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size32_f32(data: &mut [Complex<f32>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size32_f32(data.as_mut_ptr(), sign);
}
} else {
super::notw_32(data, sign);
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn dispatch_hand_avx512_size64_f32(data: &mut [Complex<f32>], sign: i32) {
if is_x86_feature_detected!("avx512f") {
unsafe {
hand_avx512_size64_f32(data.as_mut_ptr(), sign);
}
} else {
super::notw_64(data, sign);
}
}