realfft/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub use rustfft::num_complex;
4pub use rustfft::num_traits;
5pub use rustfft::FftNum;
6
7use rustfft::num_complex::Complex;
8use rustfft::num_traits::Zero;
9use rustfft::FftPlanner;
10use std::collections::HashMap;
11use std::error;
12use std::fmt;
13use std::sync::Arc;
14
15type Res<T> = Result<T, FftError>;
16
17/// Custom error returned by FFTs
18pub enum FftError {
19    /// The input buffer has the wrong size. The transform was not performed.
20    ///
21    /// The first member of the tuple is the expected size and the second member is the received
22    /// size.
23    InputBuffer(usize, usize),
24    /// The output buffer has the wrong size. The transform was not performed.
25    ///
26    /// The first member of the tuple is the expected size and the second member is the received
27    /// size.
28    OutputBuffer(usize, usize),
29    /// The scratch buffer has the wrong size. The transform was not performed.
30    ///
31    /// The first member of the tuple is the minimum size and the second member is the received
32    /// size.
33    ScratchBuffer(usize, usize),
34    /// The input data contained a non-zero imaginary part where there should have been a zero.
35    /// The transform was performed, but the result may not be correct.
36    ///
37    /// The first member of the tuple represents the first index of the complex buffer and the
38    /// second member represents the last index of the complex buffer. The values are set to true
39    /// if the corresponding complex value contains a non-zero imaginary part.
40    InputValues(bool, bool),
41}
42
43impl FftError {
44    fn fmt_internal(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        let desc = match self {
46            Self::InputBuffer(expected, got) => {
47                format!("Wrong length of input, expected {}, got {}", expected, got)
48            }
49            Self::OutputBuffer(expected, got) => {
50                format!("Wrong length of output, expected {}, got {}", expected, got)
51            }
52            Self::ScratchBuffer(expected, got) => {
53                format!(
54                    "Scratch buffer of size {} is too small, must be at least {} long",
55                    got, expected
56                )
57            }
58            Self::InputValues(first, last) => match (first, last) {
59                (true, false) => "Imaginary part of first value was non-zero.".to_string(),
60                (false, true) => "Imaginary part of last value was non-zero.".to_string(),
61                (true, true) => {
62                    "Imaginary parts of both first and last values were non-zero.".to_string()
63                }
64                (false, false) => unreachable!(),
65            },
66        };
67        write!(f, "{}", desc)
68    }
69}
70
71impl fmt::Debug for FftError {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        self.fmt_internal(f)
74    }
75}
76
77impl fmt::Display for FftError {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        self.fmt_internal(f)
80    }
81}
82
83impl error::Error for FftError {}
84
85fn compute_twiddle<T: FftNum>(index: usize, fft_len: usize) -> Complex<T> {
86    let constant = -2f64 * std::f64::consts::PI / fft_len as f64;
87    let angle = constant * index as f64;
88    Complex {
89        re: T::from_f64(angle.cos()).unwrap(),
90        im: T::from_f64(angle.sin()).unwrap(),
91    }
92}
93
94pub struct RealToComplexOdd<T> {
95    length: usize,
96    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
97    scratch_len: usize,
98}
99
100pub struct RealToComplexEven<T> {
101    twiddles: Vec<Complex<T>>,
102    length: usize,
103    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
104    scratch_len: usize,
105}
106
107pub struct ComplexToRealOdd<T> {
108    length: usize,
109    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
110    scratch_len: usize,
111}
112
113pub struct ComplexToRealEven<T> {
114    twiddles: Vec<Complex<T>>,
115    length: usize,
116    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
117    scratch_len: usize,
118}
119
120/// A forward FFT that takes a real-valued input signal of length N
121/// and transforms it to a complex spectrum of length N/2+1.
122#[allow(clippy::len_without_is_empty)]
123pub trait RealToComplex<T>: Sync + Send {
124    /// Transform a signal of N real-valued samples,
125    /// storing the resulting complex spectrum in the N/2+1
126    /// (with N/2 rounded down) element long output slice.
127    /// The input buffer is used as scratch space,
128    /// so the contents of input should be considered garbage after calling.
129    /// It also allocates additional scratch space as needed.
130    /// An error is returned if any of the given slices has the wrong length.
131    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()>;
132
133    /// Transform a signal of N real-valued samples,
134    /// similar to [`process()`](RealToComplex::process).
135    /// The difference is that this method uses the provided
136    /// scratch buffer instead of allocating new scratch space.
137    /// This is faster if the same scratch buffer is used for multiple calls.
138    fn process_with_scratch(
139        &self,
140        input: &mut [T],
141        output: &mut [Complex<T>],
142        scratch: &mut [Complex<T>],
143    ) -> Res<()>;
144
145    /// Get the minimum length of the scratch buffer needed for `process_with_scratch`.
146    fn get_scratch_len(&self) -> usize;
147
148    /// The FFT length.
149    /// Get the length of the real signal that this FFT takes as input.
150    fn len(&self) -> usize;
151
152    /// Get the number of complex data points that this FFT returns.
153    fn complex_len(&self) -> usize {
154        self.len() / 2 + 1
155    }
156
157    /// Convenience method to make an input vector of the right type and length.
158    fn make_input_vec(&self) -> Vec<T>;
159
160    /// Convenience method to make an output vector of the right type and length.
161    fn make_output_vec(&self) -> Vec<Complex<T>>;
162
163    /// Convenience method to make a scratch vector of the right type and length.
164    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
165}
166
167/// An inverse FFT that takes a complex spectrum of length N/2+1
168/// and transforms it to a real-valued signal of length N.
169#[allow(clippy::len_without_is_empty)]
170pub trait ComplexToReal<T>: Sync + Send {
171    /// Inverse transform a complex spectrum corresponding to a real-valued signal of length N.
172    /// The input is a slice of complex values with length N/2+1 (with N/2 rounded down).
173    /// The resulting real-valued signal is stored in the output slice of length N.
174    /// The input buffer is used as scratch space,
175    /// so the contents of input should be considered garbage after calling.
176    /// It also allocates additional scratch space as needed.
177    /// An error is returned if any of the given slices has the wrong length.
178    /// If the input data is invalid, meaning that one of the positions that should
179    /// contain a zero holds a non-zero value, the transform is still performed.
180    /// The function then returns an `FftError::InputValues` error to tell that the
181    /// result may not be correct.
182    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()>;
183
184    /// Inverse transform a complex spectrum,
185    /// similar to [`process()`](ComplexToReal::process).
186    /// The difference is that this method uses the provided
187    /// scratch buffer instead of allocating new scratch space.
188    /// This is faster if the same scratch buffer is used for multiple calls.
189    fn process_with_scratch(
190        &self,
191        input: &mut [Complex<T>],
192        output: &mut [T],
193        scratch: &mut [Complex<T>],
194    ) -> Res<()>;
195
196    /// Get the minimum length of the scratch space needed for `process_with_scratch`.
197    fn get_scratch_len(&self) -> usize;
198
199    /// The FFT length.
200    /// Get the length of the real-valued signal that this FFT returns.
201    fn len(&self) -> usize;
202
203    /// Get the length of the slice slice of complex values that this FFT accepts as input.
204    fn complex_len(&self) -> usize {
205        self.len() / 2 + 1
206    }
207
208    /// Convenience method to make an input vector of the right type and length.
209    fn make_input_vec(&self) -> Vec<Complex<T>>;
210
211    /// Convenience method to make an output vector of the right type and length.
212    fn make_output_vec(&self) -> Vec<T>;
213
214    /// Convenience method to make a scratch vector of the right type and length.
215    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
216}
217
218fn zip3<A, B, C>(a: A, b: B, c: C) -> impl Iterator<Item = (A::Item, B::Item, C::Item)>
219where
220    A: IntoIterator,
221    B: IntoIterator,
222    C: IntoIterator,
223{
224    a.into_iter()
225        .zip(b.into_iter().zip(c))
226        .map(|(x, (y, z))| (x, y, z))
227}
228
229/// A planner is used to create FFTs.
230/// It caches results internally,
231/// so when making more than one FFT it is advisable to reuse the same planner.
232pub struct RealFftPlanner<T: FftNum> {
233    planner: FftPlanner<T>,
234    r2c_cache: HashMap<usize, Arc<dyn RealToComplex<T>>>,
235    c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<T>>>,
236}
237
238impl<T: FftNum> RealFftPlanner<T> {
239    /// Create a new planner.
240    pub fn new() -> Self {
241        let planner = FftPlanner::<T>::new();
242        Self {
243            r2c_cache: HashMap::new(),
244            c2r_cache: HashMap::new(),
245            planner,
246        }
247    }
248
249    /// Plan a real-to-complex forward FFT. Returns the FFT in a shared reference.
250    /// If requesting a second forward FFT of the same length,
251    /// the planner will return a new reference to the already existing one.
252    pub fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<T>> {
253        if let Some(fft) = self.r2c_cache.get(&len) {
254            Arc::clone(fft)
255        } else {
256            let fft = if len % 2 > 0 {
257                Arc::new(RealToComplexOdd::new(len, &mut self.planner)) as Arc<dyn RealToComplex<T>>
258            } else {
259                Arc::new(RealToComplexEven::new(len, &mut self.planner))
260                    as Arc<dyn RealToComplex<T>>
261            };
262            self.r2c_cache.insert(len, Arc::clone(&fft));
263            fft
264        }
265    }
266
267    /// Plan a complex-to-real inverse FFT. Returns the FFT in a shared reference.
268    /// If requesting a second inverse FFT of the same length,
269    /// the planner will return a new reference to the already existing one.
270    pub fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<T>> {
271        if let Some(fft) = self.c2r_cache.get(&len) {
272            Arc::clone(fft)
273        } else {
274            let fft = if len % 2 > 0 {
275                Arc::new(ComplexToRealOdd::new(len, &mut self.planner)) as Arc<dyn ComplexToReal<T>>
276            } else {
277                Arc::new(ComplexToRealEven::new(len, &mut self.planner))
278                    as Arc<dyn ComplexToReal<T>>
279            };
280            self.c2r_cache.insert(len, Arc::clone(&fft));
281            fft
282        }
283    }
284}
285
286impl<T: FftNum> Default for RealFftPlanner<T> {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292impl<T: FftNum> RealToComplexOdd<T> {
293    /// Create a new RealToComplex forward FFT for real-valued input data of a given length,
294    /// and uses the given FftPlanner to build the inner FFT.
295    /// Panics if the length is not odd.
296    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
297        if length % 2 == 0 {
298            panic!("Length must be odd, got {}", length,);
299        }
300        let fft = fft_planner.plan_fft_forward(length);
301        let scratch_len = fft.get_inplace_scratch_len() + length;
302        RealToComplexOdd {
303            length,
304            fft,
305            scratch_len,
306        }
307    }
308}
309
310impl<T: FftNum> RealToComplex<T> for RealToComplexOdd<T> {
311    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
312        let mut scratch = self.make_scratch_vec();
313        self.process_with_scratch(input, output, &mut scratch)
314    }
315
316    fn process_with_scratch(
317        &self,
318        input: &mut [T],
319        output: &mut [Complex<T>],
320        scratch: &mut [Complex<T>],
321    ) -> Res<()> {
322        if input.len() != self.length {
323            return Err(FftError::InputBuffer(self.length, input.len()));
324        }
325        let expected_output_buffer_size = self.complex_len();
326        if output.len() != expected_output_buffer_size {
327            return Err(FftError::OutputBuffer(
328                expected_output_buffer_size,
329                output.len(),
330            ));
331        }
332        if scratch.len() < (self.scratch_len) {
333            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
334        }
335        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);
336
337        for (val, buf) in input.iter().zip(buffer.iter_mut()) {
338            *buf = Complex::new(*val, T::zero());
339        }
340        // FFT and store result in buffer_out
341        self.fft.process_with_scratch(buffer, fft_scratch);
342        output.copy_from_slice(&buffer[0..self.complex_len()]);
343        if let Some(elem) = output.first_mut() {
344            elem.im = T::zero();
345        }
346        Ok(())
347    }
348
349    fn get_scratch_len(&self) -> usize {
350        self.scratch_len
351    }
352
353    fn len(&self) -> usize {
354        self.length
355    }
356
357    fn make_input_vec(&self) -> Vec<T> {
358        vec![T::zero(); self.len()]
359    }
360
361    fn make_output_vec(&self) -> Vec<Complex<T>> {
362        vec![Complex::zero(); self.complex_len()]
363    }
364
365    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
366        vec![Complex::zero(); self.get_scratch_len()]
367    }
368}
369
370impl<T: FftNum> RealToComplexEven<T> {
371    /// Create a new RealToComplex forward FFT for real-valued input data of a given length,
372    /// and uses the given FftPlanner to build the inner FFT.
373    /// Panics if the length is not even.
374    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
375        if length % 2 > 0 {
376            panic!("Length must be even, got {}", length,);
377        }
378        let twiddle_count = if length % 4 == 0 {
379            length / 4
380        } else {
381            length / 4 + 1
382        };
383        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
384            .map(|i| compute_twiddle(i, length) * T::from_f64(0.5).unwrap())
385            .collect();
386        let fft = fft_planner.plan_fft_forward(length / 2);
387        let scratch_len = fft.get_outofplace_scratch_len();
388        RealToComplexEven {
389            twiddles,
390            length,
391            fft,
392            scratch_len,
393        }
394    }
395}
396
397impl<T: FftNum> RealToComplex<T> for RealToComplexEven<T> {
398    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
399        let mut scratch = self.make_scratch_vec();
400        self.process_with_scratch(input, output, &mut scratch)
401    }
402
403    fn process_with_scratch(
404        &self,
405        input: &mut [T],
406        output: &mut [Complex<T>],
407        scratch: &mut [Complex<T>],
408    ) -> Res<()> {
409        if input.len() != self.length {
410            return Err(FftError::InputBuffer(self.length, input.len()));
411        }
412        let expected_output_buffer_size = self.complex_len();
413        if output.len() != expected_output_buffer_size {
414            return Err(FftError::OutputBuffer(
415                expected_output_buffer_size,
416                output.len(),
417            ));
418        }
419        if scratch.len() < (self.scratch_len) {
420            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
421        }
422
423        let fftlen = self.length / 2;
424        let buf_in = unsafe {
425            let ptr = input.as_mut_ptr() as *mut Complex<T>;
426            let len = input.len();
427            std::slice::from_raw_parts_mut(ptr, len / 2)
428        };
429
430        // FFT and store result in buffer_out
431        self.fft
432            .process_outofplace_with_scratch(buf_in, &mut output[0..fftlen], scratch);
433        let (mut output_left, mut output_right) = output.split_at_mut(output.len() / 2);
434
435        // The first and last element don't require any twiddle factors, so skip that work
436        match (output_left.first_mut(), output_right.last_mut()) {
437            (Some(first_element), Some(last_element)) => {
438                // The first and last elements are just a sum and difference of the first value's real and imaginary values
439                let first_value = *first_element;
440                *first_element = Complex {
441                    re: first_value.re + first_value.im,
442                    im: T::zero(),
443                };
444                *last_element = Complex {
445                    re: first_value.re - first_value.im,
446                    im: T::zero(),
447                };
448
449                // Chop the first and last element off of our slices so that the loop below doesn't have to deal with them
450                output_left = &mut output_left[1..];
451                let right_len = output_right.len();
452                output_right = &mut output_right[..right_len - 1];
453            }
454            _ => {
455                return Ok(());
456            }
457        }
458        // Loop over the remaining elements and apply twiddle factors on them
459        for (twiddle, out, out_rev) in zip3(
460            self.twiddles.iter(),
461            output_left.iter_mut(),
462            output_right.iter_mut().rev(),
463        ) {
464            let sum = *out + *out_rev;
465            let diff = *out - *out_rev;
466            let half = T::from_f64(0.5).unwrap();
467            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
468            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
469            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
470            // multiplications in half.
471            let twiddled_re_sum = sum * twiddle.re;
472            let twiddled_im_sum = sum * twiddle.im;
473            let twiddled_re_diff = diff * twiddle.re;
474            let twiddled_im_diff = diff * twiddle.im;
475            let half_sum_re = half * sum.re;
476            let half_diff_im = half * diff.im;
477
478            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
479            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;
480
481            // We finally have all the data we need to write the transformed data back out where we found it.
482            *out = Complex {
483                re: half_sum_re + output_twiddled_real,
484                im: half_diff_im + output_twiddled_im,
485            };
486
487            *out_rev = Complex {
488                re: half_sum_re - output_twiddled_real,
489                im: output_twiddled_im - half_diff_im,
490            };
491        }
492
493        // If the output len is odd, the loop above can't postprocess the centermost element, so handle that separately.
494        if output.len() % 2 == 1 {
495            if let Some(center_element) = output.get_mut(output.len() / 2) {
496                center_element.im = -center_element.im;
497            }
498        }
499        Ok(())
500    }
501    fn get_scratch_len(&self) -> usize {
502        self.scratch_len
503    }
504
505    fn len(&self) -> usize {
506        self.length
507    }
508
509    fn make_input_vec(&self) -> Vec<T> {
510        vec![T::zero(); self.len()]
511    }
512
513    fn make_output_vec(&self) -> Vec<Complex<T>> {
514        vec![Complex::zero(); self.complex_len()]
515    }
516
517    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
518        vec![Complex::zero(); self.get_scratch_len()]
519    }
520}
521
522impl<T: FftNum> ComplexToRealOdd<T> {
523    /// Create a new ComplexToRealOdd inverse FFT for complex input spectra.
524    /// The `length` parameter refers to the length of the resulting real-valued signal.
525    /// Uses the given FftPlanner to build the inner FFT.
526    /// Panics if the length is not odd.
527    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
528        if length % 2 == 0 {
529            panic!("Length must be odd, got {}", length,);
530        }
531        let fft = fft_planner.plan_fft_inverse(length);
532        let scratch_len = length + fft.get_inplace_scratch_len();
533        ComplexToRealOdd {
534            length,
535            fft,
536            scratch_len,
537        }
538    }
539}
540
541impl<T: FftNum> ComplexToReal<T> for ComplexToRealOdd<T> {
542    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
543        let mut scratch = self.make_scratch_vec();
544        self.process_with_scratch(input, output, &mut scratch)
545    }
546
547    fn process_with_scratch(
548        &self,
549        input: &mut [Complex<T>],
550        output: &mut [T],
551        scratch: &mut [Complex<T>],
552    ) -> Res<()> {
553        let expected_input_buffer_size = self.complex_len();
554        if input.len() != expected_input_buffer_size {
555            return Err(FftError::InputBuffer(
556                expected_input_buffer_size,
557                input.len(),
558            ));
559        }
560        if output.len() != self.length {
561            return Err(FftError::OutputBuffer(self.length, output.len()));
562        }
563        if scratch.len() < (self.scratch_len) {
564            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
565        }
566
567        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
568            input[0].im = T::from_f64(0.0).unwrap();
569            true
570        } else {
571            false
572        };
573
574        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);
575
576        buffer[0..input.len()].copy_from_slice(input);
577        for (buf, val) in buffer
578            .iter_mut()
579            .rev()
580            .take(self.length / 2)
581            .zip(input.iter().skip(1))
582        {
583            *buf = val.conj();
584        }
585        self.fft.process_with_scratch(buffer, fft_scratch);
586        for (val, out) in buffer.iter().zip(output.iter_mut()) {
587            *out = val.re;
588        }
589        if first_invalid {
590            return Err(FftError::InputValues(true, false));
591        }
592        Ok(())
593    }
594
595    fn get_scratch_len(&self) -> usize {
596        self.scratch_len
597    }
598
599    fn len(&self) -> usize {
600        self.length
601    }
602
603    fn make_input_vec(&self) -> Vec<Complex<T>> {
604        vec![Complex::zero(); self.complex_len()]
605    }
606
607    fn make_output_vec(&self) -> Vec<T> {
608        vec![T::zero(); self.len()]
609    }
610
611    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
612        vec![Complex::zero(); self.get_scratch_len()]
613    }
614}
615
616impl<T: FftNum> ComplexToRealEven<T> {
617    /// Create a new ComplexToRealEven inverse FFT for complex input spectra.
618    /// The `length` parameter refers to the length of the resulting real-valued signal.
619    /// Uses the given FftPlanner to build the inner FFT.
620    /// Panics if the length is not even.
621    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
622        if length % 2 > 0 {
623            panic!("Length must be even, got {}", length,);
624        }
625        let twiddle_count = if length % 4 == 0 {
626            length / 4
627        } else {
628            length / 4 + 1
629        };
630        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
631            .map(|i| compute_twiddle(i, length).conj())
632            .collect();
633        let fft = fft_planner.plan_fft_inverse(length / 2);
634        let scratch_len = fft.get_outofplace_scratch_len();
635        ComplexToRealEven {
636            twiddles,
637            length,
638            fft,
639            scratch_len,
640        }
641    }
642}
643impl<T: FftNum> ComplexToReal<T> for ComplexToRealEven<T> {
644    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
645        let mut scratch = self.make_scratch_vec();
646        self.process_with_scratch(input, output, &mut scratch)
647    }
648
649    fn process_with_scratch(
650        &self,
651        input: &mut [Complex<T>],
652        output: &mut [T],
653        scratch: &mut [Complex<T>],
654    ) -> Res<()> {
655        let expected_input_buffer_size = self.complex_len();
656        if input.len() != expected_input_buffer_size {
657            return Err(FftError::InputBuffer(
658                expected_input_buffer_size,
659                input.len(),
660            ));
661        }
662        if output.len() != self.length {
663            return Err(FftError::OutputBuffer(self.length, output.len()));
664        }
665        if scratch.len() < (self.scratch_len) {
666            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
667        }
668        if input.is_empty() {
669            return Ok(());
670        }
671        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
672            input[0].im = T::from_f64(0.0).unwrap();
673            true
674        } else {
675            false
676        };
677        let last_invalid = if input[input.len() - 1].im != T::from_f64(0.0).unwrap() {
678            input[input.len() - 1].im = T::from_f64(0.0).unwrap();
679            true
680        } else {
681            false
682        };
683
684        let (mut input_left, mut input_right) = input.split_at_mut(input.len() / 2);
685
686        // We have to preprocess the input in-place before we send it to the FFT.
687        // The first and centermost values have to be preprocessed separately from the rest, so do that now.
688        match (input_left.first_mut(), input_right.last_mut()) {
689            (Some(first_input), Some(last_input)) => {
690                let first_sum = *first_input + *last_input;
691                let first_diff = *first_input - *last_input;
692
693                *first_input = Complex {
694                    re: first_sum.re - first_sum.im,
695                    im: first_diff.re - first_diff.im,
696                };
697
698                input_left = &mut input_left[1..];
699                let right_len = input_right.len();
700                input_right = &mut input_right[..right_len - 1];
701            }
702            _ => return Ok(()),
703        };
704
705        // now, in a loop, preprocess the rest of the elements 2 at a time.
706        for (twiddle, fft_input, fft_input_rev) in zip3(
707            self.twiddles.iter(),
708            input_left.iter_mut(),
709            input_right.iter_mut().rev(),
710        ) {
711            let sum = *fft_input + *fft_input_rev;
712            let diff = *fft_input - *fft_input_rev;
713
714            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
715            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
716            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
717            // multiplications in half.
718            let twiddled_re_sum = sum * twiddle.re;
719            let twiddled_im_sum = sum * twiddle.im;
720            let twiddled_re_diff = diff * twiddle.re;
721            let twiddled_im_diff = diff * twiddle.im;
722
723            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
724            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;
725
726            // We finally have all the data we need to write our preprocessed data back where we got it from.
727            *fft_input = Complex {
728                re: sum.re - output_twiddled_real,
729                im: diff.im - output_twiddled_im,
730            };
731            *fft_input_rev = Complex {
732                re: sum.re + output_twiddled_real,
733                im: -output_twiddled_im - diff.im,
734            }
735        }
736
737        // If the output len is odd, the loop above can't preprocess the centermost element, so handle that separately
738        if input.len() % 2 == 1 {
739            let center_element = input[input.len() / 2];
740            let doubled = center_element + center_element;
741            input[input.len() / 2] = doubled.conj();
742        }
743
744        // FFT and store result in buffer_out
745        let buf_out = unsafe {
746            let ptr = output.as_mut_ptr() as *mut Complex<T>;
747            let len = output.len();
748            std::slice::from_raw_parts_mut(ptr, len / 2)
749        };
750        self.fft
751            .process_outofplace_with_scratch(&mut input[..buf_out.len()], buf_out, scratch);
752        if first_invalid || last_invalid {
753            return Err(FftError::InputValues(first_invalid, last_invalid));
754        }
755        Ok(())
756    }
757
758    fn get_scratch_len(&self) -> usize {
759        self.scratch_len
760    }
761
762    fn len(&self) -> usize {
763        self.length
764    }
765
766    fn make_input_vec(&self) -> Vec<Complex<T>> {
767        vec![Complex::zero(); self.complex_len()]
768    }
769
770    fn make_output_vec(&self) -> Vec<T> {
771        vec![T::zero(); self.len()]
772    }
773
774    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
775        vec![Complex::zero(); self.get_scratch_len()]
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use crate::FftError;
782    use crate::RealFftPlanner;
783    use rand::Rng;
784    use rustfft::num_complex::Complex;
785    use rustfft::num_traits::{Float, Zero};
786    use rustfft::FftPlanner;
787    use std::error::Error;
788    use std::ops::Sub;
789
790    // get the largest difference
791    fn compare_complex<T: Float + Sub>(a: &[Complex<T>], b: &[Complex<T>]) -> T {
792        a.iter()
793            .zip(b.iter())
794            .fold(T::zero(), |maxdiff, (val_a, val_b)| {
795                let diff = (val_a - val_b).norm();
796                if maxdiff > diff {
797                    maxdiff
798                } else {
799                    diff
800                }
801            })
802    }
803
804    // get the largest difference
805    fn compare_scalars<T: Float + Sub>(a: &[T], b: &[T]) -> T {
806        a.iter()
807            .zip(b.iter())
808            .fold(T::zero(), |maxdiff, (val_a, val_b)| {
809                let diff = (*val_a - *val_b).abs();
810                if maxdiff > diff {
811                    maxdiff
812                } else {
813                    diff
814                }
815            })
816    }
817
818    // Compare ComplexToReal with standard inverse FFT
819    #[test]
820    fn complex_to_real_64() {
821        for length in 1..1000 {
822            let mut real_planner = RealFftPlanner::<f64>::new();
823            let c2r = real_planner.plan_fft_inverse(length);
824            let mut out_a = c2r.make_output_vec();
825            let mut indata = c2r.make_input_vec();
826            let mut rustfft_check: Vec<Complex<f64>> = vec![Complex::zero(); length];
827            let mut rng = rand::rng();
828            for val in indata.iter_mut() {
829                *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
830            }
831            indata[0].im = 0.0;
832            if length % 2 == 0 {
833                indata[length / 2].im = 0.0;
834            }
835            for (val_long, val) in rustfft_check
836                .iter_mut()
837                .take(c2r.complex_len())
838                .zip(indata.iter())
839            {
840                *val_long = *val;
841            }
842            for (val_long, val) in rustfft_check
843                .iter_mut()
844                .rev()
845                .take(length / 2)
846                .zip(indata.iter().skip(1))
847            {
848                *val_long = val.conj();
849            }
850            let mut fft_planner = FftPlanner::<f64>::new();
851            let fft = fft_planner.plan_fft_inverse(length);
852
853            c2r.process(&mut indata, &mut out_a).unwrap();
854            fft.process(&mut rustfft_check);
855
856            let check_real = rustfft_check.iter().map(|val| val.re).collect::<Vec<f64>>();
857            let maxdiff = compare_scalars(&out_a, &check_real);
858            assert!(
859                maxdiff < 1.0e-9,
860                "Length: {}, too large error: {}",
861                length,
862                maxdiff
863            );
864        }
865    }
866
867    // Compare ComplexToReal with standard inverse FFT
868    #[test]
869    fn complex_to_real_32() {
870        for length in 1..1000 {
871            let mut real_planner = RealFftPlanner::<f32>::new();
872            let c2r = real_planner.plan_fft_inverse(length);
873            let mut out_a = c2r.make_output_vec();
874            let mut indata = c2r.make_input_vec();
875            let mut rustfft_check: Vec<Complex<f32>> = vec![Complex::zero(); length];
876            let mut rng = rand::rng();
877            for val in indata.iter_mut() {
878                *val = Complex::new(rng.random::<f32>(), rng.random::<f32>());
879            }
880            indata[0].im = 0.0;
881            if length % 2 == 0 {
882                indata[length / 2].im = 0.0;
883            }
884            for (val_long, val) in rustfft_check
885                .iter_mut()
886                .take(c2r.complex_len())
887                .zip(indata.iter())
888            {
889                *val_long = *val;
890            }
891            for (val_long, val) in rustfft_check
892                .iter_mut()
893                .rev()
894                .take(length / 2)
895                .zip(indata.iter().skip(1))
896            {
897                *val_long = val.conj();
898            }
899            let mut fft_planner = FftPlanner::<f32>::new();
900            let fft = fft_planner.plan_fft_inverse(length);
901
902            c2r.process(&mut indata, &mut out_a).unwrap();
903            fft.process(&mut rustfft_check);
904
905            let check_real = rustfft_check.iter().map(|val| val.re).collect::<Vec<f32>>();
906            let maxdiff = compare_scalars(&out_a, &check_real);
907            assert!(
908                maxdiff < 5.0e-4,
909                "Length: {}, too large error: {}",
910                length,
911                maxdiff
912            );
913        }
914    }
915
916    // Test that ComplexToReal returns the right errors
917    #[test]
918    fn complex_to_real_errors_even() {
919        let length = 100;
920        let mut real_planner = RealFftPlanner::<f64>::new();
921        let c2r = real_planner.plan_fft_inverse(length);
922        let mut out_a = c2r.make_output_vec();
923        let mut indata = c2r.make_input_vec();
924        let mut rng = rand::rng();
925
926        // Make some valid data
927        for val in indata.iter_mut() {
928            *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
929        }
930        indata[0].im = 0.0;
931        indata[50].im = 0.0;
932        // this should be ok
933        assert!(c2r.process(&mut indata, &mut out_a).is_ok());
934
935        // Make some invalid data, first point invalid
936        for val in indata.iter_mut() {
937            *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
938        }
939        indata[50].im = 0.0;
940        let res = c2r.process(&mut indata, &mut out_a);
941        assert!(res.is_err());
942        assert!(matches!(res, Err(FftError::InputValues(true, false))));
943
944        // Make some invalid data, last point invalid
945        for val in indata.iter_mut() {
946            *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
947        }
948        indata[0].im = 0.0;
949        let res = c2r.process(&mut indata, &mut out_a);
950        assert!(res.is_err());
951        assert!(matches!(res, Err(FftError::InputValues(false, true))));
952    }
953
954    // Test that ComplexToReal returns the right errors
955    #[test]
956    fn complex_to_real_errors_odd() {
957        let length = 101;
958        let mut real_planner = RealFftPlanner::<f64>::new();
959        let c2r = real_planner.plan_fft_inverse(length);
960        let mut out_a = c2r.make_output_vec();
961        let mut indata = c2r.make_input_vec();
962        let mut rng = rand::rng();
963
964        // Make some valid data
965        for val in indata.iter_mut() {
966            *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
967        }
968        indata[0].im = 0.0;
969        // this should be ok
970        assert!(c2r.process(&mut indata, &mut out_a).is_ok());
971
972        // Make some invalid data, first point invalid
973        for val in indata.iter_mut() {
974            *val = Complex::new(rng.random::<f64>(), rng.random::<f64>());
975        }
976        let res = c2r.process(&mut indata, &mut out_a);
977        assert!(res.is_err());
978        assert!(matches!(res, Err(FftError::InputValues(true, false))));
979    }
980
981    // Compare RealToComplex with standard FFT
982    #[test]
983    fn real_to_complex_64() {
984        for length in 1..1000 {
985            let mut real_planner = RealFftPlanner::<f64>::new();
986            let r2c = real_planner.plan_fft_forward(length);
987            let mut out_a = r2c.make_output_vec();
988            let mut indata = r2c.make_input_vec();
989            let mut rng = rand::rng();
990            for val in indata.iter_mut() {
991                *val = rng.random::<f64>();
992            }
993            let mut rustfft_check = indata
994                .iter()
995                .map(Complex::from)
996                .collect::<Vec<Complex<f64>>>();
997            let mut fft_planner = FftPlanner::<f64>::new();
998            let fft = fft_planner.plan_fft_forward(length);
999            fft.process(&mut rustfft_check);
1000            r2c.process(&mut indata, &mut out_a).unwrap();
1001            assert_eq!(out_a[0].im, 0.0, "First imaginary component must be zero");
1002            if length % 2 == 0 {
1003                assert_eq!(
1004                    out_a.last().unwrap().im,
1005                    0.0,
1006                    "Last imaginary component for even lengths must be zero"
1007                );
1008            }
1009            let maxdiff = compare_complex(&out_a, &rustfft_check[0..r2c.complex_len()]);
1010            assert!(
1011                maxdiff < 1.0e-9,
1012                "Length: {}, too large error: {}",
1013                length,
1014                maxdiff
1015            );
1016        }
1017    }
1018
1019    // Compare RealToComplex with standard FFT
1020    #[test]
1021    fn real_to_complex_32() {
1022        for length in 1..1000 {
1023            let mut real_planner = RealFftPlanner::<f32>::new();
1024            let r2c = real_planner.plan_fft_forward(length);
1025            let mut out_a = r2c.make_output_vec();
1026            let mut indata = r2c.make_input_vec();
1027            let mut rng = rand::rng();
1028            for val in indata.iter_mut() {
1029                *val = rng.random::<f32>();
1030            }
1031            let mut rustfft_check = indata
1032                .iter()
1033                .map(Complex::from)
1034                .collect::<Vec<Complex<f32>>>();
1035            let mut fft_planner = FftPlanner::<f32>::new();
1036            let fft = fft_planner.plan_fft_forward(length);
1037            fft.process(&mut rustfft_check);
1038            r2c.process(&mut indata, &mut out_a).unwrap();
1039            assert_eq!(out_a[0].im, 0.0, "First imaginary component must be zero");
1040            if length % 2 == 0 {
1041                assert_eq!(
1042                    out_a.last().unwrap().im,
1043                    0.0,
1044                    "Last imaginary component for even lengths must be zero"
1045                );
1046            }
1047            let maxdiff = compare_complex(&out_a, &rustfft_check[0..r2c.complex_len()]);
1048            assert!(
1049                maxdiff < 5.0e-4,
1050                "Length: {}, too large error: {}",
1051                length,
1052                maxdiff
1053            );
1054        }
1055    }
1056
1057    // Check that the ? operator works on the custom errors. No need to run, just needs to compile.
1058    #[allow(dead_code)]
1059    fn test_error() -> Result<(), Box<dyn Error>> {
1060        let mut real_planner = RealFftPlanner::<f64>::new();
1061        let r2c = real_planner.plan_fft_forward(100);
1062        let mut out_a = r2c.make_output_vec();
1063        let mut indata = r2c.make_input_vec();
1064        r2c.process(&mut indata, &mut out_a)?;
1065        Ok(())
1066    }
1067}