use super::{Complex, Float};
use crate::prelude::*;
use core::any::TypeId;
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
pub struct TwiddleCache<T: Float> {
cache: HashMap<(usize, usize), Vec<Complex<T>>>,
}
impl<T: Float> Default for TwiddleCache<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> TwiddleCache<T> {
#[must_use]
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
pub fn get(&mut self, n: usize, k: usize) -> &[Complex<T>] {
self.cache
.entry((n, k))
.or_insert_with(|| compute_twiddles(n, k))
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
#[must_use]
pub fn compute_twiddles<T: Float>(n: usize, k: usize) -> Vec<Complex<T>> {
let mut result = Vec::with_capacity(k);
let theta_base = -T::TWO_PI / T::from_usize(n);
for j in 0..k {
let theta = theta_base * T::from_usize(j);
result.push(Complex::cis(theta));
}
result
}
#[allow(dead_code)] #[inline]
#[must_use]
pub fn twiddle<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = -T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[allow(dead_code)] #[inline]
#[must_use]
pub fn twiddle_inverse<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TwiddleDirection {
Forward,
Inverse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TwiddleKey {
pub size: usize,
pub direction: TwiddleDirection,
pub type_id: TypeId,
}
pub struct TwiddleTable<T: Float> {
pub factors: Vec<Complex<T>>,
}
pub struct TwiddleTableSoA<T: Float> {
pub re: Vec<T>,
pub im: Vec<T>,
}
impl<T: Float> TwiddleTableSoA<T> {
#[must_use]
pub fn len(&self) -> usize {
self.re.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.re.is_empty()
}
}
struct GlobalTwiddleCache {
f32_tables: HashMap<TwiddleKey, Arc<TwiddleTable<f32>>>,
f64_tables: HashMap<TwiddleKey, Arc<TwiddleTable<f64>>>,
soa_f32_tables: HashMap<TwiddleKey, Arc<TwiddleTableSoA<f32>>>,
soa_f64_tables: HashMap<TwiddleKey, Arc<TwiddleTableSoA<f64>>>,
}
impl GlobalTwiddleCache {
fn new() -> Self {
Self {
f32_tables: HashMap::new(),
f64_tables: HashMap::new(),
soa_f32_tables: HashMap::new(),
soa_f64_tables: HashMap::new(),
}
}
}
fn global_twiddle_cache() -> &'static RwLock<GlobalTwiddleCache> {
static CACHE: OnceLock<RwLock<GlobalTwiddleCache>> = OnceLock::new();
CACHE.get_or_init(|| RwLock::new(GlobalTwiddleCache::new()))
}
pub fn get_twiddle_table_f64(size: usize, direction: TwiddleDirection) -> Arc<TwiddleTable<f64>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f64>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.f64_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_f64(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.f64_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn get_twiddle_table_f32(size: usize, direction: TwiddleDirection) -> Arc<TwiddleTable<f32>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f32>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.f32_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_f32(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.f32_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn clear_twiddle_cache() {
let mut cache = rwlock_write(global_twiddle_cache());
*cache = GlobalTwiddleCache::new();
}
pub fn get_twiddle_table_soa_f64(
size: usize,
direction: TwiddleDirection,
) -> Arc<TwiddleTableSoA<f64>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f64>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.soa_f64_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_soa_f64(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.soa_f64_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn get_twiddle_table_soa_f32(
size: usize,
direction: TwiddleDirection,
) -> Arc<TwiddleTableSoA<f32>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f32>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.soa_f32_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_soa_f32(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.soa_f32_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
fn compute_twiddle_table_soa_f64(size: usize, direction: TwiddleDirection) -> TwiddleTableSoA<f64> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f64,
TwiddleDirection::Inverse => 1.0_f64,
};
let capacity = (size + 7) & !7;
let mut re = Vec::with_capacity(capacity);
let mut im = Vec::with_capacity(capacity);
for k in 0..size {
let angle = sign * 2.0 * core::f64::consts::PI * k as f64 / size as f64;
re.push(angle.cos());
im.push(angle.sin());
}
TwiddleTableSoA { re, im }
}
fn compute_twiddle_table_soa_f32(size: usize, direction: TwiddleDirection) -> TwiddleTableSoA<f32> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f32,
TwiddleDirection::Inverse => 1.0_f32,
};
let capacity = (size + 7) & !7;
let mut re = Vec::with_capacity(capacity);
let mut im = Vec::with_capacity(capacity);
for k in 0..size {
let angle = sign * 2.0 * core::f32::consts::PI * k as f32 / size as f32;
re.push(angle.cos());
im.push(angle.sin());
}
TwiddleTableSoA { re, im }
}
fn compute_twiddle_table_f64(size: usize, direction: TwiddleDirection) -> TwiddleTable<f64> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f64,
TwiddleDirection::Inverse => 1.0_f64,
};
let factors = (0..size)
.map(|k| {
let angle = sign * 2.0 * core::f64::consts::PI * k as f64 / size as f64;
Complex::new(angle.cos(), angle.sin())
})
.collect();
TwiddleTable { factors }
}
fn compute_twiddle_table_f32(size: usize, direction: TwiddleDirection) -> TwiddleTable<f32> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f32,
TwiddleDirection::Inverse => 1.0_f32,
};
let factors = (0..size)
.map(|k| {
let angle = sign * 2.0 * core::f32::consts::PI * k as f32 / size as f32;
Complex::new(angle.cos(), angle.sin())
})
.collect();
TwiddleTable { factors }
}
pub fn twiddle_mul_simd_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
assert_eq!(
data.len(),
twiddles.len(),
"twiddle_mul_simd_f64: length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_avx2_f64(data, twiddles) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_sse2_f64(data, twiddles) };
}
return twiddle_mul_scalar_f64(data, twiddles);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_neon_f64(data, twiddles) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_scalar_f64(data, twiddles);
}
pub fn twiddle_mul_scalar_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
for (d, t) in data.iter_mut().zip(twiddles.iter()) {
*d = *d * *t;
}
}
pub fn twiddle_mul_soa_scalar_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
debug_assert_eq!(data.len(), twiddle_re.len(), "SoA re length mismatch");
debug_assert_eq!(data.len(), twiddle_im.len(), "SoA im length mismatch");
for ((d, &tw_re), &tw_im) in data
.iter_mut()
.zip(twiddle_re.iter())
.zip(twiddle_im.iter())
{
let (d_re, d_im) = (d.re, d.im);
d.re = d_re * tw_re - d_im * tw_im;
d.im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_avx2_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..chunks {
let d = unsafe { _mm256_loadu_pd(data_ptr.add(i * 4)) };
let t = unsafe { _mm256_loadu_pd(tw_ptr.add(i * 4)) };
let a_re = _mm256_permute_pd(d, 0b0000);
let a_im = _mm256_permute_pd(d, 0b1111);
let t_swap = _mm256_permute_pd(t, 0b0101);
let prod1 = _mm256_mul_pd(a_re, t);
let prod2 = _mm256_mul_pd(a_im, t_swap);
let result = _mm256_addsub_pd(prod1, prod2);
unsafe { _mm256_storeu_pd(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_sse2_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::x86_64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..data.len() {
let d = unsafe { _mm_loadu_pd(data_ptr.add(i * 2)) };
let t = unsafe { _mm_loadu_pd(tw_ptr.add(i * 2)) };
let a_re = _mm_unpacklo_pd(d, d);
let a_im = _mm_unpackhi_pd(d, d);
let t_swap = _mm_shuffle_pd(t, t, 0b01);
let prod1 = _mm_mul_pd(a_re, t);
let prod2 = _mm_mul_pd(a_im, t_swap);
let sign = _mm_set_pd(0.0_f64, -0.0_f64); let prod2_signed = _mm_xor_pd(prod2, sign);
let result = _mm_add_pd(prod1, prod2_signed);
unsafe { _mm_storeu_pd(data_ptr.add(i * 2), result) };
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_neon_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::aarch64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..data.len() {
unsafe {
let d = vld1q_f64(data_ptr.add(i * 2));
let t = vld1q_f64(tw_ptr.add(i * 2));
let a_re = vdupq_lane_f64(vget_low_f64(d), 0);
let a_im = vdupq_lane_f64(vget_high_f64(d), 0);
let t_swap = vextq_f64(t, t, 1);
let prod1 = vmulq_f64(a_re, t);
let prod2 = vmulq_f64(a_im, t_swap);
let sign = vld1q_f64([(-1.0_f64), 1.0_f64].as_ptr());
let result = vfmaq_f64(prod1, prod2, sign);
vst1q_f64(data_ptr.add(i * 2), result);
}
}
}
pub fn twiddle_mul_soa_simd_f64(data: &mut [Complex<f64>], twiddle_re: &[f64], twiddle_im: &[f64]) {
assert_eq!(
data.len(),
twiddle_re.len(),
"twiddle_mul_soa_simd_f64: re length mismatch"
);
assert_eq!(
data.len(),
twiddle_im.len(),
"twiddle_mul_soa_simd_f64: im length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_soa_avx2_f64(data, twiddle_re, twiddle_im) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_soa_sse2_f64(data, twiddle_re, twiddle_im) };
}
return twiddle_mul_soa_scalar_f64(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_soa_neon_f64(data, twiddle_re, twiddle_im) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_soa_scalar_f64(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_soa_avx2_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::x86_64::*;
let chunks = data.len() / 2;
let data_ptr = data.as_mut_ptr() as *mut f64;
let re_ptr = twiddle_re.as_ptr();
let im_ptr = twiddle_im.as_ptr();
for i in 0..chunks {
let d = unsafe { _mm256_loadu_pd(data_ptr.add(i * 4)) };
let tw_re0 = unsafe { *re_ptr.add(i * 2) };
let tw_re1 = unsafe { *re_ptr.add(i * 2 + 1) };
let tw_im0 = unsafe { *im_ptr.add(i * 2) };
let tw_im1 = unsafe { *im_ptr.add(i * 2 + 1) };
let t = _mm256_set_pd(tw_im1, tw_re1, tw_im0, tw_re0);
let t_swap = _mm256_permute_pd(t, 0b0101);
let d_re = _mm256_permute_pd(d, 0b0000); let d_im = _mm256_permute_pd(d, 0b1111);
let prod1 = _mm256_mul_pd(d_re, t); let prod2 = _mm256_mul_pd(d_im, t_swap); let result = _mm256_addsub_pd(prod1, prod2);
unsafe { _mm256_storeu_pd(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_soa_sse2_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::x86_64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
for i in 0..data.len() {
let d = unsafe { _mm_loadu_pd(data_ptr.add(i * 2)) };
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
let t_re = _mm_set1_pd(tw_re);
let t_im = _mm_set1_pd(tw_im);
let d_re = _mm_unpacklo_pd(d, d);
let d_im = _mm_unpackhi_pd(d, d);
let prod_re = _mm_mul_pd(d_re, t_re);
let prod_im = _mm_mul_pd(d_im, t_im);
let prod_cross_re = _mm_mul_pd(d_re, t_im);
let prod_cross_im = _mm_mul_pd(d_im, t_re);
let res_re = _mm_sub_pd(prod_re, prod_im);
let res_im = _mm_add_pd(prod_cross_re, prod_cross_im);
let result = _mm_unpacklo_pd(res_re, res_im);
unsafe { _mm_storeu_pd(data_ptr.add(i * 2), result) };
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_soa_neon_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::aarch64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
for i in 0..data.len() {
unsafe {
let d = vld1q_f64(data_ptr.add(i * 2));
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
let d_re = vdupq_lane_f64(vget_low_f64(d), 0);
let d_im = vdupq_lane_f64(vget_high_f64(d), 0);
let t_re = vdupq_n_f64(tw_re);
let t_im = vdupq_n_f64(tw_im);
let prod_re_re = vmulq_f64(d_re, t_re);
let prod_im_im = vmulq_f64(d_im, t_im);
let prod_re_im = vmulq_f64(d_re, t_im);
let prod_im_re = vmulq_f64(d_im, t_re);
let res_re = vsubq_f64(prod_re_re, prod_im_im);
let res_im = vaddq_f64(prod_re_im, prod_im_re);
let result = vzip1q_f64(res_re, res_im);
vst1q_f64(data_ptr.add(i * 2), result);
}
}
}
pub fn twiddle_mul_soa_scalar_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
debug_assert_eq!(data.len(), twiddle_re.len(), "SoA re length mismatch (f32)");
debug_assert_eq!(data.len(), twiddle_im.len(), "SoA im length mismatch (f32)");
for ((d, &tw_re), &tw_im) in data
.iter_mut()
.zip(twiddle_re.iter())
.zip(twiddle_im.iter())
{
let (d_re, d_im) = (d.re, d.im);
d.re = d_re * tw_re - d_im * tw_im;
d.im = d_re * tw_im + d_im * tw_re;
}
}
pub fn twiddle_mul_soa_simd_f32(data: &mut [Complex<f32>], twiddle_re: &[f32], twiddle_im: &[f32]) {
assert_eq!(
data.len(),
twiddle_re.len(),
"twiddle_mul_soa_simd_f32: re length mismatch"
);
assert_eq!(
data.len(),
twiddle_im.len(),
"twiddle_mul_soa_simd_f32: im length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_soa_avx2_f32(data, twiddle_re, twiddle_im) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_soa_sse2_f32(data, twiddle_re, twiddle_im) };
}
return twiddle_mul_soa_scalar_f32(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_soa_neon_f32(data, twiddle_re, twiddle_im) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_soa_scalar_f32(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_soa_avx2_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
use core::arch::x86_64::*;
let chunks = data.len() / 4;
let data_ptr = data.as_mut_ptr() as *mut f32;
let re_ptr = twiddle_re.as_ptr();
let im_ptr = twiddle_im.as_ptr();
for i in 0..chunks {
let d = unsafe { _mm256_loadu_ps(data_ptr.add(i * 8)) };
let r0 = unsafe { *re_ptr.add(i * 4) };
let r1 = unsafe { *re_ptr.add(i * 4 + 1) };
let r2 = unsafe { *re_ptr.add(i * 4 + 2) };
let r3 = unsafe { *re_ptr.add(i * 4 + 3) };
let i0 = unsafe { *im_ptr.add(i * 4) };
let i1 = unsafe { *im_ptr.add(i * 4 + 1) };
let i2 = unsafe { *im_ptr.add(i * 4 + 2) };
let i3 = unsafe { *im_ptr.add(i * 4 + 3) };
let t = _mm256_set_ps(i3, r3, i2, r2, i1, r1, i0, r0);
let t_swap = _mm256_permute_ps(t, 0b10_11_00_01);
let d_re = _mm256_moveldup_ps(d);
let d_im = _mm256_movehdup_ps(d);
let prod1 = _mm256_mul_ps(d_re, t); let prod2 = _mm256_mul_ps(d_im, t_swap); let result = _mm256_addsub_ps(prod1, prod2);
unsafe { _mm256_storeu_ps(data_ptr.add(i * 8), result) };
}
let remainder_start = chunks * 4;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_soa_sse2_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
for i in 0..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_soa_neon_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
use core::arch::aarch64::*;
let chunks = data.len() / 2;
let data_ptr = data.as_mut_ptr() as *mut f32;
for i in 0..chunks {
unsafe {
let d = vld1q_f32(data_ptr.add(i * 4));
let r0 = twiddle_re[i * 2];
let r1 = twiddle_re[i * 2 + 1];
let im0 = twiddle_im[i * 2];
let im1 = twiddle_im[i * 2 + 1];
let tw_re_arr = [r0, r0, r1, r1];
let tw_im_arr = [im0, im0, im1, im1];
let t_re = vld1q_f32(tw_re_arr.as_ptr());
let t_im = vld1q_f32(tw_im_arr.as_ptr());
let d_re = vtrn1q_f32(d, d);
let d_im = vtrn2q_f32(d, d);
let res_re = vmlsq_f32(vmulq_f32(d_re, t_re), d_im, t_im);
let res_im = vmlaq_f32(vmulq_f32(d_re, t_im), d_im, t_re);
let result = vtrn1q_f32(res_re, res_im);
vst1q_f32(data_ptr.add(i * 4), result);
}
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
pub fn twiddle_mul_simd_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
assert_eq!(
data.len(),
twiddles.len(),
"twiddle_mul_simd_f32: length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_avx2_f32(data, twiddles) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_sse2_f32(data, twiddles) };
}
return twiddle_mul_scalar_f32(data, twiddles);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_neon_f32(data, twiddles) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_scalar_f32(data, twiddles);
}
pub fn twiddle_mul_scalar_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
for (d, t) in data.iter_mut().zip(twiddles.iter()) {
*d = *d * *t;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_avx2_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 4; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
let d = unsafe { _mm256_loadu_ps(data_ptr.add(i * 8)) };
let t = unsafe { _mm256_loadu_ps(tw_ptr.add(i * 8)) };
let a_re = _mm256_moveldup_ps(d);
let a_im = _mm256_movehdup_ps(d);
let t_swap = _mm256_permute_ps(t, 0b10_11_00_01);
let prod1 = _mm256_mul_ps(a_re, t);
let prod2 = _mm256_mul_ps(a_im, t_swap);
let result = _mm256_addsub_ps(prod1, prod2);
unsafe { _mm256_storeu_ps(data_ptr.add(i * 8), result) };
}
let remainder_start = chunks * 4;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_sse2_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
let d = unsafe { _mm_loadu_ps(data_ptr.add(i * 4)) };
let t = unsafe { _mm_loadu_ps(tw_ptr.add(i * 4)) };
let a_re = unsafe { _mm_moveldup_ps(d) };
let a_im = unsafe { _mm_movehdup_ps(d) };
let t_swap = _mm_shuffle_ps(t, t, 0b10_11_00_01);
let prod1 = _mm_mul_ps(a_re, t);
let prod2 = _mm_mul_ps(a_im, t_swap);
let result = unsafe { _mm_addsub_ps(prod1, prod2) };
unsafe { _mm_storeu_ps(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_neon_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::aarch64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
unsafe {
let d = vld1q_f32(data_ptr.add(i * 4));
let t = vld1q_f32(tw_ptr.add(i * 4));
let a_re = vtrn1q_f32(d, d);
let a_im = vtrn2q_f32(d, d);
let t_swap = vrev64q_f32(t);
let prod1 = vmulq_f32(a_re, t);
let prod2 = vmulq_f32(a_im, t_swap);
let sign = vld1q_f32([(-1.0_f32), 1.0_f32, (-1.0_f32), 1.0_f32].as_ptr());
let result = vfmaq_f32(prod1, prod2, sign);
vst1q_f32(data_ptr.add(i * 4), result);
}
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static CACHE_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_twiddle_w4() {
let w0: Complex<f64> = twiddle(4, 0);
assert!((w0.re - 1.0).abs() < 1e-10);
assert!(w0.im.abs() < 1e-10);
let w1: Complex<f64> = twiddle(4, 1);
assert!(w1.re.abs() < 1e-10);
assert!((w1.im - (-1.0)).abs() < 1e-10);
let w2: Complex<f64> = twiddle(4, 2);
assert!((w2.re - (-1.0)).abs() < 1e-10);
assert!(w2.im.abs() < 1e-10);
}
#[test]
fn test_compute_twiddles() {
let tw: Vec<Complex<f64>> = compute_twiddles(8, 4);
assert_eq!(tw.len(), 4);
assert!((tw[0].re - 1.0).abs() < 1e-10);
assert!(tw[0].im.abs() < 1e-10);
}
#[test]
fn simd_vs_scalar_parity_f64() {
let size = 256;
let twiddles: Vec<Complex<f64>> = (0..size)
.map(|k| {
let angle = -2.0 * core::f64::consts::PI * k as f64 / size as f64;
Complex::new(angle.cos(), angle.sin())
})
.collect();
let input: Vec<Complex<f64>> = (0..size)
.map(|k| Complex::new(k as f64, -(k as f64)))
.collect();
let mut simd_data = input.clone();
let mut scalar_data = input;
twiddle_mul_simd_f64(&mut simd_data, &twiddles);
twiddle_mul_scalar_f64(&mut scalar_data, &twiddles);
for (s, r) in simd_data.iter().zip(scalar_data.iter()) {
let diff = (s.re - r.re).abs().max((s.im - r.im).abs());
assert!(
diff <= 1e-10 || diff <= 1e-12 * r.norm(),
"SIMD/scalar f64 mismatch at element: simd={s:?} scalar={r:?}",
);
}
}
#[test]
fn simd_vs_scalar_parity_f32() {
let size = 256;
let twiddles: Vec<Complex<f32>> = (0..size)
.map(|k| {
let angle = -2.0 * core::f32::consts::PI * k as f32 / size as f32;
Complex::new(angle.cos(), angle.sin())
})
.collect();
let input: Vec<Complex<f32>> = (0..size)
.map(|k| Complex::new(k as f32, -(k as f32)))
.collect();
let mut simd_data = input.clone();
let mut scalar_data = input;
twiddle_mul_simd_f32(&mut simd_data, &twiddles);
twiddle_mul_scalar_f32(&mut scalar_data, &twiddles);
for (s, r) in simd_data.iter().zip(scalar_data.iter()) {
let diff = (s.re - r.re).abs().max((s.im - r.im).abs());
assert!(
diff <= 1e-5_f32 || diff <= 1e-6_f32 * r.norm(),
"SIMD/scalar f32 mismatch: simd={s:?} scalar={r:?}",
);
}
}
#[test]
fn simd_vs_scalar_parity_f64_odd_length() {
let size = 7;
let twiddles: Vec<Complex<f64>> = (0..size)
.map(|k| {
let angle = -2.0 * core::f64::consts::PI * k as f64 / size as f64;
Complex::new(angle.cos(), angle.sin())
})
.collect();
let input: Vec<Complex<f64>> = (0..size)
.map(|k| Complex::new(k as f64 + 1.0, -(k as f64) - 0.5))
.collect();
let mut simd_data = input.clone();
let mut scalar_data = input;
twiddle_mul_simd_f64(&mut simd_data, &twiddles);
twiddle_mul_scalar_f64(&mut scalar_data, &twiddles);
for (s, r) in simd_data.iter().zip(scalar_data.iter()) {
let diff = (s.re - r.re).abs().max((s.im - r.im).abs());
assert!(
diff <= 1e-10,
"SIMD/scalar f64 odd-length mismatch: simd={s:?} scalar={r:?}",
);
}
}
#[test]
fn twiddle_cache_hit_f64() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let t1 = get_twiddle_table_f64(256, TwiddleDirection::Forward);
let t2 = get_twiddle_table_f64(256, TwiddleDirection::Forward);
assert!(
Arc::ptr_eq(&t1, &t2),
"second call should return the cached Arc"
);
}
#[test]
fn twiddle_cache_direction_separation() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let fwd = get_twiddle_table_f64(64, TwiddleDirection::Forward);
let inv = get_twiddle_table_f64(64, TwiddleDirection::Inverse);
assert!(
!Arc::ptr_eq(&fwd, &inv),
"forward and inverse tables should be distinct"
);
if fwd.factors.len() > 1 && inv.factors.len() > 1 {
let f = fwd.factors[1];
let i = inv.factors[1];
assert!(
(f.re - i.re).abs() < 1e-14,
"real parts should match: {} vs {}",
f.re,
i.re
);
assert!(
(f.im + i.im).abs() < 1e-14,
"imag parts should be negated: {} vs {}",
f.im,
i.im
);
}
}
#[test]
fn twiddle_cache_invalidate_f64() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
let t1 = get_twiddle_table_f64(512, TwiddleDirection::Forward);
clear_twiddle_cache();
let t2 = get_twiddle_table_f64(512, TwiddleDirection::Forward);
assert!(
!Arc::ptr_eq(&t1, &t2),
"after clear, should allocate a new table"
);
}
#[test]
fn twiddle_cache_hit_f32() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let t1 = get_twiddle_table_f32(128, TwiddleDirection::Forward);
let t2 = get_twiddle_table_f32(128, TwiddleDirection::Forward);
assert!(
Arc::ptr_eq(&t1, &t2),
"second call should return the cached Arc (f32)"
);
}
#[test]
fn twiddle_cache_f32_f64_separation() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let _f64 = get_twiddle_table_f64(32, TwiddleDirection::Forward);
let _f32 = get_twiddle_table_f32(32, TwiddleDirection::Forward);
}
#[test]
fn twiddle_table_correctness_f64() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let n = 8;
let table = get_twiddle_table_f64(n, TwiddleDirection::Forward);
assert_eq!(table.factors.len(), n);
assert!((table.factors[0].re - 1.0).abs() < 1e-14);
assert!(table.factors[0].im.abs() < 1e-14);
for (k, w) in table.factors.iter().enumerate() {
let mag_sq = w.re * w.re + w.im * w.im;
assert!(
(mag_sq - 1.0).abs() < 1e-13,
"W_{n}^{k} should be on unit circle, |w|²={mag_sq}"
);
}
}
fn ulp_distance_f64(a: f64, b: f64) -> u64 {
let ai = a.to_bits();
let bi = b.to_bits();
ai.abs_diff(bi)
}
fn check_soa_vs_aos_f64(size: usize) {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let input: Vec<Complex<f64>> = (0..size)
.map(|k| Complex::new((k as f64).sin(), (k as f64).cos()))
.collect();
let aos_table = get_twiddle_table_f64(size, TwiddleDirection::Forward);
let mut aos_data = input.clone();
twiddle_mul_simd_f64(&mut aos_data, &aos_table.factors);
let soa_table = get_twiddle_table_soa_f64(size, TwiddleDirection::Forward);
let mut soa_data = input;
twiddle_mul_soa_simd_f64(&mut soa_data, &soa_table.re, &soa_table.im);
for (idx, (a, s)) in aos_data.iter().zip(soa_data.iter()).enumerate() {
let re_ulp = ulp_distance_f64(a.re, s.re);
let im_ulp = ulp_distance_f64(a.im, s.im);
assert!(
re_ulp <= 4,
"SoA vs AoS f64 re mismatch at idx={idx} size={size}: \
AoS={}, SoA={}, ULP={re_ulp}",
a.re,
s.re
);
assert!(
im_ulp <= 4,
"SoA vs AoS f64 im mismatch at idx={idx} size={size}: \
AoS={}, SoA={}, ULP={im_ulp}",
a.im,
s.im
);
}
}
fn ulp_distance_f32(a: f32, b: f32) -> u32 {
let ai = a.to_bits();
let bi = b.to_bits();
ai.abs_diff(bi)
}
fn check_soa_vs_aos_f32(size: usize) {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let input: Vec<Complex<f32>> = (0..size)
.map(|k| Complex::new((k as f32).sin(), (k as f32).cos()))
.collect();
let aos_table = get_twiddle_table_f32(size, TwiddleDirection::Forward);
let mut aos_data = input.clone();
twiddle_mul_simd_f32(&mut aos_data, &aos_table.factors);
let soa_table = get_twiddle_table_soa_f32(size, TwiddleDirection::Forward);
let mut soa_data = input;
twiddle_mul_soa_simd_f32(&mut soa_data, &soa_table.re, &soa_table.im);
for (idx, (a, s)) in aos_data.iter().zip(soa_data.iter()).enumerate() {
let re_ulp = ulp_distance_f32(a.re, s.re);
let im_ulp = ulp_distance_f32(a.im, s.im);
assert!(
re_ulp <= 4,
"SoA vs AoS f32 re mismatch at idx={idx} size={size}: \
AoS={}, SoA={}, ULP={re_ulp}",
a.re,
s.re
);
assert!(
im_ulp <= 4,
"SoA vs AoS f32 im mismatch at idx={idx} size={size}: \
AoS={}, SoA={}, ULP={im_ulp}",
a.im,
s.im
);
}
}
#[test]
fn soa_vs_aos_correctness_f64_1024() {
check_soa_vs_aos_f64(1024);
}
#[test]
fn soa_vs_aos_correctness_f64_4096() {
check_soa_vs_aos_f64(4096);
}
#[test]
fn soa_vs_aos_correctness_f64_16384() {
check_soa_vs_aos_f64(16384);
}
#[test]
fn soa_vs_aos_correctness_f64_65536() {
check_soa_vs_aos_f64(65536);
}
#[test]
fn soa_vs_aos_correctness_f32_1024() {
check_soa_vs_aos_f32(1024);
}
#[test]
fn soa_vs_aos_correctness_f32_4096() {
check_soa_vs_aos_f32(4096);
}
#[test]
fn soa_vs_aos_correctness_f32_16384() {
check_soa_vs_aos_f32(16384);
}
#[test]
fn soa_vs_aos_correctness_f32_65536() {
check_soa_vs_aos_f32(65536);
}
#[test]
fn soa_cache_hit_f64() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let t1 = get_twiddle_table_soa_f64(1024, TwiddleDirection::Forward);
let t2 = get_twiddle_table_soa_f64(1024, TwiddleDirection::Forward);
assert!(
Arc::ptr_eq(&t1, &t2),
"second SoA f64 call should return the cached Arc"
);
}
#[test]
fn soa_cache_hit_f32() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let t1 = get_twiddle_table_soa_f32(512, TwiddleDirection::Forward);
let t2 = get_twiddle_table_soa_f32(512, TwiddleDirection::Forward);
assert!(
Arc::ptr_eq(&t1, &t2),
"second SoA f32 call should return the cached Arc"
);
}
#[test]
fn soa_table_correctness_f64() {
let _lock = CACHE_LOCK.lock().expect("cache lock");
clear_twiddle_cache();
let n = 16usize;
let soa = get_twiddle_table_soa_f64(n, TwiddleDirection::Forward);
assert_eq!(soa.re.len(), n);
assert_eq!(soa.im.len(), n);
assert!((soa.re[0] - 1.0_f64).abs() < 1e-14);
assert!(soa.im[0].abs() < 1e-14);
for k in 0..n {
let mag_sq = soa.re[k] * soa.re[k] + soa.im[k] * soa.im[k];
assert!(
(mag_sq - 1.0_f64).abs() < 1e-13,
"SoA W_{n}^{k} not on unit circle: |w|²={mag_sq}"
);
}
let aos = get_twiddle_table_f64(n, TwiddleDirection::Forward);
for k in 0..n {
assert!(
(soa.re[k] - aos.factors[k].re).abs() < 1e-14,
"SoA re[{k}]={} != AoS re={}",
soa.re[k],
aos.factors[k].re
);
assert!(
(soa.im[k] - aos.factors[k].im).abs() < 1e-14,
"SoA im[{k}]={} != AoS im={}",
soa.im[k],
aos.factors[k].im
);
}
}
}