use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::error::{FerrayError, FerrayResult};
use crate::nd::{fft_1d_along_axis, fft_along_axes};
use crate::norm::FftNorm;
fn to_complex_flat<D: Dimension>(a: &Array<Complex<f64>, D>) -> Vec<Complex<f64>> {
a.iter().copied().collect()
}
fn resolve_axis(ndim: usize, axis: Option<usize>) -> FerrayResult<usize> {
match axis {
Some(ax) => {
if ax >= ndim {
Err(FerrayError::axis_out_of_bounds(ax, ndim))
} else {
Ok(ax)
}
}
None => {
if ndim == 0 {
Err(FerrayError::invalid_value(
"cannot compute FFT on a 0-dimensional array",
))
} else {
Ok(ndim - 1)
}
}
}
}
fn resolve_axes(ndim: usize, axes: Option<&[usize]>) -> FerrayResult<Vec<usize>> {
match axes {
Some(ax) => {
for &a in ax {
if a >= ndim {
return Err(FerrayError::axis_out_of_bounds(a, ndim));
}
}
Ok(ax.to_vec())
}
None => Ok((0..ndim).collect()),
}
}
fn resolve_shapes(
input_shape: &[usize],
axes: &[usize],
s: Option<&[usize]>,
) -> FerrayResult<Vec<Option<usize>>> {
match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
Ok(sizes.iter().map(|&sz| Some(sz)).collect())
}
None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
}
}
pub fn fft<D: Dimension>(
a: &Array<Complex<f64>, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let data = to_complex_flat(a);
let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, false, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
pub fn ifft<D: Dimension>(
a: &Array<Complex<f64>, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let data = to_complex_flat(a);
let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, true, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
pub fn fft2<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => ax.to_vec(),
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"fft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
fftn_impl(a, s, &axes, false, norm)
}
pub fn ifft2<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => ax.to_vec(),
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"ifft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
fftn_impl(a, s, &axes, true, norm)
}
pub fn fftn<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ax = resolve_axes(a.shape().len(), axes)?;
fftn_impl(a, s, &ax, false, norm)
}
pub fn ifftn<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ax = resolve_axes(a.shape().len(), axes)?;
fftn_impl(a, s, &ax, true, norm)
}
fn fftn_impl<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: &[usize],
inverse: bool,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let shape = a.shape().to_vec();
let sizes = resolve_shapes(&shape, axes, s)?;
let data = to_complex_flat(a);
let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
let (new_shape, result) = fft_along_axes(&data, &shape, &axes_and_sizes, inverse, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}
fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn fft_impulse() {
let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
let result = fft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
for val in result.iter() {
assert!((val.re - 1.0).abs() < 1e-12);
assert!(val.im.abs() < 1e-12);
}
}
#[test]
fn fft_constant() {
let a = make_1d(vec![c(1.0, 0.0); 4]);
let result = fft(&a, None, None, FftNorm::Backward).unwrap();
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 4.0).abs() < 1e-12);
for v in &vals[1..] {
assert!(v.re.abs() < 1e-12);
assert!(v.im.abs() < 1e-12);
}
}
#[test]
fn fft_ifft_roundtrip() {
let data = vec![
c(1.0, 2.0),
c(-1.0, 0.5),
c(3.0, -1.0),
c(0.0, 0.0),
c(-2.5, 1.5),
c(0.7, -0.3),
c(1.2, 0.8),
c(-0.4, 2.1),
];
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig.re - rec.re).abs() < 1e-10,
"re mismatch: {} vs {}",
orig.re,
rec.re
);
assert!(
(orig.im - rec.im).abs() < 1e-10,
"im mismatch: {} vs {}",
orig.im,
rec.im
);
}
}
#[test]
fn fft_with_n_padding() {
let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 2.0).abs() < 1e-12);
}
#[test]
fn fft_with_n_truncation() {
let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2]);
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 3.0).abs() < 1e-12);
assert!((vals[1].re - (-1.0)).abs() < 1e-12);
}
#[test]
fn fft_non_power_of_two() {
let n = 7;
let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!((orig.re - rec.re).abs() < 1e-10);
assert!((orig.im - rec.im).abs() < 1e-10);
}
}
#[test]
fn fft2_basic() {
use ferray_core::dimension::Ix2;
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 2]);
let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
let orig: Vec<_> = a.iter().copied().collect();
for (o, r) in orig.iter().zip(recovered.iter()) {
assert!((o.re - r.re).abs() < 1e-10);
assert!((o.im - r.im).abs() < 1e-10);
}
}
#[test]
fn fftn_roundtrip_3d() {
use ferray_core::dimension::Ix3;
let n = 2 * 3 * 4;
let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
}
}
#[test]
fn fft_axis_out_of_bounds() {
let a = make_1d(vec![c(1.0, 0.0)]);
assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
}
}