#![allow(clippy::items_after_statements)]
#![allow(clippy::large_stack_arrays)]
use crate::kernel::Complex;
use crate::prelude::*;
#[inline]
pub fn notw_64_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 64);
bit_reverse_permute_64(x);
dit_64_precomputed(&mut x[..64], sign);
}
#[cfg(target_arch = "aarch64")]
fn dit_64_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<[[Complex<f64>; 32]; 6]> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(|| {
let mut tw = [[Complex::new(0.0, 0.0); 32]; 6];
for s in 0..6 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
tw[s][j] = Complex::cis(angle);
}
}
tw
});
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let mut m = 2usize;
for s in 0..6 {
let half_m = m / 2;
let tw_stage = &twiddles[s];
if sign > 0 {
for k in (0..64).step_by(m) {
for j in 0..half_m {
let w = Complex::new(tw_stage[j].re, -tw_stage[j].im);
neon_butterfly_inline(ptr, k + j, half_m, w, sign_pattern);
}
}
} else {
for k in (0..64).step_by(m) {
for j in 0..half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
}
}
}
m *= 2;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
fn dit_64_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[inline]
fn bit_reverse_permute_64(x: &mut [Complex<f64>]) {
const BIT_REV_64: [usize; 64] = [
0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60, 2, 34, 18, 50, 10, 42, 26, 58,
6, 38, 22, 54, 14, 46, 30, 62, 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63,
];
for i in 0..64 {
let j = BIT_REV_64[i];
if i < j {
x.swap(i, j);
}
}
}
#[inline]
pub fn notw_128_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 128);
bit_reverse_permute_128(x);
dit_128_precomputed(&mut x[..128], sign);
}
#[cfg(target_arch = "aarch64")]
fn dit_128_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<[[Complex<f64>; 64]; 7]> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(|| {
let mut tw = [[Complex::new(0.0, 0.0); 64]; 7];
for s in 0..7 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
tw[s][j] = Complex::cis(angle);
}
}
tw
});
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let mut m = 2usize;
for s in 0..7 {
let half_m = m / 2;
let tw_stage = &twiddles[s];
if half_m >= 4 {
if sign > 0 {
for k in (0..128).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
let w0 = Complex::new(tw_stage[j].re, -tw_stage[j].im);
let w1 = Complex::new(tw_stage[j + 1].re, -tw_stage[j + 1].im);
let w2 = Complex::new(tw_stage[j + 2].re, -tw_stage[j + 2].im);
let w3 = Complex::new(tw_stage[j + 3].re, -tw_stage[j + 3].im);
neon_butterfly_inline(ptr, k + j, half_m, w0, sign_pattern);
neon_butterfly_inline(ptr, k + j + 1, half_m, w1, sign_pattern);
neon_butterfly_inline(ptr, k + j + 2, half_m, w2, sign_pattern);
neon_butterfly_inline(ptr, k + j + 3, half_m, w3, sign_pattern);
j += 4;
}
while j < half_m {
let w = Complex::new(tw_stage[j].re, -tw_stage[j].im);
neon_butterfly_inline(ptr, k + j, half_m, w, sign_pattern);
j += 1;
}
}
} else {
for k in (0..128).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
neon_butterfly_inline(
ptr,
k + j + 1,
half_m,
tw_stage[j + 1],
sign_pattern,
);
neon_butterfly_inline(
ptr,
k + j + 2,
half_m,
tw_stage[j + 2],
sign_pattern,
);
neon_butterfly_inline(
ptr,
k + j + 3,
half_m,
tw_stage[j + 3],
sign_pattern,
);
j += 4;
}
while j < half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
j += 1;
}
}
}
} else {
if sign > 0 {
for k in (0..128).step_by(m) {
for j in 0..half_m {
let w = Complex::new(tw_stage[j].re, -tw_stage[j].im);
neon_butterfly_inline(ptr, k + j, half_m, w, sign_pattern);
}
}
} else {
for k in (0..128).step_by(m) {
for j in 0..half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
}
}
}
}
m *= 2;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
fn dit_128_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[inline]
fn bit_reverse_permute_128(x: &mut [Complex<f64>]) {
const BIT_REV_128: [usize; 128] = [
0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, 4, 68, 36, 100, 20, 84,
52, 116, 12, 76, 44, 108, 28, 92, 60, 124, 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106,
26, 90, 58, 122, 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, 1, 65,
33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, 5, 69, 37, 101, 21, 85, 53, 117,
13, 77, 45, 109, 29, 93, 61, 125, 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91,
59, 123, 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127,
];
for i in 0..128 {
let j = BIT_REV_128[i];
if i < j {
x.swap(i, j);
}
}
}
#[inline]
pub fn notw_256_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 256);
bit_reverse_permute_256(x);
dit_256_precomputed(&mut x[..256], sign);
}
#[cfg(target_arch = "aarch64")]
struct TwiddlesF64_256 {
forward: [[[f64; 2]; 128]; 8],
inverse: [[[f64; 2]; 128]; 8],
}
#[cfg(target_arch = "aarch64")]
impl TwiddlesF64_256 {
fn new() -> Self {
let mut forward = [[[-0.0_f64; 2]; 128]; 8];
let mut inverse = [[[-0.0_f64; 2]; 128]; 8];
for s in 0..8 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
let (sin_a, cos_a) = angle.sin_cos();
forward[s][j] = [cos_a, sin_a];
inverse[s][j] = [cos_a, -sin_a];
}
}
Self { forward, inverse }
}
}
#[cfg(target_arch = "aarch64")]
fn dit_256_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<TwiddlesF64_256> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(TwiddlesF64_256::new);
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let tw_table = if sign > 0 {
&twiddles.inverse
} else {
&twiddles.forward
};
let mut m = 2usize;
for s in 0..8 {
let half_m = m / 2;
let tw_stage = &tw_table[s];
for k in (0..256).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
neon_butterfly_fast(
ptr,
k + j + 1,
half_m,
tw_stage[j + 1].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 2,
half_m,
tw_stage[j + 2].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 3,
half_m,
tw_stage[j + 3].as_ptr(),
sign_pattern,
);
j += 4;
}
while j < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
j += 1;
}
}
m *= 2;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn neon_butterfly_fast(
ptr: *mut f64,
k_j: usize,
half_m: usize,
tw_ptr: *const f64,
sign_pattern: core::arch::aarch64::float64x2_t,
) {
use core::arch::aarch64::*;
unsafe {
let u_ptr = ptr.add(k_j * 2);
let v_ptr = ptr.add((k_j + half_m) * 2);
let u = vld1q_f64(u_ptr);
let v = vld1q_f64(v_ptr);
let tw = vld1q_f64(tw_ptr);
let tw_flip = vextq_f64(tw, tw, 1);
let v_re = vdupq_laneq_f64::<0>(v);
let v_im = vdupq_laneq_f64::<1>(v);
let prod1 = vmulq_f64(v_re, tw);
let prod2 = vmulq_f64(v_im, tw_flip);
let t = vfmaq_f64(prod1, prod2, sign_pattern);
let out_u = vaddq_f64(u, t);
let out_v = vsubq_f64(u, t);
vst1q_f64(u_ptr, out_u);
vst1q_f64(v_ptr, out_v);
}
}
#[cfg(not(target_arch = "aarch64"))]
fn dit_256_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn neon_butterfly_inline(
ptr: *mut f64,
k_j: usize,
half_m: usize,
w: Complex<f64>,
sign_pattern: core::arch::aarch64::float64x2_t,
) {
use core::arch::aarch64::*;
unsafe {
let u_ptr = ptr.add(k_j * 2);
let v_ptr = ptr.add((k_j + half_m) * 2);
let u = vld1q_f64(u_ptr);
let v = vld1q_f64(v_ptr);
let tw_ptr = core::ptr::from_ref(&w) as *const f64;
let tw = vld1q_f64(tw_ptr);
let tw_flip = vextq_f64(tw, tw, 1);
let v_re = vdupq_laneq_f64::<0>(v);
let v_im = vdupq_laneq_f64::<1>(v);
let prod1 = vmulq_f64(v_re, tw);
let prod2 = vmulq_f64(v_im, tw_flip);
let t = vfmaq_f64(prod1, prod2, sign_pattern);
let out_u = vaddq_f64(u, t);
let out_v = vsubq_f64(u, t);
vst1q_f64(u_ptr, out_u);
vst1q_f64(v_ptr, out_v);
}
}
#[inline]
fn bit_reverse_permute_256(x: &mut [Complex<f64>]) {
static BIT_REV_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
let mut i = 0;
while i < 256 {
let mut x = i as u8;
let mut rev = 0u8;
let mut j = 0;
while j < 8 {
rev = (rev << 1) | (x & 1);
x >>= 1;
j += 1;
}
table[i] = rev;
i += 1;
}
table
};
for i in 0..256 {
let j = BIT_REV_TABLE[i] as usize;
if i < j {
x.swap(i, j);
}
}
}
#[inline]
pub fn notw_512_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 512);
bit_reverse_permute_512(x);
dit_512_precomputed(&mut x[..512], sign);
}
#[cfg(target_arch = "aarch64")]
fn dit_512_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<[[Complex<f64>; 256]; 9]> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(|| {
let mut tw = [[Complex::new(0.0, 0.0); 256]; 9];
for s in 0..9 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
tw[s][j] = Complex::cis(angle);
}
}
tw
});
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let mut m = 2usize;
for s in 0..9 {
let half_m = m / 2;
let tw_stage = &twiddles[s];
if half_m >= 4 {
if sign > 0 {
for k in (0..512).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
let w0 = Complex::new(tw_stage[j].re, -tw_stage[j].im);
let w1 = Complex::new(tw_stage[j + 1].re, -tw_stage[j + 1].im);
let w2 = Complex::new(tw_stage[j + 2].re, -tw_stage[j + 2].im);
let w3 = Complex::new(tw_stage[j + 3].re, -tw_stage[j + 3].im);
neon_butterfly_inline(ptr, k + j, half_m, w0, sign_pattern);
neon_butterfly_inline(ptr, k + j + 1, half_m, w1, sign_pattern);
neon_butterfly_inline(ptr, k + j + 2, half_m, w2, sign_pattern);
neon_butterfly_inline(ptr, k + j + 3, half_m, w3, sign_pattern);
j += 4;
}
while j < half_m {
let w = Complex::new(tw_stage[j].re, -tw_stage[j].im);
neon_butterfly_inline(ptr, k + j, half_m, w, sign_pattern);
j += 1;
}
}
} else {
for k in (0..512).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
neon_butterfly_inline(
ptr,
k + j + 1,
half_m,
tw_stage[j + 1],
sign_pattern,
);
neon_butterfly_inline(
ptr,
k + j + 2,
half_m,
tw_stage[j + 2],
sign_pattern,
);
neon_butterfly_inline(
ptr,
k + j + 3,
half_m,
tw_stage[j + 3],
sign_pattern,
);
j += 4;
}
while j < half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
j += 1;
}
}
}
} else {
if sign > 0 {
for k in (0..512).step_by(m) {
for j in 0..half_m {
let w = Complex::new(tw_stage[j].re, -tw_stage[j].im);
neon_butterfly_inline(ptr, k + j, half_m, w, sign_pattern);
}
}
} else {
for k in (0..512).step_by(m) {
for j in 0..half_m {
neon_butterfly_inline(ptr, k + j, half_m, tw_stage[j], sign_pattern);
}
}
}
}
m *= 2;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
fn dit_512_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[inline]
fn bit_reverse_permute_512(x: &mut [Complex<f64>]) {
static BIT_REV_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
let mut i = 0;
while i < 256 {
let mut x = i as u8;
let mut rev = 0u8;
let mut j = 0;
while j < 8 {
rev = (rev << 1) | (x & 1);
x >>= 1;
j += 1;
}
table[i] = rev;
i += 1;
}
table
};
for i in 0..512 {
let low = i & 0xFF;
let high = (i >> 8) & 0x01;
let j = high | ((BIT_REV_TABLE[low] as usize) << 1);
if i < j {
x.swap(i, j);
}
}
}
#[inline]
pub fn notw_1024_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 1024);
bit_reverse_permute_1024(x);
dit_1024_precomputed(&mut x[..1024], sign);
}
#[cfg(target_arch = "aarch64")]
struct TwiddlesF64_1024 {
forward: [[[f64; 2]; 512]; 10],
inverse: [[[f64; 2]; 512]; 10],
}
#[cfg(target_arch = "aarch64")]
impl TwiddlesF64_1024 {
fn new() -> Self {
let mut forward = [[[-0.0_f64; 2]; 512]; 10];
let mut inverse = [[[-0.0_f64; 2]; 512]; 10];
for s in 0..10 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
let (sin_a, cos_a) = angle.sin_cos();
forward[s][j] = [cos_a, sin_a];
inverse[s][j] = [cos_a, -sin_a];
}
}
Self { forward, inverse }
}
}
#[cfg(target_arch = "aarch64")]
fn dit_1024_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<TwiddlesF64_1024> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(TwiddlesF64_1024::new);
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let tw_table = if sign > 0 {
&twiddles.inverse
} else {
&twiddles.forward
};
let mut m = 2usize;
for s in 0..10 {
let half_m = m / 2;
let tw_stage = &tw_table[s];
for k in (0..1024).step_by(m) {
let mut j = 0;
while j + 3 < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
neon_butterfly_fast(
ptr,
k + j + 1,
half_m,
tw_stage[j + 1].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 2,
half_m,
tw_stage[j + 2].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 3,
half_m,
tw_stage[j + 3].as_ptr(),
sign_pattern,
);
j += 4;
}
while j < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
j += 1;
}
}
m *= 2;
}
}
}
#[cfg(target_arch = "x86_64")]
struct TwiddlesF64_1024X86 {
forward: [[[f64; 2]; 512]; 10],
inverse: [[[f64; 2]; 512]; 10],
}
#[cfg(target_arch = "x86_64")]
impl TwiddlesF64_1024X86 {
fn new() -> Self {
let mut forward = [[[-0.0_f64; 2]; 512]; 10];
let mut inverse = [[[-0.0_f64; 2]; 512]; 10];
for s in 0..10 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
let (sin_a, cos_a) = angle.sin_cos();
forward[s][j] = [cos_a, sin_a];
inverse[s][j] = [cos_a, -sin_a];
}
}
Self { forward, inverse }
}
}
#[cfg(target_arch = "x86_64")]
fn dit_1024_precomputed(data: &mut [Complex<f64>], sign: i32) {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { dit_1024_avx2(data, sign) }
} else {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dit_1024_avx2(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::x86_64::*;
static TWIDDLES: OnceLock<TwiddlesF64_1024X86> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(TwiddlesF64_1024X86::new);
let ptr = data.as_mut_ptr() as *mut f64;
let sign_f = f64::from(sign);
let tw_table = if sign > 0 {
&twiddles.inverse
} else {
&twiddles.forward
};
let sqrt2_2 = core::f64::consts::FRAC_1_SQRT_2;
let w8_1 = Complex::new(sqrt2_2, sign_f * sqrt2_2);
let w8_3 = Complex::new(-sqrt2_2, sign_f * sqrt2_2);
let c16_1 = (core::f64::consts::PI / 8.0).cos();
let s16_1 = (core::f64::consts::PI / 8.0).sin();
let c16_3 = (3.0 * core::f64::consts::PI / 8.0).cos();
let s16_3 = (3.0 * core::f64::consts::PI / 8.0).sin();
let w16_1 = Complex::new(c16_1, sign_f * s16_1);
let w16_2 = Complex::new(sqrt2_2, sign_f * sqrt2_2);
let w16_3 = Complex::new(c16_3, sign_f * s16_3);
let w16_5 = Complex::new(-c16_3, sign_f * s16_3);
let w16_6 = Complex::new(-sqrt2_2, sign_f * sqrt2_2);
let w16_7 = Complex::new(-c16_1, sign_f * s16_1);
for k in (0..1024).step_by(16) {
let mut x: [Complex<f64>; 16] = [
data[k],
data[k + 1],
data[k + 2],
data[k + 3],
data[k + 4],
data[k + 5],
data[k + 6],
data[k + 7],
data[k + 8],
data[k + 9],
data[k + 10],
data[k + 11],
data[k + 12],
data[k + 13],
data[k + 14],
data[k + 15],
];
for i in (0..16).step_by(2) {
let u = x[i];
let v = x[i + 1];
x[i] = u + v;
x[i + 1] = u - v;
}
for i in (0..16).step_by(4) {
let u0 = x[i];
let u1 = x[i + 1];
let v0 = x[i + 2];
let v1 = x[i + 3];
let t1 = Complex::new(-sign_f * v1.im, sign_f * v1.re);
x[i] = u0 + v0;
x[i + 1] = u1 + t1;
x[i + 2] = u0 - v0;
x[i + 3] = u1 - t1;
}
for base in [0, 8] {
let u0 = x[base];
let u1 = x[base + 1];
let u2 = x[base + 2];
let u3 = x[base + 3];
let v0 = x[base + 4];
let v1 = x[base + 5] * w8_1;
let v2 = Complex::new(-sign_f * x[base + 6].im, sign_f * x[base + 6].re);
let v3 = x[base + 7] * w8_3;
x[base] = u0 + v0;
x[base + 1] = u1 + v1;
x[base + 2] = u2 + v2;
x[base + 3] = u3 + v3;
x[base + 4] = u0 - v0;
x[base + 5] = u1 - v1;
x[base + 6] = u2 - v2;
x[base + 7] = u3 - v3;
}
let t0 = x[8];
let t1 = x[9] * w16_1;
let t2 = x[10] * w16_2;
let t3 = x[11] * w16_3;
let t4 = Complex::new(-sign_f * x[12].im, sign_f * x[12].re);
let t5 = x[13] * w16_5;
let t6 = x[14] * w16_6;
let t7 = x[15] * w16_7;
data[k] = x[0] + t0;
data[k + 1] = x[1] + t1;
data[k + 2] = x[2] + t2;
data[k + 3] = x[3] + t3;
data[k + 4] = x[4] + t4;
data[k + 5] = x[5] + t5;
data[k + 6] = x[6] + t6;
data[k + 7] = x[7] + t7;
data[k + 8] = x[0] - t0;
data[k + 9] = x[1] - t1;
data[k + 10] = x[2] - t2;
data[k + 11] = x[3] - t3;
data[k + 12] = x[4] - t4;
data[k + 13] = x[5] - t5;
data[k + 14] = x[6] - t6;
data[k + 15] = x[7] - t7;
}
let mut m = 32usize;
let mut s = 4;
while s + 1 < 10 {
let half_m1 = m / 2;
let m2 = m * 2;
let half_m2 = m;
let tw1_stage = &tw_table[s];
let tw2_stage = &tw_table[s + 1];
for k in (0..1024).step_by(m2) {
let mut j = 0;
while j + 2 <= half_m1 {
unsafe {
let tw1 = _mm256_loadu_pd(tw1_stage[j].as_ptr());
let tw2_a = _mm256_loadu_pd(tw2_stage[j].as_ptr());
let tw2_b = _mm256_loadu_pd(tw2_stage[j + half_m1].as_ptr());
let x0_ptr = ptr.add((k + j) * 2);
let x1_ptr = ptr.add((k + j + half_m1) * 2);
let x2_ptr = ptr.add((k + j + half_m2) * 2);
let x3_ptr = ptr.add((k + j + half_m2 + half_m1) * 2);
let x0 = _mm256_loadu_pd(x0_ptr);
let x1 = _mm256_loadu_pd(x1_ptr);
let x2 = _mm256_loadu_pd(x2_ptr);
let x3 = _mm256_loadu_pd(x3_ptr);
let tw1_re = _mm256_permute_pd(tw1, 0b0000);
let tw1_im = _mm256_permute_pd(tw1, 0b1111);
let tw2a_re = _mm256_permute_pd(tw2_a, 0b0000);
let tw2a_im = _mm256_permute_pd(tw2_a, 0b1111);
let tw2b_re = _mm256_permute_pd(tw2_b, 0b0000);
let tw2b_im = _mm256_permute_pd(tw2_b, 0b1111);
let x1_re = _mm256_permute_pd(x1, 0b0000);
let x1_im = _mm256_permute_pd(x1, 0b1111);
let t1_re = _mm256_fnmadd_pd(x1_im, tw1_im, _mm256_mul_pd(x1_re, tw1_re));
let t1_im = _mm256_fmadd_pd(x1_im, tw1_re, _mm256_mul_pd(x1_re, tw1_im));
let t1 = _mm256_blend_pd(t1_re, t1_im, 0b1010);
let x3_re = _mm256_permute_pd(x3, 0b0000);
let x3_im = _mm256_permute_pd(x3, 0b1111);
let t3_re = _mm256_fnmadd_pd(x3_im, tw1_im, _mm256_mul_pd(x3_re, tw1_re));
let t3_im = _mm256_fmadd_pd(x3_im, tw1_re, _mm256_mul_pd(x3_re, tw1_im));
let t3 = _mm256_blend_pd(t3_re, t3_im, 0b1010);
let a0 = _mm256_add_pd(x0, t1);
let a1 = _mm256_sub_pd(x0, t1);
let a2 = _mm256_add_pd(x2, t3);
let a3 = _mm256_sub_pd(x2, t3);
let a2_re = _mm256_permute_pd(a2, 0b0000);
let a2_im = _mm256_permute_pd(a2, 0b1111);
let t2a_re = _mm256_fnmadd_pd(a2_im, tw2a_im, _mm256_mul_pd(a2_re, tw2a_re));
let t2a_im = _mm256_fmadd_pd(a2_im, tw2a_re, _mm256_mul_pd(a2_re, tw2a_im));
let t2a = _mm256_blend_pd(t2a_re, t2a_im, 0b1010);
let a3_re = _mm256_permute_pd(a3, 0b0000);
let a3_im = _mm256_permute_pd(a3, 0b1111);
let t2b_re = _mm256_fnmadd_pd(a3_im, tw2b_im, _mm256_mul_pd(a3_re, tw2b_re));
let t2b_im = _mm256_fmadd_pd(a3_im, tw2b_re, _mm256_mul_pd(a3_re, tw2b_im));
let t2b = _mm256_blend_pd(t2b_re, t2b_im, 0b1010);
_mm256_storeu_pd(x0_ptr, _mm256_add_pd(a0, t2a));
_mm256_storeu_pd(x2_ptr, _mm256_sub_pd(a0, t2a));
_mm256_storeu_pd(x1_ptr, _mm256_add_pd(a1, t2b));
_mm256_storeu_pd(x3_ptr, _mm256_sub_pd(a1, t2b));
}
j += 2;
}
while j < half_m1 {
let i0 = k + j;
let i1 = k + j + half_m1;
let i2 = k + j + half_m2;
let i3 = k + j + half_m2 + half_m1;
let tw1 = tw1_stage[j];
let tw2_a = tw2_stage[j];
let tw2_b = tw2_stage[j + half_m1];
let w1 = Complex::new(tw1[0], tw1[1]);
let w2_a = Complex::new(tw2_a[0], tw2_a[1]);
let w2_b = Complex::new(tw2_b[0], tw2_b[1]);
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let a0 = x0 + x1 * w1;
let a1 = x0 - x1 * w1;
let a2 = x2 + x3 * w1;
let a3 = x2 - x3 * w1;
data[i0] = a0 + a2 * w2_a;
data[i2] = a0 - a2 * w2_a;
data[i1] = a1 + a3 * w2_b;
data[i3] = a1 - a3 * w2_b;
j += 1;
}
}
s += 2;
m *= 4;
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
fn dit_1024_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[inline]
fn bit_reverse_permute_1024(x: &mut [Complex<f64>]) {
static BIT_REV_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
let mut i = 0;
while i < 256 {
let mut x = i as u8;
let mut rev = 0u8;
let mut j = 0;
while j < 8 {
rev = (rev << 1) | (x & 1);
x >>= 1;
j += 1;
}
table[i] = rev;
i += 1;
}
table
};
const REV_2BITS: [usize; 4] = [0, 2, 1, 3];
for i in 0..1024 {
let low = i & 0xFF;
let high = (i >> 8) & 0x03;
let j = REV_2BITS[high] | ((BIT_REV_TABLE[low] as usize) << 2);
if i < j {
x.swap(i, j);
}
}
}
#[inline]
pub fn notw_4096_simd_f64(x: &mut [Complex<f64>], sign: i32) {
debug_assert!(x.len() >= 4096);
bit_reverse_permute_4096(x);
dit_4096_precomputed(&mut x[..4096], sign);
}
#[cfg(target_arch = "aarch64")]
struct TwiddlesF64_4096 {
forward: Box<[[[f64; 2]; 2048]; 12]>,
inverse: Box<[[[f64; 2]; 2048]; 12]>,
}
#[cfg(target_arch = "aarch64")]
impl TwiddlesF64_4096 {
#[allow(clippy::large_stack_frames)]
fn new() -> Self {
let mut forward = Box::new([[[-0.0_f64; 2]; 2048]; 12]);
let mut inverse = Box::new([[[-0.0_f64; 2]; 2048]; 12]);
for s in 0..12 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
let (sin_a, cos_a) = angle.sin_cos();
forward[s][j] = [cos_a, sin_a];
inverse[s][j] = [cos_a, -sin_a];
}
}
Self { forward, inverse }
}
}
#[cfg(target_arch = "aarch64")]
fn dit_4096_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::aarch64::*;
static TWIDDLES: OnceLock<TwiddlesF64_4096> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(TwiddlesF64_4096::new);
let ptr = data.as_mut_ptr() as *mut f64;
let sign_arr = [-1.0_f64, 1.0];
unsafe {
let sign_pattern = vld1q_f64(sign_arr.as_ptr());
let tw_table = if sign > 0 {
&twiddles.inverse
} else {
&twiddles.forward
};
let mut m = 2usize;
for s in 0..12 {
let half_m = m / 2;
let tw_stage = &tw_table[s];
for k in (0..4096).step_by(m) {
let mut j = 0;
while j + 7 < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
neon_butterfly_fast(
ptr,
k + j + 1,
half_m,
tw_stage[j + 1].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 2,
half_m,
tw_stage[j + 2].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 3,
half_m,
tw_stage[j + 3].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 4,
half_m,
tw_stage[j + 4].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 5,
half_m,
tw_stage[j + 5].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 6,
half_m,
tw_stage[j + 6].as_ptr(),
sign_pattern,
);
neon_butterfly_fast(
ptr,
k + j + 7,
half_m,
tw_stage[j + 7].as_ptr(),
sign_pattern,
);
j += 8;
}
while j < half_m {
neon_butterfly_fast(ptr, k + j, half_m, tw_stage[j].as_ptr(), sign_pattern);
j += 1;
}
}
m *= 2;
}
}
}
#[cfg(target_arch = "x86_64")]
struct TwiddlesF64_4096X86 {
forward: Box<[[[f64; 2]; 2048]; 12]>,
inverse: Box<[[[f64; 2]; 2048]; 12]>,
}
#[cfg(target_arch = "x86_64")]
impl TwiddlesF64_4096X86 {
#[allow(clippy::large_stack_frames)]
fn new() -> Self {
let mut forward = Box::new([[[-0.0_f64; 2]; 2048]; 12]);
let mut inverse = Box::new([[[-0.0_f64; 2]; 2048]; 12]);
for s in 0..12 {
let m = 2usize << s;
let half_m = m / 2;
for j in 0..half_m {
let angle = -core::f64::consts::TAU * (j as f64) / (m as f64);
let (sin_a, cos_a) = angle.sin_cos();
forward[s][j] = [cos_a, sin_a];
inverse[s][j] = [cos_a, -sin_a];
}
}
Self { forward, inverse }
}
}
#[cfg(target_arch = "x86_64")]
fn dit_4096_precomputed(data: &mut [Complex<f64>], sign: i32) {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { dit_4096_avx2(data, sign) }
} else {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dit_4096_avx2(data: &mut [Complex<f64>], sign: i32) {
use crate::prelude::OnceLock;
#[cfg(not(feature = "std"))]
use crate::prelude::OnceLockExt;
use core::arch::x86_64::*;
static TWIDDLES: OnceLock<TwiddlesF64_4096X86> = OnceLock::new();
let twiddles = TWIDDLES.get_or_init(TwiddlesF64_4096X86::new);
let ptr = data.as_mut_ptr() as *mut f64;
let sign_f = f64::from(sign);
let tw_table = if sign > 0 {
&twiddles.inverse
} else {
&twiddles.forward
};
let sqrt2_2 = core::f64::consts::FRAC_1_SQRT_2;
let w8_1 = Complex::new(sqrt2_2, sign_f * sqrt2_2);
let w8_3 = Complex::new(-sqrt2_2, sign_f * sqrt2_2);
let c16_1 = (core::f64::consts::PI / 8.0).cos();
let s16_1 = (core::f64::consts::PI / 8.0).sin();
let c16_3 = (3.0 * core::f64::consts::PI / 8.0).cos();
let s16_3 = (3.0 * core::f64::consts::PI / 8.0).sin();
let w16_1 = Complex::new(c16_1, sign_f * s16_1);
let w16_2 = Complex::new(sqrt2_2, sign_f * sqrt2_2);
let w16_3 = Complex::new(c16_3, sign_f * s16_3);
let w16_5 = Complex::new(-c16_3, sign_f * s16_3);
let w16_6 = Complex::new(-sqrt2_2, sign_f * sqrt2_2);
let w16_7 = Complex::new(-c16_1, sign_f * s16_1);
for k in (0..4096).step_by(16) {
let mut x: [Complex<f64>; 16] = [
data[k],
data[k + 1],
data[k + 2],
data[k + 3],
data[k + 4],
data[k + 5],
data[k + 6],
data[k + 7],
data[k + 8],
data[k + 9],
data[k + 10],
data[k + 11],
data[k + 12],
data[k + 13],
data[k + 14],
data[k + 15],
];
for i in (0..16).step_by(2) {
let u = x[i];
let v = x[i + 1];
x[i] = u + v;
x[i + 1] = u - v;
}
for i in (0..16).step_by(4) {
let u0 = x[i];
let u1 = x[i + 1];
let v0 = x[i + 2];
let v1 = x[i + 3];
let t1 = Complex::new(-sign_f * v1.im, sign_f * v1.re);
x[i] = u0 + v0;
x[i + 1] = u1 + t1;
x[i + 2] = u0 - v0;
x[i + 3] = u1 - t1;
}
for base in [0, 8] {
let u0 = x[base];
let u1 = x[base + 1];
let u2 = x[base + 2];
let u3 = x[base + 3];
let v0 = x[base + 4];
let v1 = x[base + 5] * w8_1;
let v2 = Complex::new(-sign_f * x[base + 6].im, sign_f * x[base + 6].re);
let v3 = x[base + 7] * w8_3;
x[base] = u0 + v0;
x[base + 1] = u1 + v1;
x[base + 2] = u2 + v2;
x[base + 3] = u3 + v3;
x[base + 4] = u0 - v0;
x[base + 5] = u1 - v1;
x[base + 6] = u2 - v2;
x[base + 7] = u3 - v3;
}
let t0 = x[8];
let t1 = x[9] * w16_1;
let t2 = x[10] * w16_2;
let t3 = x[11] * w16_3;
let t4 = Complex::new(-sign_f * x[12].im, sign_f * x[12].re);
let t5 = x[13] * w16_5;
let t6 = x[14] * w16_6;
let t7 = x[15] * w16_7;
data[k] = x[0] + t0;
data[k + 1] = x[1] + t1;
data[k + 2] = x[2] + t2;
data[k + 3] = x[3] + t3;
data[k + 4] = x[4] + t4;
data[k + 5] = x[5] + t5;
data[k + 6] = x[6] + t6;
data[k + 7] = x[7] + t7;
data[k + 8] = x[0] - t0;
data[k + 9] = x[1] - t1;
data[k + 10] = x[2] - t2;
data[k + 11] = x[3] - t3;
data[k + 12] = x[4] - t4;
data[k + 13] = x[5] - t5;
data[k + 14] = x[6] - t6;
data[k + 15] = x[7] - t7;
}
let mut m = 32usize;
let mut s = 4;
while s + 1 < 12 {
let half_m1 = m / 2; let m2 = m * 2; let half_m2 = m;
let tw1_stage = &tw_table[s];
let tw2_stage = &tw_table[s + 1];
for k in (0..4096).step_by(m2) {
let mut j = 0;
while j + 2 <= half_m1 {
unsafe {
let tw1 = _mm256_loadu_pd(tw1_stage[j].as_ptr());
let tw2_a = _mm256_loadu_pd(tw2_stage[j].as_ptr());
let tw2_b = _mm256_loadu_pd(tw2_stage[j + half_m1].as_ptr());
let x0_ptr = ptr.add((k + j) * 2);
let x1_ptr = ptr.add((k + j + half_m1) * 2);
let x2_ptr = ptr.add((k + j + half_m2) * 2);
let x3_ptr = ptr.add((k + j + half_m2 + half_m1) * 2);
let x0 = _mm256_loadu_pd(x0_ptr);
let x1 = _mm256_loadu_pd(x1_ptr);
let x2 = _mm256_loadu_pd(x2_ptr);
let x3 = _mm256_loadu_pd(x3_ptr);
let tw1_re = _mm256_permute_pd(tw1, 0b0000);
let tw1_im = _mm256_permute_pd(tw1, 0b1111);
let tw2a_re = _mm256_permute_pd(tw2_a, 0b0000);
let tw2a_im = _mm256_permute_pd(tw2_a, 0b1111);
let tw2b_re = _mm256_permute_pd(tw2_b, 0b0000);
let tw2b_im = _mm256_permute_pd(tw2_b, 0b1111);
let x1_re = _mm256_permute_pd(x1, 0b0000);
let x1_im = _mm256_permute_pd(x1, 0b1111);
let t1_re = _mm256_fnmadd_pd(x1_im, tw1_im, _mm256_mul_pd(x1_re, tw1_re));
let t1_im = _mm256_fmadd_pd(x1_im, tw1_re, _mm256_mul_pd(x1_re, tw1_im));
let t1 = _mm256_blend_pd(t1_re, t1_im, 0b1010);
let x3_re = _mm256_permute_pd(x3, 0b0000);
let x3_im = _mm256_permute_pd(x3, 0b1111);
let t3_re = _mm256_fnmadd_pd(x3_im, tw1_im, _mm256_mul_pd(x3_re, tw1_re));
let t3_im = _mm256_fmadd_pd(x3_im, tw1_re, _mm256_mul_pd(x3_re, tw1_im));
let t3 = _mm256_blend_pd(t3_re, t3_im, 0b1010);
let a0 = _mm256_add_pd(x0, t1);
let a1 = _mm256_sub_pd(x0, t1);
let a2 = _mm256_add_pd(x2, t3);
let a3 = _mm256_sub_pd(x2, t3);
let a2_re = _mm256_permute_pd(a2, 0b0000);
let a2_im = _mm256_permute_pd(a2, 0b1111);
let t2a_re = _mm256_fnmadd_pd(a2_im, tw2a_im, _mm256_mul_pd(a2_re, tw2a_re));
let t2a_im = _mm256_fmadd_pd(a2_im, tw2a_re, _mm256_mul_pd(a2_re, tw2a_im));
let t2a = _mm256_blend_pd(t2a_re, t2a_im, 0b1010);
let a3_re = _mm256_permute_pd(a3, 0b0000);
let a3_im = _mm256_permute_pd(a3, 0b1111);
let t2b_re = _mm256_fnmadd_pd(a3_im, tw2b_im, _mm256_mul_pd(a3_re, tw2b_re));
let t2b_im = _mm256_fmadd_pd(a3_im, tw2b_re, _mm256_mul_pd(a3_re, tw2b_im));
let t2b = _mm256_blend_pd(t2b_re, t2b_im, 0b1010);
_mm256_storeu_pd(x0_ptr, _mm256_add_pd(a0, t2a));
_mm256_storeu_pd(x2_ptr, _mm256_sub_pd(a0, t2a));
_mm256_storeu_pd(x1_ptr, _mm256_add_pd(a1, t2b));
_mm256_storeu_pd(x3_ptr, _mm256_sub_pd(a1, t2b));
}
j += 2;
}
while j < half_m1 {
let i0 = k + j;
let i1 = k + j + half_m1;
let i2 = k + j + half_m2;
let i3 = k + j + half_m2 + half_m1;
let tw1 = tw1_stage[j];
let tw2_a = tw2_stage[j];
let tw2_b = tw2_stage[j + half_m1];
let w1 = Complex::new(tw1[0], tw1[1]);
let w2_a = Complex::new(tw2_a[0], tw2_a[1]);
let w2_b = Complex::new(tw2_b[0], tw2_b[1]);
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let a0 = x0 + x1 * w1;
let a1 = x0 - x1 * w1;
let a2 = x2 + x3 * w1;
let a3 = x2 - x3 * w1;
data[i0] = a0 + a2 * w2_a;
data[i2] = a0 - a2 * w2_a;
data[i1] = a1 + a3 * w2_b;
data[i3] = a1 - a3 * w2_b;
j += 1;
}
}
s += 2;
m *= 4;
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
fn dit_4096_precomputed(data: &mut [Complex<f64>], sign: i32) {
use crate::dft::problem::Sign;
use crate::dft::solvers::simd_butterfly::dit_butterflies_f64;
let sign_val = if sign < 0 {
Sign::Forward
} else {
Sign::Backward
};
dit_butterflies_f64(data, sign_val);
}
#[inline]
fn bit_reverse_permute_4096(x: &mut [Complex<f64>]) {
static BIT_REV_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
let mut i = 0;
while i < 256 {
let mut x = i as u8;
let mut rev = 0u8;
let mut j = 0;
while j < 8 {
rev = (rev << 1) | (x & 1);
x >>= 1;
j += 1;
}
table[i] = rev;
i += 1;
}
table
};
const REV_4BITS: [usize; 16] = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15];
for i in 0..4096 {
let low = i & 0xFF;
let high = (i >> 8) & 0x0F;
let j = REV_4BITS[high] | ((BIT_REV_TABLE[low] as usize) << 4);
if i < j {
x.swap(i, j);
}
}
}