use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::ndarray::{Array2, ArrayD, Axis};
use scirs2_core::numeric::Complex64;
pub fn fftn_complex(
x: &ArrayD<Complex64>,
axes: Option<&[usize]>,
) -> FFTResult<ArrayD<Complex64>> {
let ndim = x.ndim();
let axes_to_transform: Vec<usize> = match axes {
Some(a) => {
for &ax in a {
if ax >= ndim {
return Err(FFTError::ValueError(format!(
"axis {ax} out of bounds for array of ndim={ndim}"
)));
}
}
a.to_vec()
}
None => (0..ndim).collect(),
};
let mut result = x.to_owned();
for ax in axes_to_transform {
apply_fft1d_along_axis(&mut result, ax, false)?;
}
Ok(result)
}
pub fn ifftn_complex(
x: &ArrayD<Complex64>,
axes: Option<&[usize]>,
) -> FFTResult<ArrayD<Complex64>> {
let ndim = x.ndim();
let axes_to_transform: Vec<usize> = match axes {
Some(a) => {
for &ax in a {
if ax >= ndim {
return Err(FFTError::ValueError(format!(
"axis {ax} out of bounds for array of ndim={ndim}"
)));
}
}
a.to_vec()
}
None => (0..ndim).collect(),
};
let mut result = x.to_owned();
for ax in axes_to_transform {
apply_fft1d_along_axis(&mut result, ax, true)?;
}
Ok(result)
}
pub fn fftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
shift2_impl(x, false)
}
pub fn ifftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
shift2_impl(x, true)
}
pub fn fftfreq_nd(shape: &[usize], d: &[f64]) -> FFTResult<Vec<Vec<f64>>> {
if shape.len() != d.len() {
return Err(FFTError::ValueError(format!(
"shape.len()={} must equal d.len()={}",
shape.len(),
d.len()
)));
}
for (i, &spacing) in d.iter().enumerate() {
if spacing <= 0.0 {
return Err(FFTError::ValueError(format!(
"sample spacing d[{i}]={spacing} must be > 0"
)));
}
}
shape
.iter()
.zip(d.iter())
.map(|(&n, &spacing)| fftfreq_1d(n, spacing))
.collect()
}
fn apply_fft1d_along_axis(
data: &mut ArrayD<Complex64>,
axis: usize,
inverse: bool,
) -> FFTResult<()> {
let axis_len = data.shape()[axis];
let mut buf = vec![Complex64::new(0.0, 0.0); axis_len];
for mut lane in data.lanes_mut(Axis(axis)) {
buf.iter_mut().zip(lane.iter()).for_each(|(b, &x)| *b = x);
let n = buf.len();
let transformed = if inverse {
ifft(&buf, Some(n))?
} else {
fft(&buf, Some(n))?
};
lane.iter_mut()
.zip(transformed.iter())
.for_each(|(d, &s)| *d = s);
}
Ok(())
}
fn shift2_impl(x: &Array2<Complex64>, inverse: bool) -> Array2<Complex64> {
let (rows, cols) = x.dim();
let row_shift = if inverse {
rows - rows / 2
} else {
rows / 2
};
let col_shift = if inverse { cols - cols / 2 } else { cols / 2 };
let mut out = Array2::<Complex64>::zeros((rows, cols));
for r in 0..rows {
let new_r = (r + row_shift) % rows;
for c in 0..cols {
let new_c = (c + col_shift) % cols;
out[[new_r, new_c]] = x[[r, c]];
}
}
out
}
fn fftfreq_1d(n: usize, d: f64) -> FFTResult<Vec<f64>> {
if n == 0 {
return Ok(Vec::new());
}
let scale = 1.0 / (n as f64 * d);
let mut freqs = Vec::with_capacity(n);
let p = (n / 2) as i64;
for i in 0..n as i64 {
let k = if i <= p as i64 - (if n % 2 == 0 { 1 } else { 0 }) {
i
} else {
i - n as i64
};
freqs.push(k as f64 * scale);
}
Ok(freqs)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::IxDyn;
use std::f64::consts::PI;
fn make_complex_array(shape: &[usize]) -> ArrayD<Complex64> {
let n: usize = shape.iter().product();
let data: Vec<Complex64> = (0..n)
.map(|i| Complex64::new(i as f64, -(i as f64) * 0.5))
.collect();
ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
}
#[test]
fn test_fftn_ifftn_roundtrip_1d() {
let x = make_complex_array(&[16]);
let s = fftn_complex(&x, None).expect("fftn");
let r = ifftn_complex(&s, None).expect("ifftn");
for (a, b) in x.iter().zip(r.iter()) {
assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
}
}
#[test]
fn test_fftn_ifftn_roundtrip_2d() {
let x = make_complex_array(&[4, 8]);
let s = fftn_complex(&x, None).expect("fftn 2d");
let r = ifftn_complex(&s, None).expect("ifftn 2d");
for (a, b) in x.iter().zip(r.iter()) {
assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
}
}
#[test]
fn test_fftn_ifftn_roundtrip_3d() {
let x = make_complex_array(&[2, 3, 4]);
let s = fftn_complex(&x, None).expect("fftn 3d");
let r = ifftn_complex(&s, None).expect("ifftn 3d");
for (a, b) in x.iter().zip(r.iter()) {
assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
}
}
#[test]
fn test_fftn_partial_axes() {
let x = make_complex_array(&[4, 8]);
let s1 = fftn_complex(&x, Some(&[1])).expect("fftn axis 1");
let r1 = ifftn_complex(&s1, Some(&[1])).expect("ifftn axis 1");
for (a, b) in x.iter().zip(r1.iter()) {
assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
}
}
#[test]
fn test_fftn_out_of_bounds_axis() {
let x = make_complex_array(&[4, 8]);
assert!(fftn_complex(&x, Some(&[2])).is_err()); assert!(ifftn_complex(&x, Some(&[5])).is_err());
}
#[test]
fn test_fftn_shape_preserved() {
let x = make_complex_array(&[3, 5, 7]);
let s = fftn_complex(&x, None).expect("fftn");
assert_eq!(s.shape(), x.shape());
}
#[test]
fn test_fftshift2_roundtrip_even() {
let rows = 4;
let cols = 6;
let data: Vec<Complex64> = (0..(rows * cols) as i32)
.map(|i| Complex64::new(i as f64, 0.0))
.collect();
let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
let shifted = fftshift2(&x);
let recovered = ifftshift2(&shifted);
for r in 0..rows {
for c in 0..cols {
assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
}
}
}
#[test]
fn test_fftshift2_roundtrip_odd() {
let rows = 5;
let cols = 7;
let data: Vec<Complex64> = (0..(rows * cols) as i32)
.map(|i| Complex64::new(i as f64, i as f64 * 0.1))
.collect();
let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
let shifted = fftshift2(&x);
let recovered = ifftshift2(&shifted);
for r in 0..rows {
for c in 0..cols {
assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
assert_relative_eq!(x[[r, c]].im, recovered[[r, c]].im, epsilon = 1e-12);
}
}
}
#[test]
fn test_fftshift2_dc_to_centre() {
let mut data = Array2::<Complex64>::zeros((4, 4));
data[[0, 0]] = Complex64::new(1.0, 0.0);
let shifted = fftshift2(&data);
assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12);
assert_relative_eq!(shifted[[0, 0]].re, 0.0, epsilon = 1e-12);
}
#[test]
fn test_ifftshift2_dc_back() {
let mut data = Array2::<Complex64>::zeros((4, 4));
data[[0, 0]] = Complex64::new(1.0, 0.0);
let shifted = fftshift2(&data);
let recovered = ifftshift2(&shifted);
assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-12);
}
#[test]
fn test_fftfreq_nd_basic() {
let freqs = fftfreq_nd(&[4, 8], &[1.0, 1.0]).expect("fftfreq_nd");
assert_eq!(freqs.len(), 2);
assert_eq!(freqs[0].len(), 4);
assert_eq!(freqs[1].len(), 8);
assert_relative_eq!(freqs[0][0], 0.0, epsilon = 1e-15);
assert_relative_eq!(freqs[1][0], 0.0, epsilon = 1e-15);
}
#[test]
fn test_fftfreq_nd_matches_1d_fftfreq() {
use crate::helper::fftfreq;
let n = 16;
let d = 0.5;
let nd_freqs = fftfreq_nd(&[n], &[d]).expect("nd");
let scalar_freqs = fftfreq(n, d).expect("1d");
assert_eq!(nd_freqs[0].len(), scalar_freqs.len());
for (a, b) in nd_freqs[0].iter().zip(scalar_freqs.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-14);
}
}
#[test]
fn test_fftfreq_nd_spacing() {
let f1 = fftfreq_nd(&[8], &[1.0]).expect("d=1");
let f2 = fftfreq_nd(&[8], &[0.5]).expect("d=0.5");
assert_relative_eq!(f1[0][3], 3.0 / 8.0, epsilon = 1e-14);
assert_relative_eq!(f2[0][3], 3.0 / 4.0, epsilon = 1e-14);
}
#[test]
fn test_fftfreq_nd_mismatch_error() {
assert!(fftfreq_nd(&[4, 8], &[1.0]).is_err()); assert!(fftfreq_nd(&[4], &[0.0]).is_err()); assert!(fftfreq_nd(&[4], &[-1.0]).is_err()); }
#[test]
fn test_fftfreq_nd_empty_axis() {
let freqs = fftfreq_nd(&[0, 4], &[1.0, 1.0]).expect("empty axis ok");
assert_eq!(freqs[0].len(), 0);
assert_eq!(freqs[1].len(), 4);
}
#[test]
fn test_fftshift2_known_pattern() {
let rows = 4;
let cols = 4;
let mut x = Array2::<Complex64>::zeros((rows, cols));
x[[0, 0]] = Complex64::new(1.0, 0.0); x[[0, 2]] = Complex64::new(2.0, 0.0); x[[2, 0]] = Complex64::new(3.0, 0.0); x[[2, 2]] = Complex64::new(4.0, 0.0);
let shifted = fftshift2(&x);
assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12); assert_relative_eq!(shifted[[2, 0]].re, 2.0, epsilon = 1e-12); assert_relative_eq!(shifted[[0, 2]].re, 3.0, epsilon = 1e-12); assert_relative_eq!(shifted[[0, 0]].re, 4.0, epsilon = 1e-12); }
#[test]
fn test_fftn_then_shift_preserves_energy() {
use std::f64::consts::PI;
let n = 8;
let data: Vec<Complex64> = (0..n * n)
.map(|k| {
let r = k / n;
let c = k % n;
let re = (2.0 * PI * r as f64 / n as f64).cos()
* (2.0 * PI * c as f64 / n as f64).cos();
Complex64::new(re, 0.0)
})
.collect();
let x = ArrayD::from_shape_vec(IxDyn(&[n, n]), data).expect("shape");
let spec = fftn_complex(&x, None).expect("fftn");
let energy_x: f64 = x.iter().map(|c| c.norm_sqr()).sum();
let energy_s: f64 = spec.iter().map(|c| c.norm_sqr()).sum();
let n2 = (n * n) as f64;
assert_relative_eq!(energy_s, n2 * energy_x, epsilon = 1e-8 * energy_s.max(1.0));
}
}