use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::error::{FerrayError, FerrayResult};
use crate::axes::{resolve_axes, resolve_axis};
use crate::float::FftFloat;
use crate::nd::{fft_along_axes, fft_along_axis};
use crate::norm::FftNorm;
enum ComplexData<'a, T: FftFloat>
where
Complex<T>: ferray_core::Element,
{
Borrowed(&'a [Complex<T>]),
Owned(Vec<Complex<T>>),
}
impl<T: FftFloat> std::ops::Deref for ComplexData<'_, T>
where
Complex<T>: ferray_core::Element,
{
type Target = [Complex<T>];
fn deref(&self) -> &[Complex<T>] {
match self {
ComplexData::Borrowed(s) => s,
ComplexData::Owned(v) => v,
}
}
}
fn borrow_complex_flat<T: FftFloat, D: Dimension>(a: &Array<Complex<T>, D>) -> ComplexData<'_, T>
where
Complex<T>: ferray_core::Element,
{
if let Some(s) = a.as_slice() {
ComplexData::Borrowed(s)
} else {
ComplexData::Owned(a.iter().copied().collect())
}
}
fn resolve_shapes(
input_shape: &[usize],
axes: &[usize],
s: Option<&[usize]>,
) -> FerrayResult<Vec<Option<usize>>> {
match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
Ok(sizes.iter().map(|&sz| Some(sz)).collect())
}
None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
}
}
pub fn fft<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let data = borrow_complex_flat(a);
let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, false, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
pub fn ifft<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let data = borrow_complex_flat(a);
let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, true, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
pub fn fft2<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => resolve_axes(ndim, Some(ax))?,
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"fft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
fftn_impl::<T, D>(a, s, &axes, false, norm)
}
pub fn ifft2<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => resolve_axes(ndim, Some(ax))?,
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"ifft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
fftn_impl::<T, D>(a, s, &axes, true, norm)
}
pub fn fftn<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let ax = resolve_axes(a.shape().len(), axes)?;
fftn_impl::<T, D>(a, s, &ax, false, norm)
}
pub fn ifftn<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let ax = resolve_axes(a.shape().len(), axes)?;
fftn_impl::<T, D>(a, s, &ax, true, norm)
}
fn real_to_complex_vec<T: FftFloat, D: Dimension>(a: &Array<T, D>) -> Vec<Complex<T>>
where
Complex<T>: ferray_core::Element,
{
a.iter()
.map(|&v| Complex::new(v, <T as num_traits::Zero>::zero()))
.collect()
}
pub fn fft_real<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let complex_data = real_to_complex_vec(a);
let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
fft::<T, IxDyn>(&complex_arr, n, axis, norm)
}
pub fn ifft_real<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let spectrum = ifft::<T, D>(a, n, axis, norm)?;
let shape = spectrum.shape().to_vec();
let real_data: Vec<T> = spectrum.iter().map(|c| c.re).collect();
Array::from_vec(IxDyn::new(&shape), real_data)
}
pub fn fft_real2<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let complex_data = real_to_complex_vec(a);
let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
fft2::<T, IxDyn>(&complex_arr, s, axes, norm)
}
pub fn fft_realn<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let complex_data = real_to_complex_vec(a);
let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
fftn::<T, IxDyn>(&complex_arr, s, axes, norm)
}
fn fftn_impl<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: &[usize],
inverse: bool,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: ferray_core::Element,
{
let shape = a.shape().to_vec();
let sizes = resolve_shapes(&shape, axes, s)?;
let data = borrow_complex_flat(a);
let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
let (new_shape, result) = fft_along_axes::<T>(&data, &shape, &axes_and_sizes, inverse, norm)?;
Array::from_vec(IxDyn::new(&new_shape), result)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}
fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn fft_impulse() {
let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
let result = fft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
for val in result.iter() {
assert!((val.re - 1.0).abs() < 1e-12);
assert!(val.im.abs() < 1e-12);
}
}
#[test]
fn fft_length_one() {
let a = make_1d(vec![c(7.0, -2.0)]);
let result = fft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[1]);
let v = result.iter().next().unwrap();
assert!((v.re - 7.0).abs() < 1e-12);
assert!((v.im + 2.0).abs() < 1e-12);
let recovered = ifft(&result, None, None, FftNorm::Backward).unwrap();
let r = recovered.iter().next().unwrap();
assert!((r.re - 7.0).abs() < 1e-12);
assert!((r.im + 2.0).abs() < 1e-12);
}
#[test]
fn fft_negative_axis_matches_explicit() {
use ferray_core::dimension::Ix2;
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 2]), data).unwrap();
let neg = fft(&a, None, Some(-1), FftNorm::Backward).unwrap();
let pos = fft(&a, None, Some(1), FftNorm::Backward).unwrap();
assert_eq!(neg.shape(), pos.shape());
for (n, p) in neg.iter().zip(pos.iter()) {
assert!((n.re - p.re).abs() < 1e-12);
assert!((n.im - p.im).abs() < 1e-12);
}
}
#[test]
fn fftn_negative_axes_matches_explicit() {
use ferray_core::dimension::Ix2;
let data: Vec<Complex<f64>> = (0..6).map(|i| c(i as f64, 0.0)).collect();
let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 3]), data).unwrap();
let neg = fftn(&a, None, Some(&[-2, -1][..]), FftNorm::Backward).unwrap();
let pos = fftn(&a, None, Some(&[0, 1][..]), FftNorm::Backward).unwrap();
for (n, p) in neg.iter().zip(pos.iter()) {
assert!((n.re - p.re).abs() < 1e-12);
assert!((n.im - p.im).abs() < 1e-12);
}
}
#[test]
fn fft_constant() {
let a = make_1d(vec![c(1.0, 0.0); 4]);
let result = fft(&a, None, None, FftNorm::Backward).unwrap();
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 4.0).abs() < 1e-12);
for v in &vals[1..] {
assert!(v.re.abs() < 1e-12);
assert!(v.im.abs() < 1e-12);
}
}
#[test]
fn fft_ifft_roundtrip() {
let data = vec![
c(1.0, 2.0),
c(-1.0, 0.5),
c(3.0, -1.0),
c(0.0, 0.0),
c(-2.5, 1.5),
c(0.7, -0.3),
c(1.2, 0.8),
c(-0.4, 2.1),
];
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig.re - rec.re).abs() < 1e-10,
"re mismatch: {} vs {}",
orig.re,
rec.re
);
assert!(
(orig.im - rec.im).abs() < 1e-10,
"im mismatch: {} vs {}",
orig.im,
rec.im
);
}
}
#[test]
fn fft_with_n_padding() {
let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 2.0).abs() < 1e-12);
}
#[test]
fn fft_with_n_truncation() {
let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2]);
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 3.0).abs() < 1e-12);
assert!((vals[1].re - (-1.0)).abs() < 1e-12);
}
#[test]
fn fft_non_power_of_two() {
let n = 7;
let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!((orig.re - rec.re).abs() < 1e-10);
assert!((orig.im - rec.im).abs() < 1e-10);
}
}
#[test]
fn fft2_basic() {
use ferray_core::dimension::Ix2;
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 2]);
let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
let orig: Vec<_> = a.iter().copied().collect();
for (o, r) in orig.iter().zip(recovered.iter()) {
assert!((o.re - r.re).abs() < 1e-10);
assert!((o.im - r.im).abs() < 1e-10);
}
}
#[test]
fn fftn_roundtrip_3d() {
use ferray_core::dimension::Ix3;
let n = 2 * 3 * 4;
let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
}
}
#[test]
fn fft_axis_out_of_bounds() {
let a = make_1d(vec![c(1.0, 0.0)]);
assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
}
#[test]
fn fft2_with_shape_padding() {
use ferray_core::dimension::Ix2;
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
let result = fft2(&a, Some(&[4, 4]), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4, 4]);
}
#[test]
fn fft2_with_shape_truncation() {
use ferray_core::dimension::Ix2;
let data: Vec<Complex<f64>> = (0..16).map(|i| c(i as f64, 0.0)).collect();
let a = Array::from_vec(Ix2::new([4, 4]), data).unwrap();
let result = fft2(&a, Some(&[2, 2]), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 2]);
}
#[test]
fn fftn_with_shape_roundtrip() {
use ferray_core::dimension::Ix2;
let data: Vec<Complex<f64>> = (0..12).map(|i| c(i as f64, 0.0)).collect();
let a = Array::from_vec(Ix2::new([3, 4]), data).unwrap();
let spectrum = fftn(&a, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[4, 8]);
let recovered = ifftn(&spectrum, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[4, 8]);
for i in 0..3 {
for j in 0..4 {
let idx = i * 8 + j;
let orig_val = (i * 4 + j) as f64;
assert!(
(recovered.iter().nth(idx).unwrap().re - orig_val).abs() < 1e-9,
"mismatch at ({i},{j})"
);
}
}
}
#[test]
fn fft_ifft_ortho_roundtrip() {
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Ortho).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!((orig.re - rec.re).abs() < 1e-10);
assert!((orig.im - rec.im).abs() < 1e-10);
}
}
#[test]
fn fft_ifft_forward_roundtrip() {
let data = vec![c(1.0, 2.0), c(-1.0, 0.5), c(3.0, -1.0), c(0.0, 0.0)];
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Forward).unwrap();
let recovered = ifft(&spectrum, None, None, FftNorm::Forward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!((orig.re - rec.re).abs() < 1e-10);
assert!((orig.im - rec.im).abs() < 1e-10);
}
}
#[test]
fn fft_ortho_energy_preservation() {
let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
let a = make_1d(data.clone());
let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
let energy_time: f64 = data.iter().map(|x| x.re * x.re + x.im * x.im).sum();
let energy_freq: f64 = spectrum.iter().map(|x| x.re * x.re + x.im * x.im).sum();
assert!(
(energy_time - energy_freq).abs() < 1e-10,
"Parseval: time={energy_time}, freq={energy_freq}"
);
}
#[test]
fn fft_forward_scaling() {
let a = make_1d(vec![c(1.0, 0.0); 4]);
let result = fft(&a, None, None, FftNorm::Forward).unwrap();
let vals: Vec<_> = result.iter().copied().collect();
assert!((vals[0].re - 1.0).abs() < 1e-12);
for v in &vals[1..] {
assert!(v.re.abs() < 1e-12);
assert!(v.im.abs() < 1e-12);
}
}
#[test]
fn fft_ifft_f32_roundtrip() {
let data: Vec<Complex<f32>> = (0..16)
.map(|i| Complex::new(i as f32 * 0.25, (i as f32).sin()))
.collect();
let a = Array::from_vec(Ix1::new([16]), data.clone()).unwrap();
let spectrum = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[16]);
let recovered = ifft::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig.re - rec.re).abs() < 1e-4,
"f32 re mismatch: {} vs {}",
orig.re,
rec.re
);
assert!(
(orig.im - rec.im).abs() < 1e-4,
"f32 im mismatch: {} vs {}",
orig.im,
rec.im
);
}
}
#[test]
fn fft_f32_impulse() {
let data = vec![
Complex::<f32>::new(1.0, 0.0),
Complex::<f32>::new(0.0, 0.0),
Complex::<f32>::new(0.0, 0.0),
Complex::<f32>::new(0.0, 0.0),
];
let a = Array::from_vec(Ix1::new([4]), data).unwrap();
let result = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
for val in result.iter() {
assert!((val.re - 1.0).abs() < 1e-6);
assert!(val.im.abs() < 1e-6);
}
}
#[test]
fn fft2_f32_roundtrip() {
use ferray_core::dimension::Ix2;
let data: Vec<Complex<f32>> = (0..16)
.map(|i| Complex::new(i as f32, -(i as f32) * 0.25))
.collect();
let a = Array::from_vec(Ix2::new([4, 4]), data.clone()).unwrap();
let spectrum = fft2::<f32, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft2::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o.re - r.re).abs() < 1e-4);
assert!((o.im - r.im).abs() < 1e-4);
}
}
#[test]
fn fft_real_ifft_real_roundtrip_f64() {
let original = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = Array::<f64, Ix1>::from_vec(Ix1::new([8]), original.clone()).unwrap();
let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[8]);
let recovered = ifft_real::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in original.iter().zip(recovered.iter()) {
assert!((o - r).abs() < 1e-10, "mismatch: {} vs {}", o, r);
}
}
#[test]
fn fft_real_ifft_real_roundtrip_f32() {
let original: Vec<f32> = (0..16).map(|i| i as f32 * 0.5 - 2.0).collect();
let a = Array::<f32, Ix1>::from_vec(Ix1::new([16]), original.clone()).unwrap();
let spectrum = fft_real::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
let recovered = ifft_real::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in original.iter().zip(recovered.iter()) {
assert!((o - r).abs() < 1e-4, "f32 mismatch: {} vs {}", o, r);
}
}
#[test]
fn fft_real_dc_component() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
let vals: Vec<_> = spectrum.iter().copied().collect();
assert!((vals[0].re - 4.0).abs() < 1e-12);
assert!(vals[0].im.abs() < 1e-12);
for v in &vals[1..] {
assert!(v.re.abs() < 1e-12);
assert!(v.im.abs() < 1e-12);
}
}
#[test]
fn fft_real2_roundtrip() {
use ferray_core::dimension::Ix2;
let data: Vec<f64> = (0..12).map(|i| i as f64 * 0.3).collect();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
let spectrum = fft_real2::<f64, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[3, 4]);
let recovered = ifft2::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o - r.re).abs() < 1e-10);
assert!(r.im.abs() < 1e-10);
}
}
#[test]
fn fft_realn_3d_roundtrip() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..24).map(|i| (i as f64).sin()).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
let spectrum = fft_realn::<f64, Ix3>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[2, 3, 4]);
let recovered = ifftn::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o - r.re).abs() < 1e-10);
assert!(r.im.abs() < 1e-10);
}
}
}