1use num_complex::Complex;
4
5use ferray_core::Array;
6use ferray_core::dimension::{Dimension, IxDyn};
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use crate::nd::{fft_1d_along_axis, fft_along_axes};
10use crate::norm::FftNorm;
11
12fn to_complex_flat<D: Dimension>(a: &Array<Complex<f64>, D>) -> Vec<Complex<f64>> {
18 a.iter().copied().collect()
19}
20
21fn resolve_axis(ndim: usize, axis: Option<usize>) -> FerrayResult<usize> {
23 match axis {
24 Some(ax) => {
25 if ax >= ndim {
26 Err(FerrayError::axis_out_of_bounds(ax, ndim))
27 } else {
28 Ok(ax)
29 }
30 }
31 None => {
32 if ndim == 0 {
33 Err(FerrayError::invalid_value(
34 "cannot compute FFT on a 0-dimensional array",
35 ))
36 } else {
37 Ok(ndim - 1)
38 }
39 }
40 }
41}
42
43fn resolve_axes(ndim: usize, axes: Option<&[usize]>) -> FerrayResult<Vec<usize>> {
45 match axes {
46 Some(ax) => {
47 for &a in ax {
48 if a >= ndim {
49 return Err(FerrayError::axis_out_of_bounds(a, ndim));
50 }
51 }
52 Ok(ax.to_vec())
53 }
54 None => Ok((0..ndim).collect()),
55 }
56}
57
58fn resolve_shapes(
60 input_shape: &[usize],
61 axes: &[usize],
62 s: Option<&[usize]>,
63) -> FerrayResult<Vec<Option<usize>>> {
64 match s {
65 Some(sizes) => {
66 if sizes.len() != axes.len() {
67 return Err(FerrayError::invalid_value(format!(
68 "shape parameter length {} does not match axes length {}",
69 sizes.len(),
70 axes.len(),
71 )));
72 }
73 Ok(sizes.iter().map(|&sz| Some(sz)).collect())
74 }
75 None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
76 }
77}
78
79pub fn fft<D: Dimension>(
98 a: &Array<Complex<f64>, D>,
99 n: Option<usize>,
100 axis: Option<usize>,
101 norm: FftNorm,
102) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
103 let shape = a.shape().to_vec();
104 let ndim = shape.len();
105 let ax = resolve_axis(ndim, axis)?;
106 let data = to_complex_flat(a);
107
108 let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, false, norm)?;
109
110 Array::from_vec(IxDyn::new(&new_shape), result)
111}
112
113pub fn ifft<D: Dimension>(
126 a: &Array<Complex<f64>, D>,
127 n: Option<usize>,
128 axis: Option<usize>,
129 norm: FftNorm,
130) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
131 let shape = a.shape().to_vec();
132 let ndim = shape.len();
133 let ax = resolve_axis(ndim, axis)?;
134 let data = to_complex_flat(a);
135
136 let (new_shape, result) = fft_1d_along_axis(&data, &shape, ax, n, true, norm)?;
137
138 Array::from_vec(IxDyn::new(&new_shape), result)
139}
140
141pub fn fft2<D: Dimension>(
161 a: &Array<Complex<f64>, D>,
162 s: Option<&[usize]>,
163 axes: Option<&[usize]>,
164 norm: FftNorm,
165) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
166 let ndim = a.shape().len();
167 let axes = match axes {
168 Some(ax) => ax.to_vec(),
169 None => {
170 if ndim < 2 {
171 return Err(FerrayError::invalid_value(
172 "fft2 requires at least 2 dimensions",
173 ));
174 }
175 vec![ndim - 2, ndim - 1]
176 }
177 };
178 fftn_impl(a, s, &axes, false, norm)
179}
180
181pub fn ifft2<D: Dimension>(
195 a: &Array<Complex<f64>, D>,
196 s: Option<&[usize]>,
197 axes: Option<&[usize]>,
198 norm: FftNorm,
199) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
200 let ndim = a.shape().len();
201 let axes = match axes {
202 Some(ax) => ax.to_vec(),
203 None => {
204 if ndim < 2 {
205 return Err(FerrayError::invalid_value(
206 "ifft2 requires at least 2 dimensions",
207 ));
208 }
209 vec![ndim - 2, ndim - 1]
210 }
211 };
212 fftn_impl(a, s, &axes, true, norm)
213}
214
215pub fn fftn<D: Dimension>(
235 a: &Array<Complex<f64>, D>,
236 s: Option<&[usize]>,
237 axes: Option<&[usize]>,
238 norm: FftNorm,
239) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
240 let ax = resolve_axes(a.shape().len(), axes)?;
241 fftn_impl(a, s, &ax, false, norm)
242}
243
244pub fn ifftn<D: Dimension>(
259 a: &Array<Complex<f64>, D>,
260 s: Option<&[usize]>,
261 axes: Option<&[usize]>,
262 norm: FftNorm,
263) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
264 let ax = resolve_axes(a.shape().len(), axes)?;
265 fftn_impl(a, s, &ax, true, norm)
266}
267
268fn fftn_impl<D: Dimension>(
273 a: &Array<Complex<f64>, D>,
274 s: Option<&[usize]>,
275 axes: &[usize],
276 inverse: bool,
277 norm: FftNorm,
278) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
279 let shape = a.shape().to_vec();
280 let sizes = resolve_shapes(&shape, axes, s)?;
281 let data = to_complex_flat(a);
282
283 let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
284
285 let (new_shape, result) = fft_along_axes(&data, &shape, &axes_and_sizes, inverse, norm)?;
286
287 Array::from_vec(IxDyn::new(&new_shape), result)
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use ferray_core::dimension::Ix1;
294
295 fn c(re: f64, im: f64) -> Complex<f64> {
296 Complex::new(re, im)
297 }
298
299 fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
300 let n = data.len();
301 Array::from_vec(Ix1::new([n]), data).unwrap()
302 }
303
304 #[test]
305 fn fft_impulse() {
306 let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
308 let result = fft(&a, None, None, FftNorm::Backward).unwrap();
309 assert_eq!(result.shape(), &[4]);
310 for val in result.iter() {
311 assert!((val.re - 1.0).abs() < 1e-12);
312 assert!(val.im.abs() < 1e-12);
313 }
314 }
315
316 #[test]
317 fn fft_constant() {
318 let a = make_1d(vec![c(1.0, 0.0); 4]);
320 let result = fft(&a, None, None, FftNorm::Backward).unwrap();
321 let vals: Vec<_> = result.iter().copied().collect();
322 assert!((vals[0].re - 4.0).abs() < 1e-12);
323 for v in &vals[1..] {
324 assert!(v.re.abs() < 1e-12);
325 assert!(v.im.abs() < 1e-12);
326 }
327 }
328
329 #[test]
330 fn fft_ifft_roundtrip() {
331 let data = vec![
333 c(1.0, 2.0),
334 c(-1.0, 0.5),
335 c(3.0, -1.0),
336 c(0.0, 0.0),
337 c(-2.5, 1.5),
338 c(0.7, -0.3),
339 c(1.2, 0.8),
340 c(-0.4, 2.1),
341 ];
342 let a = make_1d(data.clone());
343 let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
344 let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
345 for (orig, rec) in data.iter().zip(recovered.iter()) {
346 assert!(
347 (orig.re - rec.re).abs() < 1e-10,
348 "re mismatch: {} vs {}",
349 orig.re,
350 rec.re
351 );
352 assert!(
353 (orig.im - rec.im).abs() < 1e-10,
354 "im mismatch: {} vs {}",
355 orig.im,
356 rec.im
357 );
358 }
359 }
360
361 #[test]
362 fn fft_with_n_padding() {
363 let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
365 let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
366 assert_eq!(result.shape(), &[4]);
367 let vals: Vec<_> = result.iter().copied().collect();
368 assert!((vals[0].re - 2.0).abs() < 1e-12);
369 }
370
371 #[test]
372 fn fft_with_n_truncation() {
373 let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
375 let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
376 assert_eq!(result.shape(), &[2]);
377 let vals: Vec<_> = result.iter().copied().collect();
378 assert!((vals[0].re - 3.0).abs() < 1e-12);
380 assert!((vals[1].re - (-1.0)).abs() < 1e-12);
381 }
382
383 #[test]
384 fn fft_non_power_of_two() {
385 let n = 7;
387 let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
388 let a = make_1d(data.clone());
389 let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
390 let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
391 for (orig, rec) in data.iter().zip(recovered.iter()) {
392 assert!((orig.re - rec.re).abs() < 1e-10);
393 assert!((orig.im - rec.im).abs() < 1e-10);
394 }
395 }
396
397 #[test]
398 fn fft2_basic() {
399 use ferray_core::dimension::Ix2;
400 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
401 let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
402 let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
403 assert_eq!(result.shape(), &[2, 2]);
404
405 let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
406 let orig: Vec<_> = a.iter().copied().collect();
407 for (o, r) in orig.iter().zip(recovered.iter()) {
408 assert!((o.re - r.re).abs() < 1e-10);
409 assert!((o.im - r.im).abs() < 1e-10);
410 }
411 }
412
413 #[test]
414 fn fftn_roundtrip_3d() {
415 use ferray_core::dimension::Ix3;
416 let n = 2 * 3 * 4;
417 let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
418 let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
419 let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
420 let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
421 for (o, r) in data.iter().zip(recovered.iter()) {
422 assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
423 assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
424 }
425 }
426
427 #[test]
428 fn fft_axis_out_of_bounds() {
429 let a = make_1d(vec![c(1.0, 0.0)]);
430 assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
431 }
432}