1use num_complex::Complex;
7
8use ferray_core::Array;
9use ferray_core::dimension::{Dimension, IxDyn};
10use ferray_core::error::{FerrayError, FerrayResult};
11
12use crate::axes::{resolve_axes, resolve_axis};
13use crate::float::FftFloat;
14use crate::nd::{fft_along_axes, fft_along_axis};
15use crate::norm::FftNorm;
16
17enum ComplexData<'a, T: FftFloat>
24where
25 Complex<T>: ferray_core::Element,
26{
27 Borrowed(&'a [Complex<T>]),
28 Owned(Vec<Complex<T>>),
29}
30
31impl<T: FftFloat> std::ops::Deref for ComplexData<'_, T>
32where
33 Complex<T>: ferray_core::Element,
34{
35 type Target = [Complex<T>];
36 fn deref(&self) -> &[Complex<T>] {
37 match self {
38 ComplexData::Borrowed(s) => s,
39 ComplexData::Owned(v) => v,
40 }
41 }
42}
43
44fn borrow_complex_flat<T: FftFloat, D: Dimension>(a: &Array<Complex<T>, D>) -> ComplexData<'_, T>
45where
46 Complex<T>: ferray_core::Element,
47{
48 if let Some(s) = a.as_slice() {
49 ComplexData::Borrowed(s)
50 } else {
51 ComplexData::Owned(a.iter().copied().collect())
52 }
53}
54
55fn resolve_shapes(
57 input_shape: &[usize],
58 axes: &[usize],
59 s: Option<&[usize]>,
60) -> FerrayResult<Vec<Option<usize>>> {
61 match s {
62 Some(sizes) => {
63 if sizes.len() != axes.len() {
64 return Err(FerrayError::invalid_value(format!(
65 "shape parameter length {} does not match axes length {}",
66 sizes.len(),
67 axes.len(),
68 )));
69 }
70 Ok(sizes.iter().map(|&sz| Some(sz)).collect())
71 }
72 None => Ok(axes.iter().map(|&ax| Some(input_shape[ax])).collect()),
73 }
74}
75
76pub fn fft<T: FftFloat, D: Dimension>(
95 a: &Array<Complex<T>, D>,
96 n: Option<usize>,
97 axis: Option<isize>,
98 norm: FftNorm,
99) -> FerrayResult<Array<Complex<T>, IxDyn>>
100where
101 Complex<T>: ferray_core::Element,
102{
103 let shape = a.shape().to_vec();
104 let ndim = shape.len();
105 let ax = resolve_axis(ndim, axis)?;
106 let data = borrow_complex_flat(a);
107
108 let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, false, norm)?;
109
110 Array::from_vec(IxDyn::new(&new_shape), result)
111}
112
113pub fn ifft<T: FftFloat, D: Dimension>(
126 a: &Array<Complex<T>, D>,
127 n: Option<usize>,
128 axis: Option<isize>,
129 norm: FftNorm,
130) -> FerrayResult<Array<Complex<T>, IxDyn>>
131where
132 Complex<T>: ferray_core::Element,
133{
134 let shape = a.shape().to_vec();
135 let ndim = shape.len();
136 let ax = resolve_axis(ndim, axis)?;
137 let data = borrow_complex_flat(a);
138
139 let (new_shape, result) = fft_along_axis::<T>(&data, &shape, ax, n, true, norm)?;
140
141 Array::from_vec(IxDyn::new(&new_shape), result)
142}
143
144pub fn fft2<T: FftFloat, D: Dimension>(
164 a: &Array<Complex<T>, D>,
165 s: Option<&[usize]>,
166 axes: Option<&[isize]>,
167 norm: FftNorm,
168) -> FerrayResult<Array<Complex<T>, IxDyn>>
169where
170 Complex<T>: ferray_core::Element,
171{
172 let ndim = a.shape().len();
173 let axes = match axes {
174 Some(ax) => resolve_axes(ndim, Some(ax))?,
175 None => {
176 if ndim < 2 {
177 return Err(FerrayError::invalid_value(
178 "fft2 requires at least 2 dimensions",
179 ));
180 }
181 vec![ndim - 2, ndim - 1]
182 }
183 };
184 fftn_impl::<T, D>(a, s, &axes, false, norm)
185}
186
187pub fn ifft2<T: FftFloat, D: Dimension>(
201 a: &Array<Complex<T>, D>,
202 s: Option<&[usize]>,
203 axes: Option<&[isize]>,
204 norm: FftNorm,
205) -> FerrayResult<Array<Complex<T>, IxDyn>>
206where
207 Complex<T>: ferray_core::Element,
208{
209 let ndim = a.shape().len();
210 let axes = match axes {
211 Some(ax) => resolve_axes(ndim, Some(ax))?,
212 None => {
213 if ndim < 2 {
214 return Err(FerrayError::invalid_value(
215 "ifft2 requires at least 2 dimensions",
216 ));
217 }
218 vec![ndim - 2, ndim - 1]
219 }
220 };
221 fftn_impl::<T, D>(a, s, &axes, true, norm)
222}
223
224pub fn fftn<T: FftFloat, D: Dimension>(
244 a: &Array<Complex<T>, D>,
245 s: Option<&[usize]>,
246 axes: Option<&[isize]>,
247 norm: FftNorm,
248) -> FerrayResult<Array<Complex<T>, IxDyn>>
249where
250 Complex<T>: ferray_core::Element,
251{
252 let ax = resolve_axes(a.shape().len(), axes)?;
253 fftn_impl::<T, D>(a, s, &ax, false, norm)
254}
255
256pub fn ifftn<T: FftFloat, D: Dimension>(
271 a: &Array<Complex<T>, D>,
272 s: Option<&[usize]>,
273 axes: Option<&[isize]>,
274 norm: FftNorm,
275) -> FerrayResult<Array<Complex<T>, IxDyn>>
276where
277 Complex<T>: ferray_core::Element,
278{
279 let ax = resolve_axes(a.shape().len(), axes)?;
280 fftn_impl::<T, D>(a, s, &ax, true, norm)
281}
282
283fn real_to_complex_vec<T: FftFloat, D: Dimension>(a: &Array<T, D>) -> Vec<Complex<T>>
299where
300 Complex<T>: ferray_core::Element,
301{
302 a.iter()
303 .map(|&v| Complex::new(v, <T as num_traits::Zero>::zero()))
304 .collect()
305}
306
307pub fn fft_real<T: FftFloat, D: Dimension>(
316 a: &Array<T, D>,
317 n: Option<usize>,
318 axis: Option<isize>,
319 norm: FftNorm,
320) -> FerrayResult<Array<Complex<T>, IxDyn>>
321where
322 Complex<T>: ferray_core::Element,
323{
324 let shape = a.shape().to_vec();
325 let complex_data = real_to_complex_vec(a);
326 let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
327 fft::<T, IxDyn>(&complex_arr, n, axis, norm)
328}
329
330pub fn ifft_real<T: FftFloat, D: Dimension>(
340 a: &Array<Complex<T>, D>,
341 n: Option<usize>,
342 axis: Option<isize>,
343 norm: FftNorm,
344) -> FerrayResult<Array<T, IxDyn>>
345where
346 Complex<T>: ferray_core::Element,
347{
348 let spectrum = ifft::<T, D>(a, n, axis, norm)?;
349 let shape = spectrum.shape().to_vec();
350 let real_data: Vec<T> = spectrum.iter().map(|c| c.re).collect();
351 Array::from_vec(IxDyn::new(&shape), real_data)
352}
353
354pub fn fft_real2<T: FftFloat, D: Dimension>(
359 a: &Array<T, D>,
360 s: Option<&[usize]>,
361 axes: Option<&[isize]>,
362 norm: FftNorm,
363) -> FerrayResult<Array<Complex<T>, IxDyn>>
364where
365 Complex<T>: ferray_core::Element,
366{
367 let shape = a.shape().to_vec();
368 let complex_data = real_to_complex_vec(a);
369 let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
370 fft2::<T, IxDyn>(&complex_arr, s, axes, norm)
371}
372
373pub fn fft_realn<T: FftFloat, D: Dimension>(
378 a: &Array<T, D>,
379 s: Option<&[usize]>,
380 axes: Option<&[isize]>,
381 norm: FftNorm,
382) -> FerrayResult<Array<Complex<T>, IxDyn>>
383where
384 Complex<T>: ferray_core::Element,
385{
386 let shape = a.shape().to_vec();
387 let complex_data = real_to_complex_vec(a);
388 let complex_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), complex_data)?;
389 fftn::<T, IxDyn>(&complex_arr, s, axes, norm)
390}
391
392fn fftn_impl<T: FftFloat, D: Dimension>(
397 a: &Array<Complex<T>, D>,
398 s: Option<&[usize]>,
399 axes: &[usize],
400 inverse: bool,
401 norm: FftNorm,
402) -> FerrayResult<Array<Complex<T>, IxDyn>>
403where
404 Complex<T>: ferray_core::Element,
405{
406 let shape = a.shape().to_vec();
407 let sizes = resolve_shapes(&shape, axes, s)?;
408 let data = borrow_complex_flat(a);
409
410 let axes_and_sizes: Vec<(usize, Option<usize>)> = axes.iter().copied().zip(sizes).collect();
411
412 let (new_shape, result) = fft_along_axes::<T>(&data, &shape, &axes_and_sizes, inverse, norm)?;
413
414 Array::from_vec(IxDyn::new(&new_shape), result)
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use ferray_core::dimension::Ix1;
421
422 fn c(re: f64, im: f64) -> Complex<f64> {
423 Complex::new(re, im)
424 }
425
426 fn make_1d(data: Vec<Complex<f64>>) -> Array<Complex<f64>, Ix1> {
427 let n = data.len();
428 Array::from_vec(Ix1::new([n]), data).unwrap()
429 }
430
431 #[test]
432 fn fft_impulse() {
433 let a = make_1d(vec![c(1.0, 0.0), c(0.0, 0.0), c(0.0, 0.0), c(0.0, 0.0)]);
435 let result = fft(&a, None, None, FftNorm::Backward).unwrap();
436 assert_eq!(result.shape(), &[4]);
437 for val in result.iter() {
438 assert!((val.re - 1.0).abs() < 1e-12);
439 assert!(val.im.abs() < 1e-12);
440 }
441 }
442
443 #[test]
444 fn fft_length_one() {
445 let a = make_1d(vec![c(7.0, -2.0)]);
447 let result = fft(&a, None, None, FftNorm::Backward).unwrap();
448 assert_eq!(result.shape(), &[1]);
449 let v = result.iter().next().unwrap();
450 assert!((v.re - 7.0).abs() < 1e-12);
451 assert!((v.im + 2.0).abs() < 1e-12);
452
453 let recovered = ifft(&result, None, None, FftNorm::Backward).unwrap();
455 let r = recovered.iter().next().unwrap();
456 assert!((r.re - 7.0).abs() < 1e-12);
457 assert!((r.im + 2.0).abs() < 1e-12);
458 }
459
460 #[test]
461 fn fft_negative_axis_matches_explicit() {
462 use ferray_core::dimension::Ix2;
464 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
465 let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 2]), data).unwrap();
466 let neg = fft(&a, None, Some(-1), FftNorm::Backward).unwrap();
467 let pos = fft(&a, None, Some(1), FftNorm::Backward).unwrap();
468 assert_eq!(neg.shape(), pos.shape());
469 for (n, p) in neg.iter().zip(pos.iter()) {
470 assert!((n.re - p.re).abs() < 1e-12);
471 assert!((n.im - p.im).abs() < 1e-12);
472 }
473 }
474
475 #[test]
476 fn fftn_negative_axes_matches_explicit() {
477 use ferray_core::dimension::Ix2;
479 let data: Vec<Complex<f64>> = (0..6).map(|i| c(i as f64, 0.0)).collect();
480 let a = Array::<Complex<f64>, Ix2>::from_vec(Ix2::new([2, 3]), data).unwrap();
481 let neg = fftn(&a, None, Some(&[-2, -1][..]), FftNorm::Backward).unwrap();
482 let pos = fftn(&a, None, Some(&[0, 1][..]), FftNorm::Backward).unwrap();
483 for (n, p) in neg.iter().zip(pos.iter()) {
484 assert!((n.re - p.re).abs() < 1e-12);
485 assert!((n.im - p.im).abs() < 1e-12);
486 }
487 }
488
489 #[test]
490 fn fft_constant() {
491 let a = make_1d(vec![c(1.0, 0.0); 4]);
493 let result = fft(&a, None, None, FftNorm::Backward).unwrap();
494 let vals: Vec<_> = result.iter().copied().collect();
495 assert!((vals[0].re - 4.0).abs() < 1e-12);
496 for v in &vals[1..] {
497 assert!(v.re.abs() < 1e-12);
498 assert!(v.im.abs() < 1e-12);
499 }
500 }
501
502 #[test]
503 fn fft_ifft_roundtrip() {
504 let data = vec![
506 c(1.0, 2.0),
507 c(-1.0, 0.5),
508 c(3.0, -1.0),
509 c(0.0, 0.0),
510 c(-2.5, 1.5),
511 c(0.7, -0.3),
512 c(1.2, 0.8),
513 c(-0.4, 2.1),
514 ];
515 let a = make_1d(data.clone());
516 let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
517 let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
518 for (orig, rec) in data.iter().zip(recovered.iter()) {
519 assert!(
520 (orig.re - rec.re).abs() < 1e-10,
521 "re mismatch: {} vs {}",
522 orig.re,
523 rec.re
524 );
525 assert!(
526 (orig.im - rec.im).abs() < 1e-10,
527 "im mismatch: {} vs {}",
528 orig.im,
529 rec.im
530 );
531 }
532 }
533
534 #[test]
535 fn fft_with_n_padding() {
536 let a = make_1d(vec![c(1.0, 0.0), c(1.0, 0.0)]);
538 let result = fft(&a, Some(4), None, FftNorm::Backward).unwrap();
539 assert_eq!(result.shape(), &[4]);
540 let vals: Vec<_> = result.iter().copied().collect();
541 assert!((vals[0].re - 2.0).abs() < 1e-12);
542 }
543
544 #[test]
545 fn fft_with_n_truncation() {
546 let a = make_1d(vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)]);
548 let result = fft(&a, Some(2), None, FftNorm::Backward).unwrap();
549 assert_eq!(result.shape(), &[2]);
550 let vals: Vec<_> = result.iter().copied().collect();
551 assert!((vals[0].re - 3.0).abs() < 1e-12);
553 assert!((vals[1].re - (-1.0)).abs() < 1e-12);
554 }
555
556 #[test]
557 fn fft_non_power_of_two() {
558 let n = 7;
560 let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, 0.0)).collect();
561 let a = make_1d(data.clone());
562 let spectrum = fft(&a, None, None, FftNorm::Backward).unwrap();
563 let recovered = ifft(&spectrum, None, None, FftNorm::Backward).unwrap();
564 for (orig, rec) in data.iter().zip(recovered.iter()) {
565 assert!((orig.re - rec.re).abs() < 1e-10);
566 assert!((orig.im - rec.im).abs() < 1e-10);
567 }
568 }
569
570 #[test]
571 fn fft2_basic() {
572 use ferray_core::dimension::Ix2;
573 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
574 let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
575 let result = fft2(&a, None, None, FftNorm::Backward).unwrap();
576 assert_eq!(result.shape(), &[2, 2]);
577
578 let recovered = ifft2(&result, None, None, FftNorm::Backward).unwrap();
579 let orig: Vec<_> = a.iter().copied().collect();
580 for (o, r) in orig.iter().zip(recovered.iter()) {
581 assert!((o.re - r.re).abs() < 1e-10);
582 assert!((o.im - r.im).abs() < 1e-10);
583 }
584 }
585
586 #[test]
587 fn fftn_roundtrip_3d() {
588 use ferray_core::dimension::Ix3;
589 let n = 2 * 3 * 4;
590 let data: Vec<Complex<f64>> = (0..n).map(|i| c(i as f64, -(i as f64) * 0.5)).collect();
591 let a = Array::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
592 let spectrum = fftn(&a, None, None, FftNorm::Backward).unwrap();
593 let recovered = ifftn(&spectrum, None, None, FftNorm::Backward).unwrap();
594 for (o, r) in data.iter().zip(recovered.iter()) {
595 assert!((o.re - r.re).abs() < 1e-9, "re: {} vs {}", o.re, r.re);
596 assert!((o.im - r.im).abs() < 1e-9, "im: {} vs {}", o.im, r.im);
597 }
598 }
599
600 #[test]
601 fn fft_axis_out_of_bounds() {
602 let a = make_1d(vec![c(1.0, 0.0)]);
603 assert!(fft(&a, None, Some(1), FftNorm::Backward).is_err());
604 }
605
606 #[test]
609 fn fft2_with_shape_padding() {
610 use ferray_core::dimension::Ix2;
611 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
613 let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
614 let result = fft2(&a, Some(&[4, 4]), None, FftNorm::Backward).unwrap();
615 assert_eq!(result.shape(), &[4, 4]);
616 }
617
618 #[test]
619 fn fft2_with_shape_truncation() {
620 use ferray_core::dimension::Ix2;
621 let data: Vec<Complex<f64>> = (0..16).map(|i| c(i as f64, 0.0)).collect();
623 let a = Array::from_vec(Ix2::new([4, 4]), data).unwrap();
624 let result = fft2(&a, Some(&[2, 2]), None, FftNorm::Backward).unwrap();
625 assert_eq!(result.shape(), &[2, 2]);
626 }
627
628 #[test]
629 fn fftn_with_shape_roundtrip() {
630 use ferray_core::dimension::Ix2;
631 let data: Vec<Complex<f64>> = (0..12).map(|i| c(i as f64, 0.0)).collect();
633 let a = Array::from_vec(Ix2::new([3, 4]), data).unwrap();
634 let spectrum = fftn(&a, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
635 assert_eq!(spectrum.shape(), &[4, 8]);
636 let recovered = ifftn(&spectrum, Some(&[4, 8]), None, FftNorm::Backward).unwrap();
637 assert_eq!(recovered.shape(), &[4, 8]);
638 for i in 0..3 {
640 for j in 0..4 {
641 let idx = i * 8 + j;
642 let orig_val = (i * 4 + j) as f64;
643 assert!(
644 (recovered.iter().nth(idx).unwrap().re - orig_val).abs() < 1e-9,
645 "mismatch at ({i},{j})"
646 );
647 }
648 }
649 }
650
651 #[test]
654 fn fft_ifft_ortho_roundtrip() {
655 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
657 let a = make_1d(data.clone());
658 let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
659 let recovered = ifft(&spectrum, None, None, FftNorm::Ortho).unwrap();
660 for (orig, rec) in data.iter().zip(recovered.iter()) {
661 assert!((orig.re - rec.re).abs() < 1e-10);
662 assert!((orig.im - rec.im).abs() < 1e-10);
663 }
664 }
665
666 #[test]
667 fn fft_ifft_forward_roundtrip() {
668 let data = vec![c(1.0, 2.0), c(-1.0, 0.5), c(3.0, -1.0), c(0.0, 0.0)];
670 let a = make_1d(data.clone());
671 let spectrum = fft(&a, None, None, FftNorm::Forward).unwrap();
672 let recovered = ifft(&spectrum, None, None, FftNorm::Forward).unwrap();
673 for (orig, rec) in data.iter().zip(recovered.iter()) {
674 assert!((orig.re - rec.re).abs() < 1e-10);
675 assert!((orig.im - rec.im).abs() < 1e-10);
676 }
677 }
678
679 #[test]
680 fn fft_ortho_energy_preservation() {
681 let data = vec![c(1.0, 0.0), c(2.0, 0.0), c(3.0, 0.0), c(4.0, 0.0)];
683 let a = make_1d(data.clone());
684 let spectrum = fft(&a, None, None, FftNorm::Ortho).unwrap();
685
686 let energy_time: f64 = data.iter().map(|x| x.re * x.re + x.im * x.im).sum();
687 let energy_freq: f64 = spectrum.iter().map(|x| x.re * x.re + x.im * x.im).sum();
688 assert!(
689 (energy_time - energy_freq).abs() < 1e-10,
690 "Parseval: time={energy_time}, freq={energy_freq}"
691 );
692 }
693
694 #[test]
695 fn fft_forward_scaling() {
696 let a = make_1d(vec![c(1.0, 0.0); 4]);
699 let result = fft(&a, None, None, FftNorm::Forward).unwrap();
700 let vals: Vec<_> = result.iter().copied().collect();
701 assert!((vals[0].re - 1.0).abs() < 1e-12);
702 for v in &vals[1..] {
703 assert!(v.re.abs() < 1e-12);
704 assert!(v.im.abs() < 1e-12);
705 }
706 }
707
708 #[test]
711 fn fft_ifft_f32_roundtrip() {
712 let data: Vec<Complex<f32>> = (0..16)
714 .map(|i| Complex::new(i as f32 * 0.25, (i as f32).sin()))
715 .collect();
716 let a = Array::from_vec(Ix1::new([16]), data.clone()).unwrap();
717 let spectrum = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
718 assert_eq!(spectrum.shape(), &[16]);
719 let recovered = ifft::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
720 for (orig, rec) in data.iter().zip(recovered.iter()) {
721 assert!(
722 (orig.re - rec.re).abs() < 1e-4,
723 "f32 re mismatch: {} vs {}",
724 orig.re,
725 rec.re
726 );
727 assert!(
728 (orig.im - rec.im).abs() < 1e-4,
729 "f32 im mismatch: {} vs {}",
730 orig.im,
731 rec.im
732 );
733 }
734 }
735
736 #[test]
737 fn fft_f32_impulse() {
738 let data = vec![
740 Complex::<f32>::new(1.0, 0.0),
741 Complex::<f32>::new(0.0, 0.0),
742 Complex::<f32>::new(0.0, 0.0),
743 Complex::<f32>::new(0.0, 0.0),
744 ];
745 let a = Array::from_vec(Ix1::new([4]), data).unwrap();
746 let result = fft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
747 for val in result.iter() {
748 assert!((val.re - 1.0).abs() < 1e-6);
749 assert!(val.im.abs() < 1e-6);
750 }
751 }
752
753 #[test]
754 fn fft2_f32_roundtrip() {
755 use ferray_core::dimension::Ix2;
756 let data: Vec<Complex<f32>> = (0..16)
757 .map(|i| Complex::new(i as f32, -(i as f32) * 0.25))
758 .collect();
759 let a = Array::from_vec(Ix2::new([4, 4]), data.clone()).unwrap();
760 let spectrum = fft2::<f32, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
761 let recovered = ifft2::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
762 for (o, r) in data.iter().zip(recovered.iter()) {
763 assert!((o.re - r.re).abs() < 1e-4);
764 assert!((o.im - r.im).abs() < 1e-4);
765 }
766 }
767
768 #[test]
771 fn fft_real_ifft_real_roundtrip_f64() {
772 let original = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
774 let a = Array::<f64, Ix1>::from_vec(Ix1::new([8]), original.clone()).unwrap();
775 let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
776 assert_eq!(spectrum.shape(), &[8]);
777 let recovered = ifft_real::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
778 for (o, r) in original.iter().zip(recovered.iter()) {
779 assert!((o - r).abs() < 1e-10, "mismatch: {} vs {}", o, r);
780 }
781 }
782
783 #[test]
784 fn fft_real_ifft_real_roundtrip_f32() {
785 let original: Vec<f32> = (0..16).map(|i| i as f32 * 0.5 - 2.0).collect();
787 let a = Array::<f32, Ix1>::from_vec(Ix1::new([16]), original.clone()).unwrap();
788 let spectrum = fft_real::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
789 let recovered = ifft_real::<f32, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
790 for (o, r) in original.iter().zip(recovered.iter()) {
791 assert!((o - r).abs() < 1e-4, "f32 mismatch: {} vs {}", o, r);
792 }
793 }
794
795 #[test]
796 fn fft_real_dc_component() {
797 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 1.0, 1.0, 1.0]).unwrap();
799 let spectrum = fft_real::<f64, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
800 let vals: Vec<_> = spectrum.iter().copied().collect();
801 assert!((vals[0].re - 4.0).abs() < 1e-12);
802 assert!(vals[0].im.abs() < 1e-12);
803 for v in &vals[1..] {
804 assert!(v.re.abs() < 1e-12);
805 assert!(v.im.abs() < 1e-12);
806 }
807 }
808
809 #[test]
810 fn fft_real2_roundtrip() {
811 use ferray_core::dimension::Ix2;
812 let data: Vec<f64> = (0..12).map(|i| i as f64 * 0.3).collect();
813 let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
814 let spectrum = fft_real2::<f64, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
815 assert_eq!(spectrum.shape(), &[3, 4]);
816 let recovered = ifft2::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
817 for (o, r) in data.iter().zip(recovered.iter()) {
818 assert!((o - r.re).abs() < 1e-10);
819 assert!(r.im.abs() < 1e-10);
820 }
821 }
822
823 #[test]
824 fn fft_realn_3d_roundtrip() {
825 use ferray_core::dimension::Ix3;
826 let data: Vec<f64> = (0..24).map(|i| (i as f64).sin()).collect();
827 let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data.clone()).unwrap();
828 let spectrum = fft_realn::<f64, Ix3>(&a, None, None, FftNorm::Backward).unwrap();
829 assert_eq!(spectrum.shape(), &[2, 3, 4]);
830 let recovered = ifftn::<f64, IxDyn>(&spectrum, None, None, FftNorm::Backward).unwrap();
831 for (o, r) in data.iter().zip(recovered.iter()) {
832 assert!((o - r.re).abs() < 1e-10);
833 assert!(r.im.abs() < 1e-10);
834 }
835 }
836}