use crate::shift::{fftshift_axis, ifftshift_axis};
use ndarray::{Array, ArrayViewMut, Axis, Dimension};
use num_complex::Complex32;
use rustfft::{num_complex::Complex, FftPlanner};
use std::sync::Arc;
pub fn ifft2_inplace<D: Dimension>(a: &mut Array<Complex32, D>, axes: (usize, usize)) {
let (a1, a2) = axes;
ifftshift_axis(a, a1);
ifftshift_axis(a, a2);
ifft_axis(a.view_mut(), a1);
ifft_axis(a.view_mut(), a2);
fftshift_axis(a, a1);
fftshift_axis(a, a2);
}
pub fn ifft3_inplace<D: Dimension>(a: &mut Array<Complex32, D>, axes: (usize, usize, usize)) {
let (a1, a2, a3) = axes;
ifftshift_axis(a, a1);
ifftshift_axis(a, a2);
ifftshift_axis(a, a3);
ifft_axis(a.view_mut(), a1);
ifft_axis(a.view_mut(), a2);
ifft_axis(a.view_mut(), a3);
fftshift_axis(a, a1);
fftshift_axis(a, a2);
fftshift_axis(a, a3);
}
pub fn ifft1_inplace<D: Dimension>(a: &mut Array<Complex32, D>, axis: usize) {
ifftshift_axis(a, axis);
ifft_axis(a.view_mut(), axis);
fftshift_axis(a, axis);
}
fn ifft_axis<D: Dimension>(mut a: ArrayViewMut<Complex32, D>, axis: usize) {
let n = a.len_of(Axis(axis));
if n < 2 {
return;
}
let mut planner = FftPlanner::<f32>::new();
let fft: Arc<dyn rustfft::Fft<f32>> = planner.plan_fft_inverse(n);
let scratch_len = fft.get_inplace_scratch_len();
let mut scratch: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); scratch_len];
let mut lane_buf: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
let norm = 1.0f32 / (n as f32);
a.lanes_mut(Axis(axis)).into_iter().for_each(|mut lane| {
for i in 0..n {
lane_buf[i] = lane[i];
}
fft.process_with_scratch(&mut lane_buf, &mut scratch);
for i in 0..n {
lane[i] = lane_buf[i] * norm;
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn ifft2_impulse_gives_constant() {
let n = 8;
let mut k: Array2<Complex32> = Array2::zeros((n, n));
k[[n / 2, n / 2]] = Complex32::new(1.0, 0.0);
ifft2_inplace(&mut k, (0, 1));
let expected = 1.0 / (n as f32 * n as f32);
for v in k.iter() {
assert!(
(v.norm() - expected).abs() < 1e-6,
"expected |{}| ~= {}, got {}",
v,
expected,
v.norm()
);
}
}
#[test]
fn ifft2_roundtrip() {
let n = 16;
let mut k: Array2<Complex32> = Array2::zeros((n, n));
k[[0, 0]] = Complex32::new(1.0, 0.0);
k[[1, 2]] = Complex32::new(0.5, -0.25);
let before_sum: f32 = k.iter().map(|c| c.norm_sqr()).sum();
ifft2_inplace(&mut k, (0, 1));
let after_sum: f32 = k.iter().map(|c| c.norm_sqr()).sum();
let n2 = (n * n) as f32;
let expected = before_sum / n2;
assert!(
(after_sum - expected).abs() < 1e-5,
"Parseval mismatch: before={before_sum}, after={after_sum}, expected={expected}"
);
}
#[test]
fn ifft3_impulse_gives_constant() {
use ndarray::Array3;
let (nz, ny, nx) = (4, 8, 8);
let mut k: Array3<Complex32> = Array3::zeros((nz, ny, nx));
k[[nz / 2, ny / 2, nx / 2]] = Complex32::new(1.0, 0.0);
ifft3_inplace(&mut k, (0, 1, 2));
let expected = 1.0 / (nz as f32 * ny as f32 * nx as f32);
for v in k.iter() {
assert!(
(v.norm() - expected).abs() < 1e-6,
"expected |{}| ~= {}, got {}",
v,
expected,
v.norm()
);
}
}
#[test]
fn ifft1_decouples_from_2d() {
use ndarray::Array3;
let (nz, ny, nx) = (4, 6, 6);
let mut k_full: Array3<Complex32> = Array3::zeros((nz, ny, nx));
for z in 0..nz {
for y in 0..ny {
for x in 0..nx {
let a = Complex32::new((z as f32 + 1.0) * 0.5, 0.0);
let b = Complex32::new(y as f32 - 2.0, (x as f32) * 0.25);
k_full[[z, y, x]] = a * b;
}
}
}
let mut a = k_full.clone();
ifft3_inplace(&mut a, (0, 1, 2));
let mut b = k_full.clone();
ifft1_inplace(&mut b, 0);
ifft2_inplace(&mut b, (1, 2));
for z in 0..nz {
for y in 0..ny {
for x in 0..nx {
let err = (a[[z, y, x]] - b[[z, y, x]]).norm();
assert!(err < 1e-4, "decouple mismatch at ({z},{y},{x}): {err}");
}
}
}
}
}