rifft 1.1.2

RIFFT FFT/DLPack/FFI bridge
Documentation
use num_complex::Complex32;

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

pub fn complex_mul_inplace(a: &mut [Complex32], b: &[Complex32]) {
    assert_eq!(a.len(), b.len());
    if a.is_empty() {
        return;
    }
    #[cfg(target_arch = "aarch64")]
    {
        neon_complex_mul(a, b);
    }
    #[cfg(not(target_arch = "aarch64"))]
    {
        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
        {
            if try_x86_avx2(a, b) {
                return;
            }
        }
        complex_mul_scalar(a, b);
    }
}

fn complex_mul_scalar(a: &mut [Complex32], b: &[Complex32]) {
    for (lhs, rhs) in a.iter_mut().zip(b.iter()) {
        *lhs *= *rhs;
    }
}

#[cfg(target_arch = "aarch64")]
fn neon_complex_mul(a: &mut [Complex32], b: &[Complex32]) {
    use core::arch::aarch64::*;

    let len = a.len();
    let mut i = 0;
    unsafe {
        while i + 4 <= len {
            let a_ptr = a.as_mut_ptr().add(i) as *mut f32;
            let b_ptr = b.as_ptr().add(i) as *const f32;
            let a_vals = vld2q_f32(a_ptr as *const f32);
            let b_vals = vld2q_f32(b_ptr);
            let are = a_vals.0;
            let aim = a_vals.1;
            let bre = b_vals.0;
            let bim = b_vals.1;
            let real = vsubq_f32(vmulq_f32(are, bre), vmulq_f32(aim, bim));
            let imag = vaddq_f32(vmulq_f32(are, bim), vmulq_f32(aim, bre));
            let out = float32x4x2_t(real, imag);
            vst2q_f32(a_ptr, out);
            i += 4;
        }
    }
    if i < len {
        complex_mul_scalar(&mut a[i..], &b[i..]);
    }
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn try_x86_avx2(a: &mut [Complex32], b: &[Complex32]) -> bool {
    if std::arch::is_x86_feature_detected!("avx2") {
        unsafe { complex_mul_avx2(a, b) };
        true
    } else {
        false
    }
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn complex_mul_avx2(a: &mut [Complex32], b: &[Complex32]) {
    let len = a.len();
    let mut i = 0;
    while i + 4 <= len {
        let a_ptr = a.as_mut_ptr().add(i) as *mut f32;
        let b_ptr = b.as_ptr().add(i) as *const f32;
        let a_vec = _mm256_loadu_ps(a_ptr);
        let b_vec = _mm256_loadu_ps(b_ptr);
        let a_real = _mm256_moveldup_ps(a_vec);
        let a_imag = _mm256_movehdup_ps(a_vec);
        let b_swapped = _mm256_shuffle_ps(b_vec, b_vec, 0b1011_0001);
        let mult_re = _mm256_mul_ps(a_real, b_vec);
        let mult_im = _mm256_mul_ps(a_imag, b_swapped);
        let result = _mm256_addsub_ps(mult_re, mult_im);
        _mm256_storeu_ps(a_ptr, result);
        i += 4;
    }
    if i < len {
        complex_mul_scalar(&mut a[i..], &b[i..]);
    }
}

#[cfg(test)]
mod tests {
    use super::complex_mul_inplace;
    use num_complex::Complex32;

    #[test]
    fn complex_mul_matches_scalar() {
        let mut lhs = (0..37)
            .map(|i| Complex32::new(i as f32 * 0.5, -(i as f32) * 0.25))
            .collect::<Vec<_>>();
        let rhs = (0..37)
            .map(|i| Complex32::new((i as f32).sin(), (i as f32).cos()))
            .collect::<Vec<_>>();
        let mut expected = lhs.clone();
        for (a, b) in expected.iter_mut().zip(rhs.iter()) {
            *a *= *b;
        }
        complex_mul_inplace(&mut lhs, &rhs);
        for (got, want) in lhs.iter().zip(expected.iter()) {
            assert!((got.re - want.re).abs() < 1e-6);
            assert!((got.im - want.im).abs() < 1e-6);
        }
    }
}