use super::{Complex, Float};
use crate::prelude::*;
use core::any::TypeId;
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
pub struct TwiddleCache<T: Float> {
cache: HashMap<(usize, usize), Vec<Complex<T>>>,
}
impl<T: Float> Default for TwiddleCache<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> TwiddleCache<T> {
#[must_use]
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
pub fn get(&mut self, n: usize, k: usize) -> &[Complex<T>] {
self.cache
.entry((n, k))
.or_insert_with(|| compute_twiddles(n, k))
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
#[must_use]
pub fn compute_twiddles<T: Float>(n: usize, k: usize) -> Vec<Complex<T>> {
let mut result = Vec::with_capacity(k);
let theta_base = -T::TWO_PI / T::from_usize(n);
for j in 0..k {
let theta = theta_base * T::from_usize(j);
result.push(Complex::cis(theta));
}
result
}
#[allow(dead_code)] #[inline]
#[must_use]
pub fn twiddle<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = -T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[allow(dead_code)] #[inline]
#[must_use]
pub fn twiddle_inverse<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TwiddleDirection {
Forward,
Inverse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TwiddleKey {
pub size: usize,
pub direction: TwiddleDirection,
pub type_id: TypeId,
}
pub struct TwiddleTable<T: Float> {
pub factors: Vec<Complex<T>>,
}
pub struct TwiddleTableSoA<T: Float> {
pub re: Vec<T>,
pub im: Vec<T>,
}
impl<T: Float> TwiddleTableSoA<T> {
#[must_use]
pub fn len(&self) -> usize {
self.re.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.re.is_empty()
}
}
struct GlobalTwiddleCache {
f32_tables: HashMap<TwiddleKey, Arc<TwiddleTable<f32>>>,
f64_tables: HashMap<TwiddleKey, Arc<TwiddleTable<f64>>>,
soa_f32_tables: HashMap<TwiddleKey, Arc<TwiddleTableSoA<f32>>>,
soa_f64_tables: HashMap<TwiddleKey, Arc<TwiddleTableSoA<f64>>>,
}
impl GlobalTwiddleCache {
fn new() -> Self {
Self {
f32_tables: HashMap::new(),
f64_tables: HashMap::new(),
soa_f32_tables: HashMap::new(),
soa_f64_tables: HashMap::new(),
}
}
}
fn global_twiddle_cache() -> &'static RwLock<GlobalTwiddleCache> {
static CACHE: OnceLock<RwLock<GlobalTwiddleCache>> = OnceLock::new();
CACHE.get_or_init(|| RwLock::new(GlobalTwiddleCache::new()))
}
pub fn get_twiddle_table_f64(size: usize, direction: TwiddleDirection) -> Arc<TwiddleTable<f64>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f64>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.f64_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_f64(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.f64_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn get_twiddle_table_f32(size: usize, direction: TwiddleDirection) -> Arc<TwiddleTable<f32>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f32>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.f32_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_f32(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.f32_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn clear_twiddle_cache() {
let mut cache = rwlock_write(global_twiddle_cache());
*cache = GlobalTwiddleCache::new();
}
pub fn get_twiddle_table_soa_f64(
size: usize,
direction: TwiddleDirection,
) -> Arc<TwiddleTableSoA<f64>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f64>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.soa_f64_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_soa_f64(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.soa_f64_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
pub fn get_twiddle_table_soa_f32(
size: usize,
direction: TwiddleDirection,
) -> Arc<TwiddleTableSoA<f32>> {
let key = TwiddleKey {
size,
direction,
type_id: TypeId::of::<f32>(),
};
{
let cache = rwlock_read(global_twiddle_cache());
if let Some(t) = cache.soa_f32_tables.get(&key) {
return Arc::clone(t);
}
}
let table = Arc::new(compute_twiddle_table_soa_f32(size, direction));
{
let mut cache = rwlock_write(global_twiddle_cache());
cache
.soa_f32_tables
.entry(key)
.or_insert_with(|| Arc::clone(&table));
}
table
}
fn compute_twiddle_table_soa_f64(size: usize, direction: TwiddleDirection) -> TwiddleTableSoA<f64> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f64,
TwiddleDirection::Inverse => 1.0_f64,
};
let capacity = (size + 7) & !7;
let mut re = Vec::with_capacity(capacity);
let mut im = Vec::with_capacity(capacity);
for k in 0..size {
let angle = sign * 2.0 * core::f64::consts::PI * k as f64 / size as f64;
re.push(angle.cos());
im.push(angle.sin());
}
TwiddleTableSoA { re, im }
}
fn compute_twiddle_table_soa_f32(size: usize, direction: TwiddleDirection) -> TwiddleTableSoA<f32> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f32,
TwiddleDirection::Inverse => 1.0_f32,
};
let capacity = (size + 7) & !7;
let mut re = Vec::with_capacity(capacity);
let mut im = Vec::with_capacity(capacity);
for k in 0..size {
let angle = sign * 2.0 * core::f32::consts::PI * k as f32 / size as f32;
re.push(angle.cos());
im.push(angle.sin());
}
TwiddleTableSoA { re, im }
}
fn compute_twiddle_table_f64(size: usize, direction: TwiddleDirection) -> TwiddleTable<f64> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f64,
TwiddleDirection::Inverse => 1.0_f64,
};
let factors = (0..size)
.map(|k| {
let angle = sign * 2.0 * core::f64::consts::PI * k as f64 / size as f64;
Complex::new(angle.cos(), angle.sin())
})
.collect();
TwiddleTable { factors }
}
fn compute_twiddle_table_f32(size: usize, direction: TwiddleDirection) -> TwiddleTable<f32> {
let sign = match direction {
TwiddleDirection::Forward => -1.0_f32,
TwiddleDirection::Inverse => 1.0_f32,
};
let factors = (0..size)
.map(|k| {
let angle = sign * 2.0 * core::f32::consts::PI * k as f32 / size as f32;
Complex::new(angle.cos(), angle.sin())
})
.collect();
TwiddleTable { factors }
}
pub fn twiddle_mul_simd_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
assert_eq!(
data.len(),
twiddles.len(),
"twiddle_mul_simd_f64: length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_avx2_f64(data, twiddles) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_sse2_f64(data, twiddles) };
}
return twiddle_mul_scalar_f64(data, twiddles);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_neon_f64(data, twiddles) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_scalar_f64(data, twiddles);
}
pub fn twiddle_mul_scalar_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
for (d, t) in data.iter_mut().zip(twiddles.iter()) {
*d = *d * *t;
}
}
pub fn twiddle_mul_soa_scalar_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
debug_assert_eq!(data.len(), twiddle_re.len(), "SoA re length mismatch");
debug_assert_eq!(data.len(), twiddle_im.len(), "SoA im length mismatch");
for ((d, &tw_re), &tw_im) in data
.iter_mut()
.zip(twiddle_re.iter())
.zip(twiddle_im.iter())
{
let (d_re, d_im) = (d.re, d.im);
d.re = d_re * tw_re - d_im * tw_im;
d.im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_avx2_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..chunks {
let d = unsafe { _mm256_loadu_pd(data_ptr.add(i * 4)) };
let t = unsafe { _mm256_loadu_pd(tw_ptr.add(i * 4)) };
let a_re = _mm256_permute_pd(d, 0b0000);
let a_im = _mm256_permute_pd(d, 0b1111);
let t_swap = _mm256_permute_pd(t, 0b0101);
let prod1 = _mm256_mul_pd(a_re, t);
let prod2 = _mm256_mul_pd(a_im, t_swap);
let result = _mm256_addsub_pd(prod1, prod2);
unsafe { _mm256_storeu_pd(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_sse2_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::x86_64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..data.len() {
let d = unsafe { _mm_loadu_pd(data_ptr.add(i * 2)) };
let t = unsafe { _mm_loadu_pd(tw_ptr.add(i * 2)) };
let a_re = _mm_unpacklo_pd(d, d);
let a_im = _mm_unpackhi_pd(d, d);
let t_swap = _mm_shuffle_pd(t, t, 0b01);
let prod1 = _mm_mul_pd(a_re, t);
let prod2 = _mm_mul_pd(a_im, t_swap);
let sign = _mm_set_pd(0.0_f64, -0.0_f64); let prod2_signed = _mm_xor_pd(prod2, sign);
let result = _mm_add_pd(prod1, prod2_signed);
unsafe { _mm_storeu_pd(data_ptr.add(i * 2), result) };
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_neon_f64(data: &mut [Complex<f64>], twiddles: &[Complex<f64>]) {
use core::arch::aarch64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
let tw_ptr = twiddles.as_ptr() as *const f64;
for i in 0..data.len() {
unsafe {
let d = vld1q_f64(data_ptr.add(i * 2));
let t = vld1q_f64(tw_ptr.add(i * 2));
let a_re = vdupq_lane_f64(vget_low_f64(d), 0);
let a_im = vdupq_lane_f64(vget_high_f64(d), 0);
let t_swap = vextq_f64(t, t, 1);
let prod1 = vmulq_f64(a_re, t);
let prod2 = vmulq_f64(a_im, t_swap);
let sign = vld1q_f64([(-1.0_f64), 1.0_f64].as_ptr());
let result = vfmaq_f64(prod1, prod2, sign);
vst1q_f64(data_ptr.add(i * 2), result);
}
}
}
pub fn twiddle_mul_soa_simd_f64(data: &mut [Complex<f64>], twiddle_re: &[f64], twiddle_im: &[f64]) {
assert_eq!(
data.len(),
twiddle_re.len(),
"twiddle_mul_soa_simd_f64: re length mismatch"
);
assert_eq!(
data.len(),
twiddle_im.len(),
"twiddle_mul_soa_simd_f64: im length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_soa_avx2_f64(data, twiddle_re, twiddle_im) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_soa_sse2_f64(data, twiddle_re, twiddle_im) };
}
return twiddle_mul_soa_scalar_f64(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_soa_neon_f64(data, twiddle_re, twiddle_im) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_soa_scalar_f64(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_soa_avx2_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::x86_64::*;
let chunks = data.len() / 2;
let data_ptr = data.as_mut_ptr() as *mut f64;
let re_ptr = twiddle_re.as_ptr();
let im_ptr = twiddle_im.as_ptr();
for i in 0..chunks {
let d = unsafe { _mm256_loadu_pd(data_ptr.add(i * 4)) };
let tw_re0 = unsafe { *re_ptr.add(i * 2) };
let tw_re1 = unsafe { *re_ptr.add(i * 2 + 1) };
let tw_im0 = unsafe { *im_ptr.add(i * 2) };
let tw_im1 = unsafe { *im_ptr.add(i * 2 + 1) };
let t = _mm256_set_pd(tw_im1, tw_re1, tw_im0, tw_re0);
let t_swap = _mm256_permute_pd(t, 0b0101);
let d_re = _mm256_permute_pd(d, 0b0000); let d_im = _mm256_permute_pd(d, 0b1111);
let prod1 = _mm256_mul_pd(d_re, t); let prod2 = _mm256_mul_pd(d_im, t_swap); let result = _mm256_addsub_pd(prod1, prod2);
unsafe { _mm256_storeu_pd(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_soa_sse2_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::x86_64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
for i in 0..data.len() {
let d = unsafe { _mm_loadu_pd(data_ptr.add(i * 2)) };
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
let t_re = _mm_set1_pd(tw_re);
let t_im = _mm_set1_pd(tw_im);
let d_re = _mm_unpacklo_pd(d, d);
let d_im = _mm_unpackhi_pd(d, d);
let prod_re = _mm_mul_pd(d_re, t_re);
let prod_im = _mm_mul_pd(d_im, t_im);
let prod_cross_re = _mm_mul_pd(d_re, t_im);
let prod_cross_im = _mm_mul_pd(d_im, t_re);
let res_re = _mm_sub_pd(prod_re, prod_im);
let res_im = _mm_add_pd(prod_cross_re, prod_cross_im);
let result = _mm_unpacklo_pd(res_re, res_im);
unsafe { _mm_storeu_pd(data_ptr.add(i * 2), result) };
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_soa_neon_f64(
data: &mut [Complex<f64>],
twiddle_re: &[f64],
twiddle_im: &[f64],
) {
use core::arch::aarch64::*;
let data_ptr = data.as_mut_ptr() as *mut f64;
for i in 0..data.len() {
unsafe {
let d = vld1q_f64(data_ptr.add(i * 2));
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
let d_re = vdupq_lane_f64(vget_low_f64(d), 0);
let d_im = vdupq_lane_f64(vget_high_f64(d), 0);
let t_re = vdupq_n_f64(tw_re);
let t_im = vdupq_n_f64(tw_im);
let prod_re_re = vmulq_f64(d_re, t_re);
let prod_im_im = vmulq_f64(d_im, t_im);
let prod_re_im = vmulq_f64(d_re, t_im);
let prod_im_re = vmulq_f64(d_im, t_re);
let res_re = vsubq_f64(prod_re_re, prod_im_im);
let res_im = vaddq_f64(prod_re_im, prod_im_re);
let result = vzip1q_f64(res_re, res_im);
vst1q_f64(data_ptr.add(i * 2), result);
}
}
}
pub fn twiddle_mul_soa_scalar_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
debug_assert_eq!(data.len(), twiddle_re.len(), "SoA re length mismatch (f32)");
debug_assert_eq!(data.len(), twiddle_im.len(), "SoA im length mismatch (f32)");
for ((d, &tw_re), &tw_im) in data
.iter_mut()
.zip(twiddle_re.iter())
.zip(twiddle_im.iter())
{
let (d_re, d_im) = (d.re, d.im);
d.re = d_re * tw_re - d_im * tw_im;
d.im = d_re * tw_im + d_im * tw_re;
}
}
pub fn twiddle_mul_soa_simd_f32(data: &mut [Complex<f32>], twiddle_re: &[f32], twiddle_im: &[f32]) {
assert_eq!(
data.len(),
twiddle_re.len(),
"twiddle_mul_soa_simd_f32: re length mismatch"
);
assert_eq!(
data.len(),
twiddle_im.len(),
"twiddle_mul_soa_simd_f32: im length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_soa_avx2_f32(data, twiddle_re, twiddle_im) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_soa_sse2_f32(data, twiddle_re, twiddle_im) };
}
return twiddle_mul_soa_scalar_f32(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_soa_neon_f32(data, twiddle_re, twiddle_im) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_soa_scalar_f32(data, twiddle_re, twiddle_im);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_soa_avx2_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
use core::arch::x86_64::*;
let chunks = data.len() / 4;
let data_ptr = data.as_mut_ptr() as *mut f32;
let re_ptr = twiddle_re.as_ptr();
let im_ptr = twiddle_im.as_ptr();
for i in 0..chunks {
let d = unsafe { _mm256_loadu_ps(data_ptr.add(i * 8)) };
let r0 = unsafe { *re_ptr.add(i * 4) };
let r1 = unsafe { *re_ptr.add(i * 4 + 1) };
let r2 = unsafe { *re_ptr.add(i * 4 + 2) };
let r3 = unsafe { *re_ptr.add(i * 4 + 3) };
let i0 = unsafe { *im_ptr.add(i * 4) };
let i1 = unsafe { *im_ptr.add(i * 4 + 1) };
let i2 = unsafe { *im_ptr.add(i * 4 + 2) };
let i3 = unsafe { *im_ptr.add(i * 4 + 3) };
let t = _mm256_set_ps(i3, r3, i2, r2, i1, r1, i0, r0);
let t_swap = _mm256_permute_ps(t, 0b10_11_00_01);
let d_re = _mm256_moveldup_ps(d);
let d_im = _mm256_movehdup_ps(d);
let prod1 = _mm256_mul_ps(d_re, t); let prod2 = _mm256_mul_ps(d_im, t_swap); let result = _mm256_addsub_ps(prod1, prod2);
unsafe { _mm256_storeu_ps(data_ptr.add(i * 8), result) };
}
let remainder_start = chunks * 4;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_soa_sse2_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
for i in 0..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_soa_neon_f32(
data: &mut [Complex<f32>],
twiddle_re: &[f32],
twiddle_im: &[f32],
) {
use core::arch::aarch64::*;
let chunks = data.len() / 2;
let data_ptr = data.as_mut_ptr() as *mut f32;
for i in 0..chunks {
unsafe {
let d = vld1q_f32(data_ptr.add(i * 4));
let r0 = twiddle_re[i * 2];
let r1 = twiddle_re[i * 2 + 1];
let im0 = twiddle_im[i * 2];
let im1 = twiddle_im[i * 2 + 1];
let tw_re_arr = [r0, r0, r1, r1];
let tw_im_arr = [im0, im0, im1, im1];
let t_re = vld1q_f32(tw_re_arr.as_ptr());
let t_im = vld1q_f32(tw_im_arr.as_ptr());
let d_re = vtrn1q_f32(d, d);
let d_im = vtrn2q_f32(d, d);
let res_re = vmlsq_f32(vmulq_f32(d_re, t_re), d_im, t_im);
let res_im = vmlaq_f32(vmulq_f32(d_re, t_im), d_im, t_re);
let result = vtrn1q_f32(res_re, res_im);
vst1q_f32(data_ptr.add(i * 4), result);
}
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
let d_re = data[i].re;
let d_im = data[i].im;
let tw_re = twiddle_re[i];
let tw_im = twiddle_im[i];
data[i].re = d_re * tw_re - d_im * tw_im;
data[i].im = d_re * tw_im + d_im * tw_re;
}
}
pub fn twiddle_mul_simd_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
assert_eq!(
data.len(),
twiddles.len(),
"twiddle_mul_simd_f32: length mismatch"
);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { twiddle_mul_avx2_f32(data, twiddles) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { twiddle_mul_sse2_f32(data, twiddles) };
}
return twiddle_mul_scalar_f32(data, twiddles);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { twiddle_mul_neon_f32(data, twiddles) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
twiddle_mul_scalar_f32(data, twiddles);
}
pub fn twiddle_mul_scalar_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
for (d, t) in data.iter_mut().zip(twiddles.iter()) {
*d = *d * *t;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn twiddle_mul_avx2_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 4; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
let d = unsafe { _mm256_loadu_ps(data_ptr.add(i * 8)) };
let t = unsafe { _mm256_loadu_ps(tw_ptr.add(i * 8)) };
let a_re = _mm256_moveldup_ps(d);
let a_im = _mm256_movehdup_ps(d);
let t_swap = _mm256_permute_ps(t, 0b10_11_00_01);
let prod1 = _mm256_mul_ps(a_re, t);
let prod2 = _mm256_mul_ps(a_im, t_swap);
let result = _mm256_addsub_ps(prod1, prod2);
unsafe { _mm256_storeu_ps(data_ptr.add(i * 8), result) };
}
let remainder_start = chunks * 4;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn twiddle_mul_sse2_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::x86_64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
let d = unsafe { _mm_loadu_ps(data_ptr.add(i * 4)) };
let t = unsafe { _mm_loadu_ps(tw_ptr.add(i * 4)) };
let a_re = unsafe { _mm_moveldup_ps(d) };
let a_im = unsafe { _mm_movehdup_ps(d) };
let t_swap = _mm_shuffle_ps(t, t, 0b10_11_00_01);
let prod1 = _mm_mul_ps(a_re, t);
let prod2 = _mm_mul_ps(a_im, t_swap);
let result = unsafe { _mm_addsub_ps(prod1, prod2) };
unsafe { _mm_storeu_ps(data_ptr.add(i * 4), result) };
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn twiddle_mul_neon_f32(data: &mut [Complex<f32>], twiddles: &[Complex<f32>]) {
use core::arch::aarch64::*;
let chunks = data.len() / 2; let data_ptr = data.as_mut_ptr() as *mut f32;
let tw_ptr = twiddles.as_ptr() as *const f32;
for i in 0..chunks {
unsafe {
let d = vld1q_f32(data_ptr.add(i * 4));
let t = vld1q_f32(tw_ptr.add(i * 4));
let a_re = vtrn1q_f32(d, d);
let a_im = vtrn2q_f32(d, d);
let t_swap = vrev64q_f32(t);
let prod1 = vmulq_f32(a_re, t);
let prod2 = vmulq_f32(a_im, t_swap);
let sign = vld1q_f32([(-1.0_f32), 1.0_f32, (-1.0_f32), 1.0_f32].as_ptr());
let result = vfmaq_f32(prod1, prod2, sign);
vst1q_f32(data_ptr.add(i * 4), result);
}
}
let remainder_start = chunks * 2;
for i in remainder_start..data.len() {
data[i] = data[i] * twiddles[i];
}
}
#[must_use]
pub fn twiddles_mixed_radix(
n: usize,
factors: &[u16],
direction: TwiddleDirection,
) -> Vec<Vec<Complex<f64>>> {
assert!(
!factors.is_empty(),
"twiddles_mixed_radix: factors must be non-empty"
);
let product: usize = factors.iter().map(|&r| r as usize).product();
assert_eq!(
product, n,
"twiddles_mixed_radix: product of factors ({product}) must equal n ({n})"
);
let sign = match direction {
TwiddleDirection::Forward => -1.0_f64,
TwiddleDirection::Inverse => 1.0_f64,
};
let mut tables: Vec<Vec<Complex<f64>>> = Vec::with_capacity(factors.len());
let mut current_n: usize = 1;
for &r_u16 in factors {
let r = r_u16 as usize;
current_n *= r;
let stride = current_n / r;
let table_len = (r - 1) * stride;
let mut table = Vec::with_capacity(table_len);
for j in 1..r {
for s in 0..stride {
let angle = sign * 2.0 * core::f64::consts::PI * (j * s) as f64 / current_n as f64;
table.push(Complex::new(angle.cos(), angle.sin()));
}
}
tables.push(table);
}
tables
}
#[must_use]
pub fn twiddles_mixed_radix_f32(
n: usize,
factors: &[u16],
direction: TwiddleDirection,
) -> Vec<Vec<Complex<f32>>> {
assert!(
!factors.is_empty(),
"twiddles_mixed_radix_f32: factors must be non-empty"
);
let product: usize = factors.iter().map(|&r| r as usize).product();
assert_eq!(
product, n,
"twiddles_mixed_radix_f32: product of factors ({product}) must equal n ({n})"
);
let sign = match direction {
TwiddleDirection::Forward => -1.0_f32,
TwiddleDirection::Inverse => 1.0_f32,
};
let mut tables: Vec<Vec<Complex<f32>>> = Vec::with_capacity(factors.len());
let mut current_n: usize = 1;
for &r_u16 in factors {
let r = r_u16 as usize;
current_n *= r;
let stride = current_n / r;
let table_len = (r - 1) * stride;
let mut table = Vec::with_capacity(table_len);
for j in 1..r {
for s in 0..stride {
let angle = sign * 2.0 * core::f32::consts::PI * (j * s) as f32 / current_n as f32;
table.push(Complex::new(angle.cos(), angle.sin()));
}
}
tables.push(table);
}
tables
}
#[cfg(test)]
mod twiddle_tests;