Skip to main content

ferray_fft/
complex.rs

1// ferray-fft: Complex FFTs — fft, ifft, fft2, ifft2, fftn, ifftn (REQ-1..REQ-4)
2//
3// Generic over the scalar precision via [`FftFloat`] — works for both
4// `Complex<f64>` and `Complex<f32>` arrays. See issue #426.
5
6use num_complex::Complex;
7
8use ferray_core::Array;
9use ferray_core::dimension::{Dimension, IxDyn};
10use ferray_core::error::{FerrayError, FerrayResult};
11
12use crate::axes::{resolve_axes, resolve_axis};
13use crate::float::FftFloat;
14use crate::nd::{fft_along_axes, fft_along_axis};
15use crate::norm::FftNorm;
16
17// ---------------------------------------------------------------------------
18// Helpers to convert input to flat complex data
19// ---------------------------------------------------------------------------
20
21/// Borrowed or owned contiguous complex data. Borrows when the input
22/// array is C-contiguous; otherwise materializes into an owned `Vec`.
23enum ComplexData<'a, T: FftFloat>
24where
25    Complex<T>: ferray_core::Element,
26{
27    Borrowed(&'a [Complex<T>]),
28    Owned(Vec<Complex<T>>),
29}
30
31impl<T: FftFloat> std::ops::Deref for ComplexData<'_, T>
32where
33    Complex<T>: ferray_core::Element,
34{
35    type Target = [Complex<T>];
36    fn deref(&self) -> &[Complex<T>] {
37        match self {
38            ComplexData::Borrowed(s) => s,
39            ComplexData::Owned(v) => v,
40        }
41    }
42}
43
44fn borrow_complex_flat<T: FftFloat, D: Dimension>(a: &Array<Complex<T>, D>) -> ComplexData<'_, T>
45where
46    Complex<T>: ferray_core::Element,
47{
48    if let Some(s) = a.as_slice() {
49        ComplexData::Borrowed(s)
50    } else {
51        ComplexData::Owned(a.iter().copied().collect())
52    }
53}
54
55/// Resolve shapes for multi-dimensional FFT. If None, use the input shape.
56fn resolve_shapes(
57    input_shape: &[usize],
58    axes: &[usize],
59    s: Option<&[usize]>,
60) -> FerrayResult<Vec<Option<usize>>> {
61    match s {
62        Some(sizes) => {
63            if sizes.len() != axes.len() {
64                return Err(FerrayError::invalid_value(format!(
65                    "shape parameter length {} does not match axes length {}",
66                    sizes.len(),
67                    axes.len(),
68                )));
69            }
70            Ok(sizes.iter().map(|&sz| Some(sz)).collect())
71        }
72        None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
73    }
74}
75
76// ---------------------------------------------------------------------------
77// 1-D FFT (REQ-1)
78// ---------------------------------------------------------------------------
79
80/// Compute the one-dimensional discrete Fourier Transform.
81///
82/// Analogous to `numpy.fft.fft`. The input must be a complex array.
83///
84/// # Parameters
85/// - `a`: Input complex array of any dimensionality.
86/// - `n`: Length of the transformed axis. If `None`, uses the length of
87///   the input along `axis`. If shorter, the input is truncated. If longer,
88///   it is zero-padded.
89/// - `axis`: Axis along which to compute the FFT. Defaults to the last axis.
90/// - `norm`: Normalization mode. Defaults to `FftNorm::Backward`.
91///
92/// # Errors
93/// Returns an error if `axis` is out of bounds or `n` is 0.
94pub fn fft<T: FftFloat, D: Dimension>(
95    a: &Array<Complex<T>, D>,
96    n: Option<usize>,
97    axis: Option<isize>,
98    norm: FftNorm,
99) -> FerrayResult<Array<Complex<T>, IxDyn>>
100where
101    Complex<T>: ferray_core::Element,
102{
103    let shape = a.shape().to_vec();
104    let ndim = shape.len();
105    let ax = resolve_axis(ndim, axis)?;
106    let data = borrow_complex_flat(a);
107
108    let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, false, norm)?;
109
110    Array::from_vec(IxDyn::new(&new_shape), result)
111}
112
113/// Compute the one-dimensional inverse discrete Fourier Transform.
114///
115/// Analogous to `numpy.fft.ifft`.
116///
117/// # Parameters
118/// - `a`: Input complex array.
119/// - `n`: Length of the transformed axis. Defaults to the input length.
120/// - `axis`: Axis along which to compute. Defaults to the last axis.
121/// - `norm`: Normalization mode. Defaults to `FftNorm::Backward` (divides by `n`).
122///
123/// # Errors
124/// Returns an error if `axis` is out of bounds or `n` is 0.
125pub fn ifft<T: FftFloat, D: Dimension>(
126    a: &Array<Complex<T>, D>,
127    n: Option<usize>,
128    axis: Option<isize>,
129    norm: FftNorm,
130) -> FerrayResult<Array<Complex<T>, IxDyn>>
131where
132    Complex<T>: ferray_core::Element,
133{
134    let shape = a.shape().to_vec();
135    let ndim = shape.len();
136    let ax = resolve_axis(ndim, axis)?;
137    let data = borrow_complex_flat(a);
138
139    let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, true, norm)?;
140
141    Array::from_vec(IxDyn::new(&new_shape), result)
142}
143
144// ---------------------------------------------------------------------------
145// 2-D FFT (REQ-3)
146// ---------------------------------------------------------------------------
147
148/// Compute the 2-dimensional discrete Fourier Transform.
149///
150/// Analogous to `numpy.fft.fft2`. Equivalent to calling `fftn` with
151/// `axes = [-2, -1]` (the last two axes).
152///
153/// # Parameters
154/// - `a`: Input complex array (must have at least 2 dimensions).
155/// - `s`: Shape `(n_rows, n_cols)` of the output along the transform axes.
156///   If `None`, uses the input shape.
157/// - `axes`: Axes over which to compute the FFT. Defaults to the last 2 axes.
158/// - `norm`: Normalization mode.
159///
160/// # Errors
161/// Returns an error if the array has fewer than 2 dimensions, axes are
162/// out of bounds, or shape parameters are invalid.
163pub fn fft2<T: FftFloat, D: Dimension>(
164    a: &Array<Complex<T>, D>,
165    s: Option<&[usize]>,
166    axes: Option<&[isize]>,
167    norm: FftNorm,
168) -> FerrayResult<Array<Complex<T>, IxDyn>>
169where
170    Complex<T>: ferray_core::Element,
171{
172    let ndim = a.shape().len();
173    let axes = match axes {
174        Some(ax) => resolve_axes(ndim, Some(ax))?,
175        None => {
176            if ndim < 2 {
177                return Err(FerrayError::invalid_value(
178                    "fft2 requires at least 2 dimensions",
179                ));
180            }
181            vec![ndim - 2, ndim - 1]
182        }
183    };
184    fftn_impl::<T, D>(a, s, &axes, false, norm)
185}
186
187/// Compute the 2-dimensional inverse discrete Fourier Transform.
188///
189/// Analogous to `numpy.fft.ifft2`.
190///
191/// # Parameters
192/// - `a`: Input complex array (must have at least 2 dimensions).
193/// - `s`: Output shape along the transform axes. If `None`, uses input shape.
194/// - `axes`: Axes over which to compute. Defaults to the last 2 axes.
195/// - `norm`: Normalization mode.
196///
197/// # Errors
198/// Returns an error if the array has fewer than 2 dimensions, axes are
199/// out of bounds, or shape parameters are invalid.
200pub fn ifft2<T: FftFloat, D: Dimension>(
201    a: &Array<Complex<T>, D>,
202    s: Option<&[usize]>,
203    axes: Option<&[isize]>,
204    norm: FftNorm,
205) -> FerrayResult<Array<Complex<T>, IxDyn>>
206where
207    Complex<T>: ferray_core::Element,
208{
209    let ndim = a.shape().len();
210    let axes = match axes {
211        Some(ax) => resolve_axes(ndim, Some(ax))?,
212        None => {
213            if ndim < 2 {
214                return Err(FerrayError::invalid_value(
215                    "ifft2 requires at least 2 dimensions",
216                ));
217            }
218            vec![ndim - 2, ndim - 1]
219        }
220    };
221    fftn_impl::<T, D>(a, s, &axes, true, norm)
222}
223
224// ---------------------------------------------------------------------------
225// N-D FFT (REQ-4)
226// ---------------------------------------------------------------------------
227
228/// Compute the N-dimensional discrete Fourier Transform.
229///
230/// Analogous to `numpy.fft.fftn`. Transforms along each of the specified
231/// axes in sequence.
232///
233/// # Parameters
234/// - `a`: Input complex array.
235/// - `s`: Shape of the output along each transform axis. If `None`,
236///   uses the input shape.
237/// - `axes`: Axes over which to compute. If `None`, uses all axes.
238/// - `norm`: Normalization mode.
239///
240/// # Errors
241/// Returns an error if axes are out of bounds or shape parameters
242/// are inconsistent.
243pub fn fftn<T: FftFloat, D: Dimension>(
244    a: &Array<Complex<T>, D>,
245    s: Option<&[usize]>,
246    axes: Option<&[isize]>,
247    norm: FftNorm,
248) -> FerrayResult<Array<Complex<T>, IxDyn>>
249where
250    Complex<T>: ferray_core::Element,
251{
252    let ax = resolve_axes(a.shape().len(), axes)?;
253    fftn_impl::<T, D>(a, s, &ax, false, norm)
254}
255
256/// Compute the N-dimensional inverse discrete Fourier Transform.
257///
258/// Analogous to `numpy.fft.ifftn`.
259///
260/// # Parameters
261/// - `a`: Input complex array.
262/// - `s`: Shape of the output along each transform axis. If `None`,
263///   uses the input shape.
264/// - `axes`: Axes over which to compute. If `None`, uses all axes.
265/// - `norm`: Normalization mode.
266///
267/// # Errors
268/// Returns an error if axes are out of bounds or shape parameters
269/// are inconsistent.
270pub fn ifftn<T: FftFloat, D: Dimension>(
271    a: &Array<Complex<T>, D>,
272    s: Option<&[usize]>,
273    axes: Option<&[isize]>,
274    norm: FftNorm,
275) -> FerrayResult<Array<Complex<T>, IxDyn>>
276where
277    Complex<T>: ferray_core::Element,
278{
279    let ax = resolve_axes(a.shape().len(), axes)?;
280    fftn_impl::<T, D>(a, s, &ax, true, norm)
281}
282
283// ---------------------------------------------------------------------------
284// Real-input convenience wrappers (issue #427)
285//
286// NumPy's `np.fft.fft(real_array)` accepts real arrays directly and
287// promotes them to complex. ferray-fft's native `fft` requires
288// `Array<Complex<T>, D>`, so these wrappers bridge real → complex by
289// wrapping each element in `Complex::new(v, 0)` before calling the
290// generic complex path.
291//
292// For the half-spectrum (Hermitian) form of a real transform use
293// `rfft` / `rfftn` from `real.rs` instead — those return n/2+1 complex
294// values and are ~2× faster for the real-input case.
295// ---------------------------------------------------------------------------
296
297/// Promote a real array to a `Vec<Complex<T>>` with zero imaginary parts.
298fn real_to_complex_vec<T: FftFloat, D: Dimension>(a: &Array<T, D>) -> Vec<Complex<T>>
299where
300    Complex<T>: ferray_core::Element,
301{
302    a.iter()
303        .map(|&v| Complex::new(v, <T as num_traits::Zero>::zero()))
304        .collect()
305}
306
307/// 1-D FFT of a real-valued array. Equivalent to `np.fft.fft(real_array)`.
308///
309/// Auto-promotes the input to complex and delegates to [`fft`]. The
310/// returned spectrum is the full-length complex FFT, not the
311/// Hermitian-folded half-spectrum — use [`crate::rfft`] for that.
312///
313/// # Errors
314/// Forwards any error from [`fft`].
315pub fn fft_real<T: FftFloat, D: Dimension>(
316    a: &Array<T, D>,
317    n: Option<usize>,
318    axis: Option<isize>,
319    norm: FftNorm,
320) -> FerrayResult<Array<Complex<T>, IxDyn>>
321where
322    Complex<T>: ferray_core::Element,
323{
324    let shape = a.shape().to_vec();
325    let complex_data = real_to_complex_vec(a);
326    let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
327    fft::<T, IxDyn>(&complex_arr, n, axis, norm)
328}
329
330/// 1-D inverse FFT returning only the real part. Equivalent to
331/// `np.fft.ifft(arr).real` for the common case where the caller knows
332/// the result is real-valued (e.g. the inverse of a real FFT done via
333/// the full complex path).
334///
335/// For the Hermitian-folded inverse use [`crate::irfft`] instead.
336///
337/// # Errors
338/// Forwards any error from [`ifft`].
339pub fn ifft_real<T: FftFloat, D: Dimension>(
340    a: &Array<Complex<T>, D>,
341    n: Option<usize>,
342    axis: Option<isize>,
343    norm: FftNorm,
344) -> FerrayResult<Array<T, IxDyn>>
345where
346    Complex<T>: ferray_core::Element,
347{
348    let spectrum = ifft::<T, D>(a, n, axis, norm)?;
349    let shape = spectrum.shape().to_vec();
350    let real_data: Vec<T> = spectrum.iter().map(|c| c.re).collect();
351    Array::from_vec(IxDyn::new(&shape), real_data)
352}
353
354/// 2-D FFT of a real-valued array.
355///
356/// Auto-promotes the input to complex and delegates to [`fft2`]. For the
357/// Hermitian-folded form use [`crate::rfft2`].
358pub fn fft_real2<T: FftFloat, D: Dimension>(
359    a: &Array<T, D>,
360    s: Option<&[usize]>,
361    axes: Option<&[isize]>,
362    norm: FftNorm,
363) -> FerrayResult<Array<Complex<T>, IxDyn>>
364where
365    Complex<T>: ferray_core::Element,
366{
367    let shape = a.shape().to_vec();
368    let complex_data = real_to_complex_vec(a);
369    let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
370    fft2::<T, IxDyn>(&complex_arr, s, axes, norm)
371}
372
373/// N-D FFT of a real-valued array.
374///
375/// Auto-promotes the input to complex and delegates to [`fftn`]. For the
376/// Hermitian-folded form use [`crate::rfftn`].
377pub fn fft_realn<T: FftFloat, D: Dimension>(
378    a: &Array<T, D>,
379    s: Option<&[usize]>,
380    axes: Option<&[isize]>,
381    norm: FftNorm,
382) -> FerrayResult<Array<Complex<T>, IxDyn>>
383where
384    Complex<T>: ferray_core::Element,
385{
386    let shape = a.shape().to_vec();
387    let complex_data = real_to_complex_vec(a);
388    let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
389    fftn::<T, IxDyn>(&complex_arr, s, axes, norm)
390}
391
392// ---------------------------------------------------------------------------
393// Internal N-D implementation
394// ---------------------------------------------------------------------------
395
396fn fftn_impl<T: FftFloat, D: Dimension>(
397    a: &Array<Complex<T>, D>,
398    s: Option<&[usize]>,
399    axes: &[usize],
400    inverse: bool,
401    norm: FftNorm,
402) -> FerrayResult<Array<Complex<T>, IxDyn>>
403where
404    Complex<T>: ferray_core::Element,
405{
406    let shape = a.shape().to_vec();
407    let sizes = resolve_shapes(&shape, axes, s)?;
408    let data = borrow_complex_flat(a);
409
410    let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
411
412    let (new_shape, result) = fft_along_axes::<T>(&data, &shape, &axes_and_sizes, inverse, norm)?;
413
414    Array::from_vec(IxDyn::new(&new_shape), result)
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use ferray_core::dimension::Ix1;
421
422    fn c(re: f64, im: f64) -> Complex<f64> {
423        Complex::new(re, im)
424    }
425
426    fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
427        let n = data.len();
428        Array::from_vec(Ix1::new([n]), data).unwrap()
429    }
430
431    #[test]
432    fn fft_impulse() {
433        // FFT of [1, 0, 0, 0] = [1, 1, 1, 1]
434        let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
435        let result = fft(&a, None, None, FftNorm::Backward).unwrap();
436        assert_eq!(result.shape(), &[4]);
437        for val in result.iter() {
438            assert!((val.re - 1.0).abs() < 1e-12);
439            assert!(val.im.abs() < 1e-12);
440        }
441    }
442
443    #[test]
444    fn fft_length_one() {
445        // FFT of a length-1 array is the identity (#229).
446        let a = make_1d(vec![c(7.0, -2.0)]);
447        let result = fft(&a, None, None, FftNorm::Backward).unwrap();
448        assert_eq!(result.shape(), &[1]);
449        let v = result.iter().next().unwrap();
450        assert!((v.re - 7.0).abs() < 1e-12);
451        assert!((v.im + 2.0).abs() < 1e-12);
452
453        // Roundtrip should also be exact for length 1.
454        let recovered = ifft(&result, None, None, FftNorm::Backward).unwrap();
455        let r = recovered.iter().next().unwrap();
456        assert!((r.re - 7.0).abs() < 1e-12);
457        assert!((r.im + 2.0).abs() < 1e-12);
458    }
459
460    #[test]
461    fn fft_negative_axis_matches_explicit() {
462        // axis=-1 on a 2-D array should match axis=1 (#434).
463        use ferray_core::dimension::Ix2;
464        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
465        let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 2]), data).unwrap();
466        let neg = fft(&a, None, Some(-1), FftNorm::Backward).unwrap();
467        let pos = fft(&a, None, Some(1), FftNorm::Backward).unwrap();
468        assert_eq!(neg.shape(), pos.shape());
469        for (n, p) in neg.iter().zip(pos.iter()) {
470            assert!((n.re - p.re).abs() < 1e-12);
471            assert!((n.im - p.im).abs() < 1e-12);
472        }
473    }
474
475    #[test]
476    fn fftn_negative_axes_matches_explicit() {
477        // axes=[-2, -1] on a 2-D array should match axes=[0, 1] (#434).
478        use ferray_core::dimension::Ix2;
479        let data: Vec<Complex<f64>> = (0..6).map(|i| c(i as f64, 0.0)).collect();
480        let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 3]), data).unwrap();
481        let neg = fftn(&a, None, Some(&[-2, -1][..]), FftNorm::Backward).unwrap();
482        let pos = fftn(&a, None, Some(&[0, 1][..]), FftNorm::Backward).unwrap();
483        for (n, p) in neg.iter().zip(pos.iter()) {
484            assert!((n.re - p.re).abs() < 1e-12);
485            assert!((n.im - p.im).abs() < 1e-12);
486        }
487    }
488
489    #[test]
490    fn fft_constant() {
491        // FFT of [1, 1, 1, 1] = [4, 0, 0, 0]
492        let a = make_1d(vec![c(1.0, 0.0); 4]);
493        let result = fft(&a, None, None, FftNorm::Backward).unwrap();
494        let vals: Vec<_> = result.iter().copied().collect();
495        assert!((vals[0].re - 4.0).abs() < 1e-12);
496        for v in &vals[1..] {
497            assert!(v.re.abs() < 1e-12);
498            assert!(v.im.abs() < 1e-12);
499        }
500    }
501
502    #[test]
503    fn fft_ifft_roundtrip() {
504        // AC-1: fft(ifft(a)) roundtrips to within 4 ULPs for complex f64
505        let data = vec![
506            c(1.0, 2.0),
507            c(-1.0, 0.5),
508            c(3.0, -1.0),
509            c(0.0, 0.0),
510            c(-2.5, 1.5),
511            c(0.7, -0.3),
512            c(1.2, 0.8),
513            c(-0.4, 2.1),
514        ];
515        let a = make_1d(data.clone());
516        let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
517        let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
518        for (orig, rec) in data.iter().zip(recovered.iter()) {
519            assert!(
520                (orig.re - rec.re).abs() < 1e-10,
521                "re mismatch: {} vs {}",
522                orig.re,
523                rec.re
524            );
525            assert!(
526                (orig.im - rec.im).abs() < 1e-10,
527                "im mismatch: {} vs {}",
528                orig.im,
529                rec.im
530            );
531        }
532    }
533
534    #[test]
535    fn fft_with_n_padding() {
536        // Pad [1, 1] to length 4 -> FFT of [1, 1, 0, 0]
537        let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
538        let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
539        assert_eq!(result.shape(), &[4]);
540        let vals: Vec<_> = result.iter().copied().collect();
541        assert!((vals[0].re - 2.0).abs() < 1e-12);
542    }
543
544    #[test]
545    fn fft_with_n_truncation() {
546        // Truncate [1, 2, 3, 4] to length 2 -> FFT of [1, 2]
547        let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
548        let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
549        assert_eq!(result.shape(), &[2]);
550        let vals: Vec<_> = result.iter().copied().collect();
551        // FFT of [1, 2] = [3, -1]
552        assert!((vals[0].re - 3.0).abs() < 1e-12);
553        assert!((vals[1].re - (-1.0)).abs() < 1e-12);
554    }
555
556    #[test]
557    fn fft_non_power_of_two() {
558        // AC-2 partial: test non-power-of-2 length
559        let n = 7;
560        let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
561        let a = make_1d(data.clone());
562        let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
563        let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
564        for (orig, rec) in data.iter().zip(recovered.iter()) {
565            assert!((orig.re - rec.re).abs() < 1e-10);
566            assert!((orig.im - rec.im).abs() < 1e-10);
567        }
568    }
569
570    #[test]
571    fn fft2_basic() {
572        use ferray_core::dimension::Ix2;
573        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
574        let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
575        let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
576        assert_eq!(result.shape(), &[2, 2]);
577
578        let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
579        let orig: Vec<_> = a.iter().copied().collect();
580        for (o, r) in orig.iter().zip(recovered.iter()) {
581            assert!((o.re - r.re).abs() < 1e-10);
582            assert!((o.im - r.im).abs() < 1e-10);
583        }
584    }
585
586    #[test]
587    fn fftn_roundtrip_3d() {
588        use ferray_core::dimension::Ix3;
589        let n = 2 * 3 * 4;
590        let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
591        let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
592        let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
593        let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
594        for (o, r) in data.iter().zip(recovered.iter()) {
595            assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
596            assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
597        }
598    }
599
600    #[test]
601    fn fft_axis_out_of_bounds() {
602        let a = make_1d(vec![c(1.0, 0.0)]);
603        assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
604    }
605
606    // --- Shape/padding parameter tests ---
607
608    #[test]
609    fn fft2_with_shape_padding() {
610        use ferray_core::dimension::Ix2;
611        // 2x2 input, pad to 4x4 via s parameter
612        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
613        let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
614        let result = fft2(&a, Some(&[4, 4]), None, FftNorm::Backward).unwrap();
615        assert_eq!(result.shape(), &[4, 4]);
616    }
617
618    #[test]
619    fn fft2_with_shape_truncation() {
620        use ferray_core::dimension::Ix2;
621        // 4x4 input, truncate to 2x2 via s parameter
622        let data: Vec<Complex<f64>> = (0..16).map(|i| c(i as f64, 0.0)).collect();
623        let a = Array::from_vec(Ix2::new([4, 4]), data).unwrap();
624        let result = fft2(&a, Some(&[2, 2]), None, FftNorm::Backward).unwrap();
625        assert_eq!(result.shape(), &[2, 2]);
626    }
627
628    #[test]
629    fn fftn_with_shape_roundtrip() {
630        use ferray_core::dimension::Ix2;
631        // 3x4 input, pad to 4x8, then ifft with same shape should recover padded version
632        let data: Vec<Complex<f64>> = (0..12).map(|i| c(i as f64, 0.0)).collect();
633        let a = Array::from_vec(Ix2::new([3, 4]), data).unwrap();
634        let spectrum = fftn(&a, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
635        assert_eq!(spectrum.shape(), &[4, 8]);
636        let recovered = ifftn(&spectrum, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
637        assert_eq!(recovered.shape(), &[4, 8]);
638        // First 3x4 block should match original (rest is zero-padded)
639        for i in 0..3 {
640            for j in 0..4 {
641                let idx = i * 8 + j;
642                let orig_val = (i * 4 + j) as f64;
643                assert!(
644                    (recovered.iter().nth(idx).unwrap().re - orig_val).abs() < 1e-9,
645                    "mismatch at ({i},{j})"
646                );
647            }
648        }
649    }
650
651    // --- Normalization mode tests ---
652
653    #[test]
654    fn fft_ifft_ortho_roundtrip() {
655        // Ortho normalization: both forward and inverse scale by 1/sqrt(n)
656        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
657        let a = make_1d(data.clone());
658        let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
659        let recovered = ifft(&spectrum, None, None, FftNorm::Ortho).unwrap();
660        for (orig, rec) in data.iter().zip(recovered.iter()) {
661            assert!((orig.re - rec.re).abs() < 1e-10);
662            assert!((orig.im - rec.im).abs() < 1e-10);
663        }
664    }
665
666    #[test]
667    fn fft_ifft_forward_roundtrip() {
668        // Forward normalization: forward scales by 1/n, inverse no scaling
669        let data = vec![c(1.0, 2.0), c(-1.0, 0.5), c(3.0, -1.0), c(0.0, 0.0)];
670        let a = make_1d(data.clone());
671        let spectrum = fft(&a, None, None, FftNorm::Forward).unwrap();
672        let recovered = ifft(&spectrum, None, None, FftNorm::Forward).unwrap();
673        for (orig, rec) in data.iter().zip(recovered.iter()) {
674            assert!((orig.re - rec.re).abs() < 1e-10);
675            assert!((orig.im - rec.im).abs() < 1e-10);
676        }
677    }
678
679    #[test]
680    fn fft_ortho_energy_preservation() {
681        // Parseval's theorem: ||x||^2 = ||X||^2 for ortho normalization
682        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
683        let a = make_1d(data.clone());
684        let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
685
686        let energy_time: f64 = data.iter().map(|x| x.re * x.re + x.im * x.im).sum();
687        let energy_freq: f64 = spectrum.iter().map(|x| x.re * x.re + x.im * x.im).sum();
688        assert!(
689            (energy_time - energy_freq).abs() < 1e-10,
690            "Parseval: time={energy_time}, freq={energy_freq}"
691        );
692    }
693
694    #[test]
695    fn fft_forward_scaling() {
696        // Forward normalization: FFT of constant [1,1,1,1] should be [1, 0, 0, 0]
697        // (divided by n=4, so DC = 4/4 = 1)
698        let a = make_1d(vec![c(1.0, 0.0); 4]);
699        let result = fft(&a, None, None, FftNorm::Forward).unwrap();
700        let vals: Vec<_> = result.iter().copied().collect();
701        assert!((vals[0].re - 1.0).abs() < 1e-12);
702        for v in &vals[1..] {
703            assert!(v.re.abs() < 1e-12);
704            assert!(v.im.abs() < 1e-12);
705        }
706    }
707
708    // --- f32 generic path (#426) ---
709
710    #[test]
711    fn fft_ifft_f32_roundtrip() {
712        // AC-426: The FFT functions must work for f32 as well as f64.
713        let data: Vec<Complex<f32>> = (0..16)
714            .map(|i| Complex::new(i as f32 * 0.25, (i as f32).sin()))
715            .collect();
716        let a = Array::from_vec(Ix1::new([16]), data.clone()).unwrap();
717        let spectrum = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
718        assert_eq!(spectrum.shape(), &[16]);
719        let recovered = ifft::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
720        for (orig, rec) in data.iter().zip(recovered.iter()) {
721            assert!(
722                (orig.re - rec.re).abs() < 1e-4,
723                "f32 re mismatch: {} vs {}",
724                orig.re,
725                rec.re
726            );
727            assert!(
728                (orig.im - rec.im).abs() < 1e-4,
729                "f32 im mismatch: {} vs {}",
730                orig.im,
731                rec.im
732            );
733        }
734    }
735
736    #[test]
737    fn fft_f32_impulse() {
738        // FFT of [1, 0, 0, 0] in f32 = [1, 1, 1, 1]
739        let data = vec![
740            Complex::<f32>::new(1.0, 0.0),
741            Complex::<f32>::new(0.0, 0.0),
742            Complex::<f32>::new(0.0, 0.0),
743            Complex::<f32>::new(0.0, 0.0),
744        ];
745        let a = Array::from_vec(Ix1::new([4]), data).unwrap();
746        let result = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
747        for val in result.iter() {
748            assert!((val.re - 1.0).abs() < 1e-6);
749            assert!(val.im.abs() < 1e-6);
750        }
751    }
752
753    #[test]
754    fn fft2_f32_roundtrip() {
755        use ferray_core::dimension::Ix2;
756        let data: Vec<Complex<f32>> = (0..16)
757            .map(|i| Complex::new(i as f32, -(i as f32) * 0.25))
758            .collect();
759        let a = Array::from_vec(Ix2::new([4, 4]), data.clone()).unwrap();
760        let spectrum = fft2::<f32, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
761        let recovered = ifft2::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
762        for (o, r) in data.iter().zip(recovered.iter()) {
763            assert!((o.re - r.re).abs() < 1e-4);
764            assert!((o.im - r.im).abs() < 1e-4);
765        }
766    }
767
768    // --- Real-input convenience wrappers (#427) ---
769
770    #[test]
771    fn fft_real_ifft_real_roundtrip_f64() {
772        // AC-427: Real array should be transformable without manual promotion to Complex.
773        let original = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
774        let a = Array::<f64, Ix1>::from_vec(Ix1::new([8]), original.clone()).unwrap();
775        let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
776        assert_eq!(spectrum.shape(), &[8]);
777        let recovered = ifft_real::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
778        for (o, r) in original.iter().zip(recovered.iter()) {
779            assert!((o - r).abs() < 1e-10, "mismatch: {} vs {}", o, r);
780        }
781    }
782
783    #[test]
784    fn fft_real_ifft_real_roundtrip_f32() {
785        // Same real-input convenience, but on f32 to verify both layers work together.
786        let original: Vec<f32> = (0..16).map(|i| i as f32 * 0.5 - 2.0).collect();
787        let a = Array::<f32, Ix1>::from_vec(Ix1::new([16]), original.clone()).unwrap();
788        let spectrum = fft_real::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
789        let recovered = ifft_real::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
790        for (o, r) in original.iter().zip(recovered.iter()) {
791            assert!((o - r).abs() < 1e-4, "f32 mismatch: {} vs {}", o, r);
792        }
793    }
794
795    #[test]
796    fn fft_real_dc_component() {
797        // FFT of real constant [1,1,1,1] should have DC = 4, rest zero.
798        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
799        let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
800        let vals: Vec<_> = spectrum.iter().copied().collect();
801        assert!((vals[0].re - 4.0).abs() < 1e-12);
802        assert!(vals[0].im.abs() < 1e-12);
803        for v in &vals[1..] {
804            assert!(v.re.abs() < 1e-12);
805            assert!(v.im.abs() < 1e-12);
806        }
807    }
808
809    #[test]
810    fn fft_real2_roundtrip() {
811        use ferray_core::dimension::Ix2;
812        let data: Vec<f64> = (0..12).map(|i| i as f64 * 0.3).collect();
813        let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
814        let spectrum = fft_real2::<f64, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
815        assert_eq!(spectrum.shape(), &[3, 4]);
816        let recovered = ifft2::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
817        for (o, r) in data.iter().zip(recovered.iter()) {
818            assert!((o - r.re).abs() < 1e-10);
819            assert!(r.im.abs() < 1e-10);
820        }
821    }
822
823    #[test]
824    fn fft_realn_3d_roundtrip() {
825        use ferray_core::dimension::Ix3;
826        let data: Vec<f64> = (0..24).map(|i| (i as f64).sin()).collect();
827        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
828        let spectrum = fft_realn::<f64, Ix3>(&a, None, None, FftNorm::Backward).unwrap();
829        assert_eq!(spectrum.shape(), &[2, 3, 4]);
830        let recovered = ifftn::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
831        for (o, r) in data.iter().zip(recovered.iter()) {
832            assert!((o - r.re).abs() < 1e-10);
833            assert!(r.im.abs() < 1e-10);
834        }
835    }
836}