realfft/
lib.rs

1//! # RealFFT: Real-to-complex forward FFT and complex-to-real inverse FFT based on RustFFT
2//!
3//! This library is a wrapper for [RustFFT](https://crates.io/crates/rustfft)
4//! that enables fast and convenient FFT of real-valued data.
5//! The API is designed to be as similar as possible to RustFFT.
6//!
7//! Using this library instead of RustFFT directly avoids the need of converting
8//! real-valued data to complex before performing a FFT.
9//! If the length is even, it also enables faster computations by using a complex FFT of half the length.
10//! It then packs a real-valued signal of length N into an N/2 long complex buffer,
11//! which is transformed using a standard complex-to-complex FFT.
12//! The FFT result is then post-processed to give only the first half of the complex spectrum,
13//! as an N/2+1 long complex vector.
14//!
15//! The inverse FFT goes through the same steps backwards,
16//! to transform a complex spectrum of length N/2+1 to a real-valued signal of length N.
17//!
18//! The speed increase compared to just converting the input to a length N complex vector
19//! and using a length N complex-to-complex FFT depends on the length of the input data.
20//! The largest improvements are for longer FFTs and for lengths over around 1000 elements
21//! there is an improvement of about a factor 2.
22//! The difference shrinks for shorter lengths,
23//! and around 30 elements there is no longer any difference.
24//!
25//! ## Why use real-to-complex FFT?
26//! ### Using a complex-to-complex FFT
27//! A simple way to transform a real valued signal is to convert it to complex,
28//! and then use a complex-to-complex FFT.
29//!
30//! Let's assume `x` is a 6 element long real vector:
31//! ```text
32//! x = [x0r, x1r, x2r, x3r, x4r, x5r]
33//! ```
34//!
35//! We now convert `x` to complex by adding an imaginary part with value zero.
36//! Using the notation `(xNr, xNi)` for the complex value `xN`, this becomes:
37//! ```text
38//! x_c = [(x0r, 0), (x1r, 0), (x2r, 0), (x3r, 0), (x4r, 0), (x5r, 0)]
39//! ```
40//!
41//! Performing a normal complex FFT, the result of `FFT(x_c)` is:
42//! ```text
43//! FFT(x_c) = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)]
44//! ```
45//!
46//! But because our `x_c` is real-valued (all imaginary parts are zero), some of this becomes redundant:
47//! ```text
48//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0), (X2r, -X2i), (X1r, -X1i)]
49//! ```
50//!
51//! The last two values are the complex conjugates of `X1` and `X2`. Additionally, `X0i` and `X3i` are zero.
52//! As we can see, the output contains 6 independent values, and the rest is redundant.
53//! But it still takes time for the FFT to calculate the redundant values.
54//! Converting the input data to complex also takes a little bit of time.
55//!
56//! If the length of `x` instead had been 7, result would have been:
57//! ```text
58//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X3r, -X3i), (X2r, -X2i), (X1r, -X1i)]
59//! ```
60//!
61//! The result is similar, but this time there is no zero at `X3i`.
62//! Also in this case we got the same number of independent values as we started with.
63//!
64//! ### Real-to-complex
65//! Using a real-to-complex FFT removes the need for converting the input data to complex.
66//! It also avoids calculating the redundant output values.
67//!
68//! The result for 6 elements is:
69//! ```text
70//! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0)]
71//! ```
72//!
73//! The result for 7 elements is:
74//! ```text
75//! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i)]
76//! ```
77//!
78//! This is the data layout output by the real-to-complex FFT,
79//! and the one expected as input to the complex-to-real inverse FFT.
80//!
81//! ## Scaling
82//! RealFFT matches the behaviour of RustFFT and does not normalize the output of either forward or inverse FFT.
83//! To get normalized results, each element must be scaled by `1/sqrt(length)`,
84//! where `length` is the length of the real-valued signal.
85//! If the processing involves both a forward and an inverse FFT step,
86//! it is advisable to merge the two normalization steps to a single, by scaling by `1/length`.
87//!
88//! ## Documentation
89//!
90//! The full documentation can be generated by rustdoc. To generate and view it run:
91//! ```text
92//! cargo doc --open
93//! ```
94//!
95//! ## Benchmarks
96//!
97//! To run a set of benchmarks comparing real-to-complex FFT with standard complex-to-complex, type:
98//! ```text
99//! cargo bench
100//! ```
101//! The results are printed while running, and are compiled into an html report containing much more details.
102//! To view, open `target/criterion/report/index.html` in a browser.
103//!
104//! ## Example
105//! Transform a signal, and then inverse transform the result.
106//! ```
107//! use realfft::RealFftPlanner;
108//! use rustfft::num_complex::Complex;
109//! use rustfft::num_traits::Zero;
110//!
111//! let length = 256;
112//!
113//! // make a planner
114//! let mut real_planner = RealFftPlanner::<f64>::new();
115//!
116//! // create a FFT
117//! let r2c = real_planner.plan_fft_forward(length);
118//! // make a dummy real-valued signal (filled with zeros)
119//! let mut indata = r2c.make_input_vec();
120//! // make a vector for storing the spectrum
121//! let mut spectrum = r2c.make_output_vec();
122//!
123//! // Are they the length we expect?
124//! assert_eq!(indata.len(), length);
125//! assert_eq!(spectrum.len(), length/2+1);
126//!
127//! // forward transform the signal
128//! r2c.process(&mut indata, &mut spectrum).unwrap();
129//!
130//! // create an inverse FFT
131//! let c2r = real_planner.plan_fft_inverse(length);
132//!
133//! // create a vector for storing the output
134//! let mut outdata = c2r.make_output_vec();
135//! assert_eq!(outdata.len(), length);
136//!
137//! // inverse transform the spectrum back to a real-valued signal
138//! c2r.process(&mut spectrum, &mut outdata).unwrap();
139//! ```
140//!
141//! ### Versions
142//! - 3.4.0: Fix undefined behavior reported by Miri.
143//!          Update to latest RustFFT.
144//! - 3.3.0: Add method for getting the length of the complex input/output.
145//!          Bugfix: clean up numerical noise in the zero imaginary components.
146//! - 3.2.0: Allow scratch buffer to be larger than needed.
147//! - 3.1.0: Update to RustFFT 6.1 with Neon support.
148//! - 3.0.2: Fix confusing typos in errors about scratch length.
149//! - 3.0.1: More helpful error messages, fix confusing typos.
150//! - 3.0.0: Improved error reporting.
151//! - 2.0.1: Minor bugfix.
152//! - 2.0.0: Update RustFFT to 6.0.0 and num-complex to 0.4.0.
153//! - 1.1.0: Add missing Sync+Send.
154//! - 1.0.0: First version with new api.
155//!
156//! ### Compatibility
157//!
158//! The `realfft` crate has the same rustc version requirements as RustFFT.
159//! The minimum rustc version is 1.61.
160
161pub use rustfft::num_complex;
162pub use rustfft::num_traits;
163pub use rustfft::FftNum;
164
165use rustfft::num_complex::Complex;
166use rustfft::num_traits::Zero;
167use rustfft::FftPlanner;
168use std::collections::HashMap;
169use std::error;
170use std::fmt;
171use std::sync::Arc;
172
173type Res<T> = Result<T, FftError>;
174
175/// Custom error returned by FFTs
176pub enum FftError {
177    /// The input buffer has the wrong size. The transform was not performed.
178    ///
179    /// The first member of the tuple is the expected size and the second member is the received
180    /// size.
181    InputBuffer(usize, usize),
182    /// The output buffer has the wrong size. The transform was not performed.
183    ///
184    /// The first member of the tuple is the expected size and the second member is the received
185    /// size.
186    OutputBuffer(usize, usize),
187    /// The scratch buffer has the wrong size. The transform was not performed.
188    ///
189    /// The first member of the tuple is the minimum size and the second member is the received
190    /// size.
191    ScratchBuffer(usize, usize),
192    /// The input data contained a non-zero imaginary part where there should have been a zero.
193    /// The transform was performed, but the result may not be correct.
194    ///
195    /// The first member of the tuple represents the first index of the complex buffer and the
196    /// second member represents the last index of the complex buffer. The values are set to true
197    /// if the corresponding complex value contains a non-zero imaginary part.
198    InputValues(bool, bool),
199}
200
201impl FftError {
202    fn fmt_internal(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        let desc = match self {
204            Self::InputBuffer(expected, got) => {
205                format!("Wrong length of input, expected {}, got {}", expected, got)
206            }
207            Self::OutputBuffer(expected, got) => {
208                format!("Wrong length of output, expected {}, got {}", expected, got)
209            }
210            Self::ScratchBuffer(expected, got) => {
211                format!(
212                    "Scratch buffer of size {} is too small, must be at least {} long",
213                    got, expected
214                )
215            }
216            Self::InputValues(first, last) => match (first, last) {
217                (true, false) => "Imaginary part of first value was non-zero.".to_string(),
218                (false, true) => "Imaginary part of last value was non-zero.".to_string(),
219                (true, true) => {
220                    "Imaginary parts of both first and last values were non-zero.".to_string()
221                }
222                (false, false) => unreachable!(),
223            },
224        };
225        write!(f, "{}", desc)
226    }
227}
228
229impl fmt::Debug for FftError {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231        self.fmt_internal(f)
232    }
233}
234
235impl fmt::Display for FftError {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        self.fmt_internal(f)
238    }
239}
240
241impl error::Error for FftError {}
242
243fn compute_twiddle<T: FftNum>(index: usize, fft_len: usize) -> Complex<T> {
244    let constant = -2f64 * std::f64::consts::PI / fft_len as f64;
245    let angle = constant * index as f64;
246    Complex {
247        re: T::from_f64(angle.cos()).unwrap(),
248        im: T::from_f64(angle.sin()).unwrap(),
249    }
250}
251
252pub struct RealToComplexOdd<T> {
253    length: usize,
254    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
255    scratch_len: usize,
256}
257
258pub struct RealToComplexEven<T> {
259    twiddles: Vec<Complex<T>>,
260    length: usize,
261    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
262    scratch_len: usize,
263}
264
265pub struct ComplexToRealOdd<T> {
266    length: usize,
267    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
268    scratch_len: usize,
269}
270
271pub struct ComplexToRealEven<T> {
272    twiddles: Vec<Complex<T>>,
273    length: usize,
274    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
275    scratch_len: usize,
276}
277
278/// A forward FFT that takes a real-valued input signal of length N
279/// and transforms it to a complex spectrum of length N/2+1.
280#[allow(clippy::len_without_is_empty)]
281pub trait RealToComplex<T>: Sync + Send {
282    /// Transform a signal of N real-valued samples,
283    /// storing the resulting complex spectrum in the N/2+1
284    /// (with N/2 rounded down) element long output slice.
285    /// The input buffer is used as scratch space,
286    /// so the contents of input should be considered garbage after calling.
287    /// It also allocates additional scratch space as needed.
288    /// An error is returned if any of the given slices has the wrong length.
289    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()>;
290
291    /// Transform a signal of N real-valued samples,
292    /// similar to [`process()`](RealToComplex::process).
293    /// The difference is that this method uses the provided
294    /// scratch buffer instead of allocating new scratch space.
295    /// This is faster if the same scratch buffer is used for multiple calls.
296    fn process_with_scratch(
297        &self,
298        input: &mut [T],
299        output: &mut [Complex<T>],
300        scratch: &mut [Complex<T>],
301    ) -> Res<()>;
302
303    /// Get the minimum length of the scratch buffer needed for `process_with_scratch`.
304    fn get_scratch_len(&self) -> usize;
305
306    /// The FFT length.
307    /// Get the length of the real signal that this FFT takes as input.
308    fn len(&self) -> usize;
309
310    /// Get the number of complex data points that this FFT returns.
311    fn complex_len(&self) -> usize {
312        self.len() / 2 + 1
313    }
314
315    /// Convenience method to make an input vector of the right type and length.
316    fn make_input_vec(&self) -> Vec<T>;
317
318    /// Convenience method to make an output vector of the right type and length.
319    fn make_output_vec(&self) -> Vec<Complex<T>>;
320
321    /// Convenience method to make a scratch vector of the right type and length.
322    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
323}
324
325/// An inverse FFT that takes a complex spectrum of length N/2+1
326/// and transforms it to a real-valued signal of length N.
327#[allow(clippy::len_without_is_empty)]
328pub trait ComplexToReal<T>: Sync + Send {
329    /// Inverse transform a complex spectrum corresponding to a real-valued signal of length N.
330    /// The input is a slice of complex values with length N/2+1 (with N/2 rounded down).
331    /// The resulting real-valued signal is stored in the output slice of length N.
332    /// The input buffer is used as scratch space,
333    /// so the contents of input should be considered garbage after calling.
334    /// It also allocates additional scratch space as needed.
335    /// An error is returned if any of the given slices has the wrong length.
336    /// If the input data is invalid, meaning that one of the positions that should
337    /// contain a zero holds a non-zero value, the transform is still performed.
338    /// The function then returns an `FftError::InputValues` error to tell that the
339    /// result may not be correct.
340    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()>;
341
342    /// Inverse transform a complex spectrum,
343    /// similar to [`process()`](ComplexToReal::process).
344    /// The difference is that this method uses the provided
345    /// scratch buffer instead of allocating new scratch space.
346    /// This is faster if the same scratch buffer is used for multiple calls.
347    fn process_with_scratch(
348        &self,
349        input: &mut [Complex<T>],
350        output: &mut [T],
351        scratch: &mut [Complex<T>],
352    ) -> Res<()>;
353
354    /// Get the minimum length of the scratch space needed for `process_with_scratch`.
355    fn get_scratch_len(&self) -> usize;
356
357    /// The FFT length.
358    /// Get the length of the real-valued signal that this FFT returns.
359    fn len(&self) -> usize;
360
361    /// Get the length of the slice slice of complex values that this FFT accepts as input.
362    fn complex_len(&self) -> usize {
363        self.len() / 2 + 1
364    }
365
366    /// Convenience method to make an input vector of the right type and length.
367    fn make_input_vec(&self) -> Vec<Complex<T>>;
368
369    /// Convenience method to make an output vector of the right type and length.
370    fn make_output_vec(&self) -> Vec<T>;
371
372    /// Convenience method to make a scratch vector of the right type and length.
373    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
374}
375
376fn zip3<A, B, C>(a: A, b: B, c: C) -> impl Iterator<Item = (A::Item, B::Item, C::Item)>
377where
378    A: IntoIterator,
379    B: IntoIterator,
380    C: IntoIterator,
381{
382    a.into_iter()
383        .zip(b.into_iter().zip(c))
384        .map(|(x, (y, z))| (x, y, z))
385}
386
387/// A planner is used to create FFTs.
388/// It caches results internally,
389/// so when making more than one FFT it is advisable to reuse the same planner.
390pub struct RealFftPlanner<T: FftNum> {
391    planner: FftPlanner<T>,
392    r2c_cache: HashMap<usize, Arc<dyn RealToComplex<T>>>,
393    c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<T>>>,
394}
395
396impl<T: FftNum> RealFftPlanner<T> {
397    /// Create a new planner.
398    pub fn new() -> Self {
399        let planner = FftPlanner::<T>::new();
400        Self {
401            r2c_cache: HashMap::new(),
402            c2r_cache: HashMap::new(),
403            planner,
404        }
405    }
406
407    /// Plan a real-to-complex forward FFT. Returns the FFT in a shared reference.
408    /// If requesting a second forward FFT of the same length,
409    /// the planner will return a new reference to the already existing one.
410    pub fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<T>> {
411        if let Some(fft) = self.r2c_cache.get(&len) {
412            Arc::clone(fft)
413        } else {
414            let fft = if len % 2 > 0 {
415                Arc::new(RealToComplexOdd::new(len, &mut self.planner)) as Arc<dyn RealToComplex<T>>
416            } else {
417                Arc::new(RealToComplexEven::new(len, &mut self.planner))
418                    as Arc<dyn RealToComplex<T>>
419            };
420            self.r2c_cache.insert(len, Arc::clone(&fft));
421            fft
422        }
423    }
424
425    /// Plan a complex-to-real inverse FFT. Returns the FFT in a shared reference.
426    /// If requesting a second inverse FFT of the same length,
427    /// the planner will return a new reference to the already existing one.
428    pub fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<T>> {
429        if let Some(fft) = self.c2r_cache.get(&len) {
430            Arc::clone(fft)
431        } else {
432            let fft = if len % 2 > 0 {
433                Arc::new(ComplexToRealOdd::new(len, &mut self.planner)) as Arc<dyn ComplexToReal<T>>
434            } else {
435                Arc::new(ComplexToRealEven::new(len, &mut self.planner))
436                    as Arc<dyn ComplexToReal<T>>
437            };
438            self.c2r_cache.insert(len, Arc::clone(&fft));
439            fft
440        }
441    }
442}
443
444impl<T: FftNum> Default for RealFftPlanner<T> {
445    fn default() -> Self {
446        Self::new()
447    }
448}
449
450impl<T: FftNum> RealToComplexOdd<T> {
451    /// Create a new RealToComplex forward FFT for real-valued input data of a given length,
452    /// and uses the given FftPlanner to build the inner FFT.
453    /// Panics if the length is not odd.
454    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
455        if length % 2 == 0 {
456            panic!("Length must be odd, got {}", length,);
457        }
458        let fft = fft_planner.plan_fft_forward(length);
459        let scratch_len = fft.get_inplace_scratch_len() + length;
460        RealToComplexOdd {
461            length,
462            fft,
463            scratch_len,
464        }
465    }
466}
467
468impl<T: FftNum> RealToComplex<T> for RealToComplexOdd<T> {
469    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
470        let mut scratch = self.make_scratch_vec();
471        self.process_with_scratch(input, output, &mut scratch)
472    }
473
474    fn process_with_scratch(
475        &self,
476        input: &mut [T],
477        output: &mut [Complex<T>],
478        scratch: &mut [Complex<T>],
479    ) -> Res<()> {
480        if input.len() != self.length {
481            return Err(FftError::InputBuffer(self.length, input.len()));
482        }
483        let expected_output_buffer_size = self.complex_len();
484        if output.len() != expected_output_buffer_size {
485            return Err(FftError::OutputBuffer(
486                expected_output_buffer_size,
487                output.len(),
488            ));
489        }
490        if scratch.len() < (self.scratch_len) {
491            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
492        }
493        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);
494
495        for (val, buf) in input.iter().zip(buffer.iter_mut()) {
496            *buf = Complex::new(*val, T::zero());
497        }
498        // FFT and store result in buffer_out
499        self.fft.process_with_scratch(buffer, fft_scratch);
500        output.copy_from_slice(&buffer[0..self.complex_len()]);
501        if let Some(elem) = output.first_mut() {
502            elem.im = T::zero();
503        }
504        Ok(())
505    }
506
507    fn get_scratch_len(&self) -> usize {
508        self.scratch_len
509    }
510
511    fn len(&self) -> usize {
512        self.length
513    }
514
515    fn make_input_vec(&self) -> Vec<T> {
516        vec![T::zero(); self.len()]
517    }
518
519    fn make_output_vec(&self) -> Vec<Complex<T>> {
520        vec![Complex::zero(); self.complex_len()]
521    }
522
523    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
524        vec![Complex::zero(); self.get_scratch_len()]
525    }
526}
527
528impl<T: FftNum> RealToComplexEven<T> {
529    /// Create a new RealToComplex forward FFT for real-valued input data of a given length,
530    /// and uses the given FftPlanner to build the inner FFT.
531    /// Panics if the length is not even.
532    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
533        if length % 2 > 0 {
534            panic!("Length must be even, got {}", length,);
535        }
536        let twiddle_count = if length % 4 == 0 {
537            length / 4
538        } else {
539            length / 4 + 1
540        };
541        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
542            .map(|i| compute_twiddle(i, length) * T::from_f64(0.5).unwrap())
543            .collect();
544        let fft = fft_planner.plan_fft_forward(length / 2);
545        let scratch_len = fft.get_outofplace_scratch_len();
546        RealToComplexEven {
547            twiddles,
548            length,
549            fft,
550            scratch_len,
551        }
552    }
553}
554
555impl<T: FftNum> RealToComplex<T> for RealToComplexEven<T> {
556    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
557        let mut scratch = self.make_scratch_vec();
558        self.process_with_scratch(input, output, &mut scratch)
559    }
560
561    fn process_with_scratch(
562        &self,
563        input: &mut [T],
564        output: &mut [Complex<T>],
565        scratch: &mut [Complex<T>],
566    ) -> Res<()> {
567        if input.len() != self.length {
568            return Err(FftError::InputBuffer(self.length, input.len()));
569        }
570        let expected_output_buffer_size = self.complex_len();
571        if output.len() != expected_output_buffer_size {
572            return Err(FftError::OutputBuffer(
573                expected_output_buffer_size,
574                output.len(),
575            ));
576        }
577        if scratch.len() < (self.scratch_len) {
578            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
579        }
580
581        let fftlen = self.length / 2;
582        let buf_in = unsafe {
583            let ptr = input.as_mut_ptr() as *mut Complex<T>;
584            let len = input.len();
585            std::slice::from_raw_parts_mut(ptr, len / 2)
586        };
587
588        // FFT and store result in buffer_out
589        self.fft
590            .process_outofplace_with_scratch(buf_in, &mut output[0..fftlen], scratch);
591        let (mut output_left, mut output_right) = output.split_at_mut(output.len() / 2);
592
593        // The first and last element don't require any twiddle factors, so skip that work
594        match (output_left.first_mut(), output_right.last_mut()) {
595            (Some(first_element), Some(last_element)) => {
596                // The first and last elements are just a sum and difference of the first value's real and imaginary values
597                let first_value = *first_element;
598                *first_element = Complex {
599                    re: first_value.re + first_value.im,
600                    im: T::zero(),
601                };
602                *last_element = Complex {
603                    re: first_value.re - first_value.im,
604                    im: T::zero(),
605                };
606
607                // Chop the first and last element off of our slices so that the loop below doesn't have to deal with them
608                output_left = &mut output_left[1..];
609                let right_len = output_right.len();
610                output_right = &mut output_right[..right_len - 1];
611            }
612            _ => {
613                return Ok(());
614            }
615        }
616        // Loop over the remaining elements and apply twiddle factors on them
617        for (twiddle, out, out_rev) in zip3(
618            self.twiddles.iter(),
619            output_left.iter_mut(),
620            output_right.iter_mut().rev(),
621        ) {
622            let sum = *out + *out_rev;
623            let diff = *out - *out_rev;
624            let half = T::from_f64(0.5).unwrap();
625            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
626            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
627            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
628            // multiplications in half.
629            let twiddled_re_sum = sum * twiddle.re;
630            let twiddled_im_sum = sum * twiddle.im;
631            let twiddled_re_diff = diff * twiddle.re;
632            let twiddled_im_diff = diff * twiddle.im;
633            let half_sum_re = half * sum.re;
634            let half_diff_im = half * diff.im;
635
636            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
637            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;
638
639            // We finally have all the data we need to write the transformed data back out where we found it.
640            *out = Complex {
641                re: half_sum_re + output_twiddled_real,
642                im: half_diff_im + output_twiddled_im,
643            };
644
645            *out_rev = Complex {
646                re: half_sum_re - output_twiddled_real,
647                im: output_twiddled_im - half_diff_im,
648            };
649        }
650
651        // If the output len is odd, the loop above can't postprocess the centermost element, so handle that separately.
652        if output.len() % 2 == 1 {
653            if let Some(center_element) = output.get_mut(output.len() / 2) {
654                center_element.im = -center_element.im;
655            }
656        }
657        Ok(())
658    }
659    fn get_scratch_len(&self) -> usize {
660        self.scratch_len
661    }
662
663    fn len(&self) -> usize {
664        self.length
665    }
666
667    fn make_input_vec(&self) -> Vec<T> {
668        vec![T::zero(); self.len()]
669    }
670
671    fn make_output_vec(&self) -> Vec<Complex<T>> {
672        vec![Complex::zero(); self.complex_len()]
673    }
674
675    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
676        vec![Complex::zero(); self.get_scratch_len()]
677    }
678}
679
680impl<T: FftNum> ComplexToRealOdd<T> {
681    /// Create a new ComplexToRealOdd inverse FFT for complex input spectra.
682    /// The `length` parameter refers to the length of the resulting real-valued signal.
683    /// Uses the given FftPlanner to build the inner FFT.
684    /// Panics if the length is not odd.
685    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
686        if length % 2 == 0 {
687            panic!("Length must be odd, got {}", length,);
688        }
689        let fft = fft_planner.plan_fft_inverse(length);
690        let scratch_len = length + fft.get_inplace_scratch_len();
691        ComplexToRealOdd {
692            length,
693            fft,
694            scratch_len,
695        }
696    }
697}
698
699impl<T: FftNum> ComplexToReal<T> for ComplexToRealOdd<T> {
700    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
701        let mut scratch = self.make_scratch_vec();
702        self.process_with_scratch(input, output, &mut scratch)
703    }
704
705    fn process_with_scratch(
706        &self,
707        input: &mut [Complex<T>],
708        output: &mut [T],
709        scratch: &mut [Complex<T>],
710    ) -> Res<()> {
711        let expected_input_buffer_size = self.complex_len();
712        if input.len() != expected_input_buffer_size {
713            return Err(FftError::InputBuffer(
714                expected_input_buffer_size,
715                input.len(),
716            ));
717        }
718        if output.len() != self.length {
719            return Err(FftError::OutputBuffer(self.length, output.len()));
720        }
721        if scratch.len() < (self.scratch_len) {
722            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
723        }
724
725        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
726            input[0].im = T::from_f64(0.0).unwrap();
727            true
728        } else {
729            false
730        };
731
732        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);
733
734        buffer[0..input.len()].copy_from_slice(input);
735        for (buf, val) in buffer
736            .iter_mut()
737            .rev()
738            .take(self.length / 2)
739            .zip(input.iter().skip(1))
740        {
741            *buf = val.conj();
742        }
743        self.fft.process_with_scratch(buffer, fft_scratch);
744        for (val, out) in buffer.iter().zip(output.iter_mut()) {
745            *out = val.re;
746        }
747        if first_invalid {
748            return Err(FftError::InputValues(true, false));
749        }
750        Ok(())
751    }
752
753    fn get_scratch_len(&self) -> usize {
754        self.scratch_len
755    }
756
757    fn len(&self) -> usize {
758        self.length
759    }
760
761    fn make_input_vec(&self) -> Vec<Complex<T>> {
762        vec![Complex::zero(); self.complex_len()]
763    }
764
765    fn make_output_vec(&self) -> Vec<T> {
766        vec![T::zero(); self.len()]
767    }
768
769    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
770        vec![Complex::zero(); self.get_scratch_len()]
771    }
772}
773
774impl<T: FftNum> ComplexToRealEven<T> {
775    /// Create a new ComplexToRealEven inverse FFT for complex input spectra.
776    /// The `length` parameter refers to the length of the resulting real-valued signal.
777    /// Uses the given FftPlanner to build the inner FFT.
778    /// Panics if the length is not even.
779    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
780        if length % 2 > 0 {
781            panic!("Length must be even, got {}", length,);
782        }
783        let twiddle_count = if length % 4 == 0 {
784            length / 4
785        } else {
786            length / 4 + 1
787        };
788        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
789            .map(|i| compute_twiddle(i, length).conj())
790            .collect();
791        let fft = fft_planner.plan_fft_inverse(length / 2);
792        let scratch_len = fft.get_outofplace_scratch_len();
793        ComplexToRealEven {
794            twiddles,
795            length,
796            fft,
797            scratch_len,
798        }
799    }
800}
801impl<T: FftNum> ComplexToReal<T> for ComplexToRealEven<T> {
802    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
803        let mut scratch = self.make_scratch_vec();
804        self.process_with_scratch(input, output, &mut scratch)
805    }
806
807    fn process_with_scratch(
808        &self,
809        input: &mut [Complex<T>],
810        output: &mut [T],
811        scratch: &mut [Complex<T>],
812    ) -> Res<()> {
813        let expected_input_buffer_size = self.complex_len();
814        if input.len() != expected_input_buffer_size {
815            return Err(FftError::InputBuffer(
816                expected_input_buffer_size,
817                input.len(),
818            ));
819        }
820        if output.len() != self.length {
821            return Err(FftError::OutputBuffer(self.length, output.len()));
822        }
823        if scratch.len() < (self.scratch_len) {
824            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
825        }
826        if input.is_empty() {
827            return Ok(());
828        }
829        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
830            input[0].im = T::from_f64(0.0).unwrap();
831            true
832        } else {
833            false
834        };
835        let last_invalid = if input[input.len() - 1].im != T::from_f64(0.0).unwrap() {
836            input[input.len() - 1].im = T::from_f64(0.0).unwrap();
837            true
838        } else {
839            false
840        };
841
842        let (mut input_left, mut input_right) = input.split_at_mut(input.len() / 2);
843
844        // We have to preprocess the input in-place before we send it to the FFT.
845        // The first and centermost values have to be preprocessed separately from the rest, so do that now.
846        match (input_left.first_mut(), input_right.last_mut()) {
847            (Some(first_input), Some(last_input)) => {
848                let first_sum = *first_input + *last_input;
849                let first_diff = *first_input - *last_input;
850
851                *first_input = Complex {
852                    re: first_sum.re - first_sum.im,
853                    im: first_diff.re - first_diff.im,
854                };
855
856                input_left = &mut input_left[1..];
857                let right_len = input_right.len();
858                input_right = &mut input_right[..right_len - 1];
859            }
860            _ => return Ok(()),
861        };
862
863        // now, in a loop, preprocess the rest of the elements 2 at a time.
864        for (twiddle, fft_input, fft_input_rev) in zip3(
865            self.twiddles.iter(),
866            input_left.iter_mut(),
867            input_right.iter_mut().rev(),
868        ) {
869            let sum = *fft_input + *fft_input_rev;
870            let diff = *fft_input - *fft_input_rev;
871
872            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
873            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
874            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
875            // multiplications in half.
876            let twiddled_re_sum = sum * twiddle.re;
877            let twiddled_im_sum = sum * twiddle.im;
878            let twiddled_re_diff = diff * twiddle.re;
879            let twiddled_im_diff = diff * twiddle.im;
880
881            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
882            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;
883
884            // We finally have all the data we need to write our preprocessed data back where we got it from.
885            *fft_input = Complex {
886                re: sum.re - output_twiddled_real,
887                im: diff.im - output_twiddled_im,
888            };
889            *fft_input_rev = Complex {
890                re: sum.re + output_twiddled_real,
891                im: -output_twiddled_im - diff.im,
892            }
893        }
894
895        // If the output len is odd, the loop above can't preprocess the centermost element, so handle that separately
896        if input.len() % 2 == 1 {
897            let center_element = input[input.len() / 2];
898            let doubled = center_element + center_element;
899            input[input.len() / 2] = doubled.conj();
900        }
901
902        // FFT and store result in buffer_out
903        let buf_out = unsafe {
904            let ptr = output.as_mut_ptr() as *mut Complex<T>;
905            let len = output.len();
906            std::slice::from_raw_parts_mut(ptr, len / 2)
907        };
908        self.fft
909            .process_outofplace_with_scratch(&mut input[..buf_out.len()], buf_out, scratch);
910        if first_invalid || last_invalid {
911            return Err(FftError::InputValues(first_invalid, last_invalid));
912        }
913        Ok(())
914    }
915
916    fn get_scratch_len(&self) -> usize {
917        self.scratch_len
918    }
919
920    fn len(&self) -> usize {
921        self.length
922    }
923
924    fn make_input_vec(&self) -> Vec<Complex<T>> {
925        vec![Complex::zero(); self.complex_len()]
926    }
927
928    fn make_output_vec(&self) -> Vec<T> {
929        vec![T::zero(); self.len()]
930    }
931
932    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
933        vec![Complex::zero(); self.get_scratch_len()]
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use crate::FftError;
940    use crate::RealFftPlanner;
941    use rand::Rng;
942    use rustfft::num_complex::Complex;
943    use rustfft::num_traits::{Float, Zero};
944    use rustfft::FftPlanner;
945    use std::error::Error;
946    use std::ops::Sub;
947
948    // get the largest difference
949    fn compare_complex<T: Float + Sub>(a: &[Complex<T>], b: &[Complex<T>]) -> T {
950        a.iter()
951            .zip(b.iter())
952            .fold(T::zero(), |maxdiff, (val_a, val_b)| {
953                let diff = (val_a - val_b).norm();
954                if maxdiff > diff {
955                    maxdiff
956                } else {
957                    diff
958                }
959            })
960    }
961
962    // get the largest difference
963    fn compare_scalars<T: Float + Sub>(a: &[T], b: &[T]) -> T {
964        a.iter()
965            .zip(b.iter())
966            .fold(T::zero(), |maxdiff, (val_a, val_b)| {
967                let diff = (*val_a - *val_b).abs();
968                if maxdiff > diff {
969                    maxdiff
970                } else {
971                    diff
972                }
973            })
974    }
975
976    // Compare ComplexToReal with standard inverse FFT
977    #[test]
978    fn complex_to_real_64() {
979        for length in 1..1000 {
980            let mut real_planner = RealFftPlanner::<f64>::new();
981            let c2r = real_planner.plan_fft_inverse(length);
982            let mut out_a = c2r.make_output_vec();
983            let mut indata = c2r.make_input_vec();
984            let mut rustfft_check: Vec<Complex<f64>> = vec![Complex::zero(); length];
985            let mut rng = rand::thread_rng();
986            for val in indata.iter_mut() {
987                *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
988            }
989            indata[0].im = 0.0;
990            if length % 2 == 0 {
991                indata[length / 2].im = 0.0;
992            }
993            for (val_long, val) in rustfft_check
994                .iter_mut()
995                .take(c2r.complex_len())
996                .zip(indata.iter())
997            {
998                *val_long = *val;
999            }
1000            for (val_long, val) in rustfft_check
1001                .iter_mut()
1002                .rev()
1003                .take(length / 2)
1004                .zip(indata.iter().skip(1))
1005            {
1006                *val_long = val.conj();
1007            }
1008            let mut fft_planner = FftPlanner::<f64>::new();
1009            let fft = fft_planner.plan_fft_inverse(length);
1010
1011            c2r.process(&mut indata, &mut out_a).unwrap();
1012            fft.process(&mut rustfft_check);
1013
1014            let check_real = rustfft_check.iter().map(|val| val.re).collect::<Vec<f64>>();
1015            let maxdiff = compare_scalars(&out_a, &check_real);
1016            assert!(
1017                maxdiff < 1.0e-9,
1018                "Length: {}, too large error: {}",
1019                length,
1020                maxdiff
1021            );
1022        }
1023    }
1024
1025    // Compare ComplexToReal with standard inverse FFT
1026    #[test]
1027    fn complex_to_real_32() {
1028        for length in 1..1000 {
1029            let mut real_planner = RealFftPlanner::<f32>::new();
1030            let c2r = real_planner.plan_fft_inverse(length);
1031            let mut out_a = c2r.make_output_vec();
1032            let mut indata = c2r.make_input_vec();
1033            let mut rustfft_check: Vec<Complex<f32>> = vec![Complex::zero(); length];
1034            let mut rng = rand::thread_rng();
1035            for val in indata.iter_mut() {
1036                *val = Complex::new(rng.gen::<f32>(), rng.gen::<f32>());
1037            }
1038            indata[0].im = 0.0;
1039            if length % 2 == 0 {
1040                indata[length / 2].im = 0.0;
1041            }
1042            for (val_long, val) in rustfft_check
1043                .iter_mut()
1044                .take(c2r.complex_len())
1045                .zip(indata.iter())
1046            {
1047                *val_long = *val;
1048            }
1049            for (val_long, val) in rustfft_check
1050                .iter_mut()
1051                .rev()
1052                .take(length / 2)
1053                .zip(indata.iter().skip(1))
1054            {
1055                *val_long = val.conj();
1056            }
1057            let mut fft_planner = FftPlanner::<f32>::new();
1058            let fft = fft_planner.plan_fft_inverse(length);
1059
1060            c2r.process(&mut indata, &mut out_a).unwrap();
1061            fft.process(&mut rustfft_check);
1062
1063            let check_real = rustfft_check.iter().map(|val| val.re).collect::<Vec<f32>>();
1064            let maxdiff = compare_scalars(&out_a, &check_real);
1065            assert!(
1066                maxdiff < 5.0e-4,
1067                "Length: {}, too large error: {}",
1068                length,
1069                maxdiff
1070            );
1071        }
1072    }
1073
1074    // Test that ComplexToReal returns the right errors
1075    #[test]
1076    fn complex_to_real_errors_even() {
1077        let length = 100;
1078        let mut real_planner = RealFftPlanner::<f64>::new();
1079        let c2r = real_planner.plan_fft_inverse(length);
1080        let mut out_a = c2r.make_output_vec();
1081        let mut indata = c2r.make_input_vec();
1082        let mut rng = rand::thread_rng();
1083
1084        // Make some valid data
1085        for val in indata.iter_mut() {
1086            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
1087        }
1088        indata[0].im = 0.0;
1089        indata[50].im = 0.0;
1090        // this should be ok
1091        assert!(c2r.process(&mut indata, &mut out_a).is_ok());
1092
1093        // Make some invalid data, first point invalid
1094        for val in indata.iter_mut() {
1095            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
1096        }
1097        indata[50].im = 0.0;
1098        let res = c2r.process(&mut indata, &mut out_a);
1099        assert!(res.is_err());
1100        assert!(matches!(res, Err(FftError::InputValues(true, false))));
1101
1102        // Make some invalid data, last point invalid
1103        for val in indata.iter_mut() {
1104            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
1105        }
1106        indata[0].im = 0.0;
1107        let res = c2r.process(&mut indata, &mut out_a);
1108        assert!(res.is_err());
1109        assert!(matches!(res, Err(FftError::InputValues(false, true))));
1110    }
1111
1112    // Test that ComplexToReal returns the right errors
1113    #[test]
1114    fn complex_to_real_errors_odd() {
1115        let length = 101;
1116        let mut real_planner = RealFftPlanner::<f64>::new();
1117        let c2r = real_planner.plan_fft_inverse(length);
1118        let mut out_a = c2r.make_output_vec();
1119        let mut indata = c2r.make_input_vec();
1120        let mut rng = rand::thread_rng();
1121
1122        // Make some valid data
1123        for val in indata.iter_mut() {
1124            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
1125        }
1126        indata[0].im = 0.0;
1127        // this should be ok
1128        assert!(c2r.process(&mut indata, &mut out_a).is_ok());
1129
1130        // Make some invalid data, first point invalid
1131        for val in indata.iter_mut() {
1132            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
1133        }
1134        let res = c2r.process(&mut indata, &mut out_a);
1135        assert!(res.is_err());
1136        assert!(matches!(res, Err(FftError::InputValues(true, false))));
1137    }
1138
1139    // Compare RealToComplex with standard FFT
1140    #[test]
1141    fn real_to_complex_64() {
1142        for length in 1..1000 {
1143            let mut real_planner = RealFftPlanner::<f64>::new();
1144            let r2c = real_planner.plan_fft_forward(length);
1145            let mut out_a = r2c.make_output_vec();
1146            let mut indata = r2c.make_input_vec();
1147            let mut rng = rand::thread_rng();
1148            for val in indata.iter_mut() {
1149                *val = rng.gen::<f64>();
1150            }
1151            let mut rustfft_check = indata
1152                .iter()
1153                .map(Complex::from)
1154                .collect::<Vec<Complex<f64>>>();
1155            let mut fft_planner = FftPlanner::<f64>::new();
1156            let fft = fft_planner.plan_fft_forward(length);
1157            fft.process(&mut rustfft_check);
1158            r2c.process(&mut indata, &mut out_a).unwrap();
1159            assert_eq!(out_a[0].im, 0.0, "First imaginary component must be zero");
1160            if length % 2 == 0 {
1161                assert_eq!(
1162                    out_a.last().unwrap().im,
1163                    0.0,
1164                    "Last imaginary component for even lengths must be zero"
1165                );
1166            }
1167            let maxdiff = compare_complex(&out_a, &rustfft_check[0..r2c.complex_len()]);
1168            assert!(
1169                maxdiff < 1.0e-9,
1170                "Length: {}, too large error: {}",
1171                length,
1172                maxdiff
1173            );
1174        }
1175    }
1176
1177    // Compare RealToComplex with standard FFT
1178    #[test]
1179    fn real_to_complex_32() {
1180        for length in 1..1000 {
1181            let mut real_planner = RealFftPlanner::<f32>::new();
1182            let r2c = real_planner.plan_fft_forward(length);
1183            let mut out_a = r2c.make_output_vec();
1184            let mut indata = r2c.make_input_vec();
1185            let mut rng = rand::thread_rng();
1186            for val in indata.iter_mut() {
1187                *val = rng.gen::<f32>();
1188            }
1189            let mut rustfft_check = indata
1190                .iter()
1191                .map(Complex::from)
1192                .collect::<Vec<Complex<f32>>>();
1193            let mut fft_planner = FftPlanner::<f32>::new();
1194            let fft = fft_planner.plan_fft_forward(length);
1195            fft.process(&mut rustfft_check);
1196            r2c.process(&mut indata, &mut out_a).unwrap();
1197            assert_eq!(out_a[0].im, 0.0, "First imaginary component must be zero");
1198            if length % 2 == 0 {
1199                assert_eq!(
1200                    out_a.last().unwrap().im,
1201                    0.0,
1202                    "Last imaginary component for even lengths must be zero"
1203                );
1204            }
1205            let maxdiff = compare_complex(&out_a, &rustfft_check[0..r2c.complex_len()]);
1206            assert!(
1207                maxdiff < 5.0e-4,
1208                "Length: {}, too large error: {}",
1209                length,
1210                maxdiff
1211            );
1212        }
1213    }
1214
1215    // Check that the ? operator works on the custom errors. No need to run, just needs to compile.
1216    #[allow(dead_code)]
1217    fn test_error() -> Result<(), Box<dyn Error>> {
1218        let mut real_planner = RealFftPlanner::<f64>::new();
1219        let r2c = real_planner.plan_fft_forward(100);
1220        let mut out_a = r2c.make_output_vec();
1221        let mut indata = r2c.make_input_vec();
1222        r2c.process(&mut indata, &mut out_a)?;
1223        Ok(())
1224    }
1225}