Skip to main content

ferray_fft/
complex.rs

1// ferray-fft: Complex FFTs — fft, ifft, fft2, ifft2, fftn, ifftn (REQ-1..REQ-4)
2
3use num_complex::Complex;
4
5use ferray_core::Array;
6use ferray_core::dimension::{Dimension, IxDyn};
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use crate::nd::{fft_1d_along_axis, fft_along_axes};
10use crate::norm::FftNorm;
11
12// ---------------------------------------------------------------------------
13// Helpers to convert input to Complex<f64> flat data
14// ---------------------------------------------------------------------------
15
16/// Convert an Array<Complex<f64>, D> to flat row-major data.
17fn to_complex_flat<D: Dimension>(a: &Array<Complex<f64>, D>) -> Vec<Complex<f64>> {
18    a.iter().copied().collect()
19}
20
21/// Resolve an axis parameter: if None, use the last axis.
22fn resolve_axis(ndim: usize, axis: Option<usize>) -> FerrayResult<usize> {
23    match axis {
24        Some(ax) => {
25            if ax >= ndim {
26                Err(FerrayError::axis_out_of_bounds(ax, ndim))
27            } else {
28                Ok(ax)
29            }
30        }
31        None => {
32            if ndim == 0 {
33                Err(FerrayError::invalid_value(
34                    "cannot compute FFT on a 0-dimensional array",
35                ))
36            } else {
37                Ok(ndim - 1)
38            }
39        }
40    }
41}
42
43/// Resolve axes for multi-dimensional FFT. If None, use all axes.
44fn resolve_axes(ndim: usize, axes: Option<&[usize]>) -> FerrayResult<Vec<usize>> {
45    match axes {
46        Some(ax) => {
47            for &a in ax {
48                if a >= ndim {
49                    return Err(FerrayError::axis_out_of_bounds(a, ndim));
50                }
51            }
52            Ok(ax.to_vec())
53        }
54        None => Ok((0..ndim).collect()),
55    }
56}
57
58/// Resolve shapes for multi-dimensional FFT. If None, use the input shape.
59fn resolve_shapes(
60    input_shape: &[usize],
61    axes: &[usize],
62    s: Option<&[usize]>,
63) -> FerrayResult<Vec<Option<usize>>> {
64    match s {
65        Some(sizes) => {
66            if sizes.len() != axes.len() {
67                return Err(FerrayError::invalid_value(format!(
68                    "shape parameter length {} does not match axes length {}",
69                    sizes.len(),
70                    axes.len(),
71                )));
72            }
73            Ok(sizes.iter().map(|&sz| Some(sz)).collect())
74        }
75        None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
76    }
77}
78
79// ---------------------------------------------------------------------------
80// 1-D FFT (REQ-1)
81// ---------------------------------------------------------------------------
82
83/// Compute the one-dimensional discrete Fourier Transform.
84///
85/// Analogous to `numpy.fft.fft`. The input must be a complex array.
86///
87/// # Parameters
88/// - `a`: Input complex array of any dimensionality.
89/// - `n`: Length of the transformed axis. If `None`, uses the length of
90///   the input along `axis`. If shorter, the input is truncated. If longer,
91///   it is zero-padded.
92/// - `axis`: Axis along which to compute the FFT. Defaults to the last axis.
93/// - `norm`: Normalization mode. Defaults to `FftNorm::Backward`.
94///
95/// # Errors
96/// Returns an error if `axis` is out of bounds or `n` is 0.
97pub fn fft<D: Dimension>(
98    a: &Array<Complex<f64>, D>,
99    n: Option<usize>,
100    axis: Option<usize>,
101    norm: FftNorm,
102) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
103    let shape = a.shape().to_vec();
104    let ndim = shape.len();
105    let ax = resolve_axis(ndim, axis)?;
106    let data = to_complex_flat(a);
107
108    let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, false, norm)?;
109
110    Array::from_vec(IxDyn::new(&new_shape), result)
111}
112
113/// Compute the one-dimensional inverse discrete Fourier Transform.
114///
115/// Analogous to `numpy.fft.ifft`.
116///
117/// # Parameters
118/// - `a`: Input complex array.
119/// - `n`: Length of the transformed axis. Defaults to the input length.
120/// - `axis`: Axis along which to compute. Defaults to the last axis.
121/// - `norm`: Normalization mode. Defaults to `FftNorm::Backward` (divides by `n`).
122///
123/// # Errors
124/// Returns an error if `axis` is out of bounds or `n` is 0.
125pub fn ifft<D: Dimension>(
126    a: &Array<Complex<f64>, D>,
127    n: Option<usize>,
128    axis: Option<usize>,
129    norm: FftNorm,
130) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
131    let shape = a.shape().to_vec();
132    let ndim = shape.len();
133    let ax = resolve_axis(ndim, axis)?;
134    let data = to_complex_flat(a);
135
136    let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, true, norm)?;
137
138    Array::from_vec(IxDyn::new(&new_shape), result)
139}
140
141// ---------------------------------------------------------------------------
142// 2-D FFT (REQ-3)
143// ---------------------------------------------------------------------------
144
145/// Compute the 2-dimensional discrete Fourier Transform.
146///
147/// Analogous to `numpy.fft.fft2`. Equivalent to calling `fftn` with
148/// `axes = [-2, -1]` (the last two axes).
149///
150/// # Parameters
151/// - `a`: Input complex array (must have at least 2 dimensions).
152/// - `s`: Shape `(n_rows, n_cols)` of the output along the transform axes.
153///   If `None`, uses the input shape.
154/// - `axes`: Axes over which to compute the FFT. Defaults to the last 2 axes.
155/// - `norm`: Normalization mode.
156///
157/// # Errors
158/// Returns an error if the array has fewer than 2 dimensions, axes are
159/// out of bounds, or shape parameters are invalid.
160pub fn fft2<D: Dimension>(
161    a: &Array<Complex<f64>, D>,
162    s: Option<&[usize]>,
163    axes: Option<&[usize]>,
164    norm: FftNorm,
165) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
166    let ndim = a.shape().len();
167    let axes = match axes {
168        Some(ax) => ax.to_vec(),
169        None => {
170            if ndim < 2 {
171                return Err(FerrayError::invalid_value(
172                    "fft2 requires at least 2 dimensions",
173                ));
174            }
175            vec![ndim - 2, ndim - 1]
176        }
177    };
178    fftn_impl(a, s, &axes, false, norm)
179}
180
181/// Compute the 2-dimensional inverse discrete Fourier Transform.
182///
183/// Analogous to `numpy.fft.ifft2`.
184///
185/// # Parameters
186/// - `a`: Input complex array (must have at least 2 dimensions).
187/// - `s`: Output shape along the transform axes. If `None`, uses input shape.
188/// - `axes`: Axes over which to compute. Defaults to the last 2 axes.
189/// - `norm`: Normalization mode.
190///
191/// # Errors
192/// Returns an error if the array has fewer than 2 dimensions, axes are
193/// out of bounds, or shape parameters are invalid.
194pub fn ifft2<D: Dimension>(
195    a: &Array<Complex<f64>, D>,
196    s: Option<&[usize]>,
197    axes: Option<&[usize]>,
198    norm: FftNorm,
199) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
200    let ndim = a.shape().len();
201    let axes = match axes {
202        Some(ax) => ax.to_vec(),
203        None => {
204            if ndim < 2 {
205                return Err(FerrayError::invalid_value(
206                    "ifft2 requires at least 2 dimensions",
207                ));
208            }
209            vec![ndim - 2, ndim - 1]
210        }
211    };
212    fftn_impl(a, s, &axes, true, norm)
213}
214
215// ---------------------------------------------------------------------------
216// N-D FFT (REQ-4)
217// ---------------------------------------------------------------------------
218
219/// Compute the N-dimensional discrete Fourier Transform.
220///
221/// Analogous to `numpy.fft.fftn`. Transforms along each of the specified
222/// axes in sequence.
223///
224/// # Parameters
225/// - `a`: Input complex array.
226/// - `s`: Shape of the output along each transform axis. If `None`,
227///   uses the input shape.
228/// - `axes`: Axes over which to compute. If `None`, uses all axes.
229/// - `norm`: Normalization mode.
230///
231/// # Errors
232/// Returns an error if axes are out of bounds or shape parameters
233/// are inconsistent.
234pub fn fftn<D: Dimension>(
235    a: &Array<Complex<f64>, D>,
236    s: Option<&[usize]>,
237    axes: Option<&[usize]>,
238    norm: FftNorm,
239) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
240    let ax = resolve_axes(a.shape().len(), axes)?;
241    fftn_impl(a, s, &ax, false, norm)
242}
243
244/// Compute the N-dimensional inverse discrete Fourier Transform.
245///
246/// Analogous to `numpy.fft.ifftn`.
247///
248/// # Parameters
249/// - `a`: Input complex array.
250/// - `s`: Shape of the output along each transform axis. If `None`,
251///   uses the input shape.
252/// - `axes`: Axes over which to compute. If `None`, uses all axes.
253/// - `norm`: Normalization mode.
254///
255/// # Errors
256/// Returns an error if axes are out of bounds or shape parameters
257/// are inconsistent.
258pub fn ifftn<D: Dimension>(
259    a: &Array<Complex<f64>, D>,
260    s: Option<&[usize]>,
261    axes: Option<&[usize]>,
262    norm: FftNorm,
263) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
264    let ax = resolve_axes(a.shape().len(), axes)?;
265    fftn_impl(a, s, &ax, true, norm)
266}
267
268// ---------------------------------------------------------------------------
269// Internal N-D implementation
270// ---------------------------------------------------------------------------
271
272fn fftn_impl<D: Dimension>(
273    a: &Array<Complex<f64>, D>,
274    s: Option<&[usize]>,
275    axes: &[usize],
276    inverse: bool,
277    norm: FftNorm,
278) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
279    let shape = a.shape().to_vec();
280    let sizes = resolve_shapes(&shape, axes, s)?;
281    let data = to_complex_flat(a);
282
283    let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
284
285    let (new_shape, result) = fft_along_axes(&data, &shape, &axes_and_sizes, inverse, norm)?;
286
287    Array::from_vec(IxDyn::new(&new_shape), result)
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use ferray_core::dimension::Ix1;
294
295    fn c(re: f64, im: f64) -> Complex<f64> {
296        Complex::new(re, im)
297    }
298
299    fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
300        let n = data.len();
301        Array::from_vec(Ix1::new([n]), data).unwrap()
302    }
303
304    #[test]
305    fn fft_impulse() {
306        // FFT of [1, 0, 0, 0] = [1, 1, 1, 1]
307        let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
308        let result = fft(&a, None, None, FftNorm::Backward).unwrap();
309        assert_eq!(result.shape(), &[4]);
310        for val in result.iter() {
311            assert!((val.re - 1.0).abs() < 1e-12);
312            assert!(val.im.abs() < 1e-12);
313        }
314    }
315
316    #[test]
317    fn fft_constant() {
318        // FFT of [1, 1, 1, 1] = [4, 0, 0, 0]
319        let a = make_1d(vec![c(1.0, 0.0); 4]);
320        let result = fft(&a, None, None, FftNorm::Backward).unwrap();
321        let vals: Vec<_> = result.iter().copied().collect();
322        assert!((vals[0].re - 4.0).abs() < 1e-12);
323        for v in &vals[1..] {
324            assert!(v.re.abs() < 1e-12);
325            assert!(v.im.abs() < 1e-12);
326        }
327    }
328
329    #[test]
330    fn fft_ifft_roundtrip() {
331        // AC-1: fft(ifft(a)) roundtrips to within 4 ULPs for complex f64
332        let data = vec![
333            c(1.0, 2.0),
334            c(-1.0, 0.5),
335            c(3.0, -1.0),
336            c(0.0, 0.0),
337            c(-2.5, 1.5),
338            c(0.7, -0.3),
339            c(1.2, 0.8),
340            c(-0.4, 2.1),
341        ];
342        let a = make_1d(data.clone());
343        let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
344        let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
345        for (orig, rec) in data.iter().zip(recovered.iter()) {
346            assert!(
347                (orig.re - rec.re).abs() < 1e-10,
348                "re mismatch: {} vs {}",
349                orig.re,
350                rec.re
351            );
352            assert!(
353                (orig.im - rec.im).abs() < 1e-10,
354                "im mismatch: {} vs {}",
355                orig.im,
356                rec.im
357            );
358        }
359    }
360
361    #[test]
362    fn fft_with_n_padding() {
363        // Pad [1, 1] to length 4 -> FFT of [1, 1, 0, 0]
364        let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
365        let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
366        assert_eq!(result.shape(), &[4]);
367        let vals: Vec<_> = result.iter().copied().collect();
368        assert!((vals[0].re - 2.0).abs() < 1e-12);
369    }
370
371    #[test]
372    fn fft_with_n_truncation() {
373        // Truncate [1, 2, 3, 4] to length 2 -> FFT of [1, 2]
374        let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
375        let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
376        assert_eq!(result.shape(), &[2]);
377        let vals: Vec<_> = result.iter().copied().collect();
378        // FFT of [1, 2] = [3, -1]
379        assert!((vals[0].re - 3.0).abs() < 1e-12);
380        assert!((vals[1].re - (-1.0)).abs() < 1e-12);
381    }
382
383    #[test]
384    fn fft_non_power_of_two() {
385        // AC-2 partial: test non-power-of-2 length
386        let n = 7;
387        let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
388        let a = make_1d(data.clone());
389        let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
390        let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
391        for (orig, rec) in data.iter().zip(recovered.iter()) {
392            assert!((orig.re - rec.re).abs() < 1e-10);
393            assert!((orig.im - rec.im).abs() < 1e-10);
394        }
395    }
396
397    #[test]
398    fn fft2_basic() {
399        use ferray_core::dimension::Ix2;
400        let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
401        let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
402        let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
403        assert_eq!(result.shape(), &[2, 2]);
404
405        let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
406        let orig: Vec<_> = a.iter().copied().collect();
407        for (o, r) in orig.iter().zip(recovered.iter()) {
408            assert!((o.re - r.re).abs() < 1e-10);
409            assert!((o.im - r.im).abs() < 1e-10);
410        }
411    }
412
413    #[test]
414    fn fftn_roundtrip_3d() {
415        use ferray_core::dimension::Ix3;
416        let n = 2 * 3 * 4;
417        let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
418        let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
419        let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
420        let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
421        for (o, r) in data.iter().zip(recovered.iter()) {
422            assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
423            assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
424        }
425    }
426
427    #[test]
428    fn fft_axis_out_of_bounds() {
429        let a = make_1d(vec![c(1.0, 0.0)]);
430        assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
431    }
432}