use crate::Result;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
use scirs2_core::simd_ops::SimdUnifiedOps;
use scirs2_core::Complex;
use std::f64::consts::PI;
pub struct SimdComplexOps;
impl SimdComplexOps {
pub fn complex_multiply(
a: &ArrayView1<Complex<f64>>,
b: &ArrayView1<Complex<f64>>,
) -> Result<Array1<Complex<f64>>> {
if a.len() != b.len() {
return Err(crate::NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
let a_real: Array1<f64> = a.map(|c| c.re);
let a_imag: Array1<f64> = a.map(|c| c.im);
let b_real: Array1<f64> = b.map(|c| c.re);
let b_imag: Array1<f64> = b.map(|c| c.im);
let ac = f64::simd_mul(&a_real.view(), &b_real.view());
let bd = f64::simd_mul(&a_imag.view(), &b_imag.view());
let real_part = f64::simd_sub(&ac.view(), &bd.view());
let ad = f64::simd_mul(&a_real.view(), &b_imag.view());
let bc = f64::simd_mul(&a_imag.view(), &b_real.view());
let imag_part = f64::simd_add(&ad.view(), &bc.view());
let mut result = Array1::zeros(a.len());
Zip::from(&mut result)
.and(&real_part)
.and(&imag_part)
.for_each(|out, &re, &im| {
*out = Complex::new(re, im);
});
Ok(result)
}
pub fn complex_conjugate(a: &ArrayView1<Complex<f64>>) -> Array1<Complex<f64>> {
let real_part: Array1<f64> = a.map(|c| c.re);
let imag_part: Array1<f64> = a.map(|c| c.im);
let neg_one = Array1::from_elem(imag_part.len(), -1.0);
let neg_imag = f64::simd_mul(&imag_part.view(), &neg_one.view());
let mut result = Array1::zeros(a.len());
Zip::from(&mut result)
.and(&real_part)
.and(&neg_imag)
.for_each(|out, &re, &im| {
*out = Complex::new(re, im);
});
result
}
pub fn complex_magnitude(a: &ArrayView1<Complex<f64>>) -> Array1<f64> {
let real_part: Array1<f64> = a.map(|c| c.re);
let imag_part: Array1<f64> = a.map(|c| c.im);
let real_squared = f64::simd_mul(&real_part.view(), &real_part.view());
let imag_squared = f64::simd_mul(&imag_part.view(), &imag_part.view());
let sum_squared = f64::simd_add(&real_squared.view(), &imag_squared.view());
f64::simd_sqrt(&sum_squared.view())
}
pub fn complex_phase(a: &ArrayView1<Complex<f64>>) -> Array1<f64> {
use crate::optimized_ops::enhanced_math::parallel_atan2;
let real_part: Array1<f64> = a.map(|c| c.re);
let imag_part: Array1<f64> = a.map(|c| c.im);
parallel_atan2(&imag_part.view(), &real_part.view())
.unwrap_or_else(|_| Array1::zeros(a.len()))
}
pub fn complex_exp(a: &ArrayView1<Complex<f64>>) -> Array1<Complex<f64>> {
use crate::optimized_ops::enhanced_exp::parallel_exp;
use crate::optimized_ops::enhanced_math::{parallel_cos, parallel_sin};
let real_part: Array1<f64> = a.map(|c| c.re);
let imag_part: Array1<f64> = a.map(|c| c.im);
let exp_real = parallel_exp(&real_part.view());
let cos_imag = parallel_cos(&imag_part.view());
let sin_imag = parallel_sin(&imag_part.view());
let result_real = f64::simd_mul(&exp_real.view(), &cos_imag.view());
let result_imag = f64::simd_mul(&exp_real.view(), &sin_imag.view());
let mut result = Array1::zeros(a.len());
Zip::from(&mut result)
.and(&result_real)
.and(&result_imag)
.for_each(|out, &re, &im| {
*out = Complex::new(re, im);
});
result
}
}
pub struct SimdFft;
impl SimdFft {
pub fn fft(input: &ArrayView1<Complex<f64>>) -> Result<Array1<Complex<f64>>> {
let n = input.len();
if n == 0 || (n & (n - 1)) != 0 {
return Err(crate::NumRs2Error::InvalidOperation(
"FFT input length must be a power of 2".to_string(),
));
}
let mut output = input.to_owned();
Self::fft_recursive(&mut output.view_mut());
Ok(output)
}
fn fft_recursive(data: &mut scirs2_core::ndarray::ArrayViewMut1<Complex<f64>>) {
let n = data.len();
if n <= 1 {
return;
}
let mut even = Array1::zeros(n / 2);
let mut odd = Array1::zeros(n / 2);
for i in 0..n / 2 {
even[i] = data[2 * i];
odd[i] = data[2 * i + 1];
}
Self::fft_recursive(&mut even.view_mut());
Self::fft_recursive(&mut odd.view_mut());
let angle_step = -2.0 * PI / n as f64;
let angles: Array1<f64> =
Array1::from_vec((0..n / 2).map(|k| angle_step * k as f64).collect());
let cos_angles = crate::optimized_ops::enhanced_math::parallel_cos(&angles.view());
let sin_angles = crate::optimized_ops::enhanced_math::parallel_sin(&angles.view());
let mut twiddle_factors = Array1::zeros(n / 2);
Zip::from(&mut twiddle_factors)
.and(&cos_angles)
.and(&sin_angles)
.for_each(|out, &c, &s| {
*out = Complex::new(c, s);
});
let twiddle_times_odd = SimdComplexOps::complex_multiply(
&twiddle_factors.view(),
&odd.view(),
)
.expect(
"complex multiplication of equal-length twiddle factors and odd array should succeed",
);
for k in 0..n / 2 {
let t = twiddle_times_odd[k];
data[k] = even[k] + t;
data[k + n / 2] = even[k] - t;
}
}
pub fn ifft(input: &ArrayView1<Complex<f64>>) -> Result<Array1<Complex<f64>>> {
let n = input.len();
let conjugated = SimdComplexOps::complex_conjugate(input);
let fft_result = Self::fft(&conjugated.view())?;
let mut result = SimdComplexOps::complex_conjugate(&fft_result.view());
let scale = 1.0 / n as f64;
result.map_inplace(|c| *c *= scale);
Ok(result)
}
pub fn fft2d(input: &ArrayView2<Complex<f64>>) -> Result<Array2<Complex<f64>>> {
let (rows, cols) = input.dim();
if rows == 0 || (rows & (rows - 1)) != 0 || cols == 0 || (cols & (cols - 1)) != 0 {
return Err(crate::NumRs2Error::InvalidOperation(
"FFT2D dimensions must be powers of 2".to_string(),
));
}
let mut result = input.to_owned();
for i in 0..rows {
let row_fft = Self::fft(&result.row(i))?;
result.row_mut(i).assign(&row_fft);
}
for j in 0..cols {
let col_fft = Self::fft(&result.column(j))?;
result.column_mut(j).assign(&col_fft);
}
Ok(result)
}
}
pub struct SimdConvolution;
impl SimdConvolution {
pub fn convolve(signal: &ArrayView1<f64>, kernel: &ArrayView1<f64>) -> Result<Array1<f64>> {
let n = signal.len();
let m = kernel.len();
let output_len = n + m - 1;
let fft_len = output_len.next_power_of_two();
let mut padded_signal = Array1::zeros(fft_len);
let mut padded_kernel = Array1::zeros(fft_len);
padded_signal
.slice_mut(scirs2_core::ndarray::s![..n])
.assign(signal);
padded_kernel
.slice_mut(scirs2_core::ndarray::s![..m])
.assign(kernel);
let complex_signal: Array1<Complex<f64>> = padded_signal.map(|&x| Complex::new(x, 0.0));
let complex_kernel: Array1<Complex<f64>> = padded_kernel.map(|&x| Complex::new(x, 0.0));
let signal_fft = SimdFft::fft(&complex_signal.view())?;
let kernel_fft = SimdFft::fft(&complex_kernel.view())?;
let product = SimdComplexOps::complex_multiply(&signal_fft.view(), &kernel_fft.view())?;
let result_complex = SimdFft::ifft(&product.view())?;
let result_real: Array1<f64> = result_complex.map(|c| c.re);
Ok(result_real
.slice(scirs2_core::ndarray::s![..output_len])
.to_owned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_complex_multiply() {
let a = Array1::from_vec(vec![Complex::new(1.0, 2.0), Complex::new(3.0, 4.0)]);
let b = Array1::from_vec(vec![Complex::new(5.0, 6.0), Complex::new(7.0, 8.0)]);
let result = SimdComplexOps::complex_multiply(&a.view(), &b.view())
.expect("complex_multiply should succeed for equal length arrays");
assert_relative_eq!(result[0].re, -7.0, epsilon = 1e-10);
assert_relative_eq!(result[0].im, 16.0, epsilon = 1e-10);
assert_relative_eq!(result[1].re, -11.0, epsilon = 1e-10);
assert_relative_eq!(result[1].im, 52.0, epsilon = 1e-10);
}
#[test]
fn test_complex_magnitude() {
let a = Array1::from_vec(vec![Complex::new(3.0, 4.0), Complex::new(5.0, 12.0)]);
let result = SimdComplexOps::complex_magnitude(&a.view());
assert_relative_eq!(result[0], 5.0, epsilon = 1e-10);
assert_relative_eq!(result[1], 13.0, epsilon = 1e-10);
}
#[test]
fn test_fft_simple() {
let input = Array1::from_vec(vec![
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
]);
let result =
SimdFft::fft(&input.view()).expect("FFT should succeed for power-of-2 length input");
assert_relative_eq!(result[0].re, 4.0, epsilon = 1e-10);
assert_relative_eq!(result[0].im, 0.0, epsilon = 1e-10);
for i in 1..4 {
assert_relative_eq!(result[i].norm(), 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_fft_inverse() {
let input = Array1::from_vec(vec![
Complex::new(1.0, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
]);
let fft_result =
SimdFft::fft(&input.view()).expect("FFT should succeed for power-of-2 length input");
let ifft_result =
SimdFft::ifft(&fft_result.view()).expect("IFFT should succeed for valid FFT output");
for i in 0..4 {
assert_relative_eq!(ifft_result[i].re, input[i].re, epsilon = 1e-10);
assert_relative_eq!(ifft_result[i].im, input[i].im, epsilon = 1e-10);
}
}
}