realfft 3.2.0

Real-to-complex FFT and complex-to-real iFFT for Rust
Documentation
//! # RealFFT: Real-to-complex FFT and complex-to-real iFFT based on RustFFT
//!
//! This library is a wrapper for [RustFFT](https://crates.io/crates/rustfft) that enables performing FFT of real-valued data.
//! The API is designed to be as similar as possible to RustFFT.
//!
//! Using this library instead of RustFFT directly avoids the need of converting real-valued data to complex before performing a FFT.
//! If the length is even, it also enables faster computations by using a complex FFT of half the length.
//! It then packs a 2N long real vector into an N long complex vector, which is transformed using a standard FFT.
//! The FFT result is then post-processed to give only the first half of the complex spectrum, as an N+1 long complex vector.
//!
//! The iFFT goes through the same steps backwards, to transform an N+1 long complex spectrum to a 2N long real result.
//!
//! The speed increase compared to just converting the input to a 2N long complex vector
//! and using a 2N long FFT depends on the length of the input data.
//! The largest improvements are for long FFTs and for lengths over around 1000 elements there is an improvement of about a factor 2.
//! The difference shrinks for shorter lengths, and around 30 elements there is no longer any difference.  
//!
//! ## Why use real-to-complex FFT?
//! ### Using a complex-to-complex FFT
//! A simple way to get the FFT of a real valued vector is to convert it to complex, and using a complex-to-complex FFT.
//!
//! Let's assume `x` is a 6 element long real vector:
//! ```text
//! x = [x0r, x1r, x2r, x3r, x4r, x5r]
//! ```
//!
//! We now convert `x` to complex by adding an imaginary part with value zero. Using the notation `(xNr, xNi)` for the complex value `xN`, this becomes:
//! ```text
//! x_c = [(x0r, 0), (x1r, 0), (x2r, 0), (x3r, 0), (x4r, 0), (x5r, 0)]
//! ```
//!
//! Performing a normal complex FFT, the result of `FFT(x_c)` is:
//! ```text
//! FFT(x_c) = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)]
//! ```
//!
//! But because our `x_c` is real-valued (all imaginary parts are zero), some of this becomes redundant:
//! ```text
//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0), (X2r, -X2i), (X1r, -X1i)]
//! ```
//!
//! The last two values are the complex conjugates of `X1` and `X2`. Additionally, `X0i` and `X3i` are zero.
//! As we can see, the output contains 6 independent values, and the rest is redundant.
//! But it still takes time for the FFT to calculate the redundant values.
//! Converting the input data to complex also takes a little bit of time.
//!
//! If the length of `x` instead had been 7, result would have been:
//! ```text
//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X3r, -X3i), (X2r, -X2i), (X1r, -X1i)]
//! ```
//!
//! The result is similar, but this time there is no zero at `X3i`. Also in this case we got the same number of independent values as we started with.
//!
//! ### Real-to-complex
//! Using a real-to-complex FFT removes the need for converting the input data to complex.
//! It also avoids calculating the redundant output values.
//!
//! The result for 6 elements is:
//! ```text
//! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0)]
//! ```
//!
//! The result for 7 elements is:
//! ```text
//! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i)]
//! ```
//!
//! This is the data layout output by the real-to-complex FFT, and the one expected as input to the complex-to-real iFFT.
//!
//! ## Scaling
//! RealFFT matches the behaviour of RustFFT and does not normalize the output of either FFT of iFFT. To get normalized results, each element must be scaled by `1/sqrt(length)`. If the processing involves both an FFT and an iFFT step, it is advisable to merge the two normalization steps to a single, by scaling by `1/length`.
//!
//! ## Documentation
//!
//! The full documentation can be generated by rustdoc. To generate and view it run:
//! ```text
//! cargo doc --open
//! ```
//!
//! ## Benchmarks
//!
//! To run a set of benchmarks comparing real-to-complex FFT with standard complex-to-complex, type:
//! ```text
//! cargo bench
//! ```
//! The results are printed while running, and are compiled into an html report containing much more details.
//! To view, open `target/criterion/report/index.html` in a browser.
//!
//! ## Example
//! Transform a vector, and then inverse transform the result.
//! ```
//! use realfft::RealFftPlanner;
//! use rustfft::num_complex::Complex;
//! use rustfft::num_traits::Zero;
//!
//! let length = 256;
//!
//! // make a planner
//! let mut real_planner = RealFftPlanner::<f64>::new();
//!
//! // create a FFT
//! let r2c = real_planner.plan_fft_forward(length);
//! // make input and output vectors
//! let mut indata = r2c.make_input_vec();
//! let mut spectrum = r2c.make_output_vec();
//!
//! // Are they the length we expect?
//! assert_eq!(indata.len(), length);
//! assert_eq!(spectrum.len(), length/2+1);
//!
//! // Forward transform the input data
//! r2c.process(&mut indata, &mut spectrum).unwrap();
//!
//! // create an iFFT and an output vector
//! let c2r = real_planner.plan_fft_inverse(length);
//! let mut outdata = c2r.make_output_vec();
//! assert_eq!(outdata.len(), length);
//!
//! c2r.process(&mut spectrum, &mut outdata).unwrap();
//! ```
//!
//! ### Versions
//! - 3.2.0: Allow scratch buffer to be larger than needed.
//! - 3.1.0: Update to RustFFT 6.1 with Neon support.
//! - 3.0.2: Fix confusing typos in errors about scratch length.
//! - 3.0.1: More helpful error messages, fix confusing typos.
//! - 3.0.0: Improved error reporting.
//! - 2.0.1: Minor bugfix.
//! - 2.0.0: Update RustFFT to 6.0.0 and num-complex to 0.4.0.
//! - 1.1.0: Add missing Sync+Send.
//! - 1.0.0: First version with new api.
//!
//! ### Compatibility
//!
//! The `realfft` crate has the same rustc version requirements as RustFFT.
//! The minimum rustc version is 1.37 on all platforms except AArch64.
//! On AArch64 the minimum rustc version is 1.61.

pub use rustfft::num_complex;
pub use rustfft::num_traits;
pub use rustfft::FftNum;

use rustfft::num_complex::Complex;
use rustfft::num_traits::Zero;
use rustfft::FftPlanner;
use std::collections::HashMap;
use std::error;
use std::fmt;
use std::sync::Arc;

type Res<T> = Result<T, FftError>;

/// Custom error returned by FFTs
pub enum FftError {
    /// The input buffer has the wrong size. The transform was not performed.
    ///
    /// The first member of the tuple is the expected size and the second member is the received
    /// size.
    InputBuffer(usize, usize),
    /// The output buffer has the wrong size. The transform was not performed.
    ///
    /// The first member of the tuple is the expected size and the second member is the received
    /// size.
    OutputBuffer(usize, usize),
    /// The scratch buffer has the wrong size. The transform was not performed.
    ///
    /// The first member of the tuple is the minimum size and the second member is the received
    /// size.
    ScratchBuffer(usize, usize),
    /// The input data contained a non-zero imaginary part where there should have been a zero.
    /// The transform was performed, but the result may not be correct.
    ///
    /// The first member of the tuple represents the first index of the complex buffer and the
    /// second member represents the last index of the complex buffer. The values are set to true
    /// if the corresponding complex value contains a non-zero imaginary part.
    InputValues(bool, bool),
}

impl FftError {
    fn fmt_internal(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let desc = match self {
            Self::InputBuffer(expected, got) => {
                format!("Wrong length of input, expected {}, got {}", expected, got)
            }
            Self::OutputBuffer(expected, got) => {
                format!("Wrong length of output, expected {}, got {}", expected, got)
            }
            Self::ScratchBuffer(expected, got) => {
                format!(
                    "Scratch buffer of size {} is too small, must be at least {} long",
                    got, expected
                )
            }
            Self::InputValues(first, last) => match (first, last) {
                (true, false) => "Imaginary part of first value was non-zero.".to_string(),
                (false, true) => "Imaginary part of last value was non-zero.".to_string(),
                (true, true) => {
                    "Imaginary parts of both first and last values were non-zero.".to_string()
                }
                (false, false) => unreachable!(),
            },
        };
        write!(f, "{}", desc)
    }
}

impl fmt::Debug for FftError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.fmt_internal(f)
    }
}

impl fmt::Display for FftError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.fmt_internal(f)
    }
}

impl error::Error for FftError {}

fn compute_twiddle<T: FftNum>(index: usize, fft_len: usize) -> Complex<T> {
    let constant = -2f64 * std::f64::consts::PI / fft_len as f64;
    let angle = constant * index as f64;
    Complex {
        re: T::from_f64(angle.cos()).unwrap(),
        im: T::from_f64(angle.sin()).unwrap(),
    }
}

pub struct RealToComplexOdd<T> {
    length: usize,
    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
    scratch_len: usize,
}

pub struct RealToComplexEven<T> {
    twiddles: Vec<Complex<T>>,
    length: usize,
    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
    scratch_len: usize,
}

pub struct ComplexToRealOdd<T> {
    length: usize,
    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
    scratch_len: usize,
}

pub struct ComplexToRealEven<T> {
    twiddles: Vec<Complex<T>>,
    length: usize,
    fft: std::sync::Arc<dyn rustfft::Fft<T>>,
    scratch_len: usize,
}

/// An FFT that takes a real-valued input vector of length 2*N and transforms it to a complex
/// spectrum of length N+1.
#[allow(clippy::len_without_is_empty)]
pub trait RealToComplex<T>: Sync + Send {
    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()>;

    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    fn process_with_scratch(
        &self,
        input: &mut [T],
        output: &mut [Complex<T>],
        scratch: &mut [Complex<T>],
    ) -> Res<()>;

    /// Get the minimum length of the scratch space needed for `process_with_scratch`.
    fn get_scratch_len(&self) -> usize;

    /// Get the number of points that this FFT can process.
    fn len(&self) -> usize;

    /// Convenience method to make an input vector of the right type and length.
    fn make_input_vec(&self) -> Vec<T>;

    /// Convenience method to make an output vector of the right type and length.
    fn make_output_vec(&self) -> Vec<Complex<T>>;

    /// Convenience method to make a scratch vector of the right type and length.
    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
}

/// An FFT that takes a complex-valued input vector of length N+1 and transforms it to a complex
/// spectrum of length 2*N.
#[allow(clippy::len_without_is_empty)]
pub trait ComplexToReal<T>: Sync + Send {
    /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// the transform is still performed. The function then returns an `FftError::InputValues` error to tell that the
    /// result may not be correct.
    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()>;

    /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the 2*N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// the transform is still performed. The function then returns an `FftError::InputValues` error to tell that the
    /// result may not be correct.
    fn process_with_scratch(
        &self,
        input: &mut [Complex<T>],
        output: &mut [T],
        scratch: &mut [Complex<T>],
    ) -> Res<()>;

    /// Get the minimum length of the scratch space needed for `process_with_scratch`.
    fn get_scratch_len(&self) -> usize;

    /// Get the number of points that this FFT can process.
    fn len(&self) -> usize;

    /// Convenience method to make an input vector of the right type and length.
    fn make_input_vec(&self) -> Vec<Complex<T>>;

    /// Convenience method to make an output vector of the right type and length.
    fn make_output_vec(&self) -> Vec<T>;

    /// Convenience method to make a scratch vector of the right type and length.
    fn make_scratch_vec(&self) -> Vec<Complex<T>>;
}

fn zip3<A, B, C>(a: A, b: B, c: C) -> impl Iterator<Item = (A::Item, B::Item, C::Item)>
where
    A: IntoIterator,
    B: IntoIterator,
    C: IntoIterator,
{
    a.into_iter()
        .zip(b.into_iter().zip(c))
        .map(|(x, (y, z))| (x, y, z))
}

/// A planner is used to create FFTs. It caches results internally,
/// so when making more than one FFT it is advisable to reuse the same planner.
pub struct RealFftPlanner<T: FftNum> {
    planner: FftPlanner<T>,
    r2c_cache: HashMap<usize, Arc<dyn RealToComplex<T>>>,
    c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<T>>>,
}

impl<T: FftNum> RealFftPlanner<T> {
    /// Create a new planner.
    pub fn new() -> Self {
        let planner = FftPlanner::<T>::new();
        Self {
            r2c_cache: HashMap::new(),
            c2r_cache: HashMap::new(),
            planner,
        }
    }

    /// Plan a Real-to-Complex forward FFT. Returns the FFT in a shared reference.
    /// If requesting a second FFT of the same length, this will return a new reference to the already existing one.
    pub fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<T>> {
        if let Some(fft) = self.r2c_cache.get(&len) {
            Arc::clone(fft)
        } else {
            let fft = if len % 2 > 0 {
                Arc::new(RealToComplexOdd::new(len, &mut self.planner)) as Arc<dyn RealToComplex<T>>
            } else {
                Arc::new(RealToComplexEven::new(len, &mut self.planner))
                    as Arc<dyn RealToComplex<T>>
            };
            self.r2c_cache.insert(len, Arc::clone(&fft));
            fft
        }
    }

    /// Plan a Complex-to-Real inverse FFT. Returns the FFT in a shared reference.
    /// If requesting a second FFT of the same length, this will return a new reference to the already existing one.
    pub fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<T>> {
        if let Some(fft) = self.c2r_cache.get(&len) {
            Arc::clone(fft)
        } else {
            let fft = if len % 2 > 0 {
                Arc::new(ComplexToRealOdd::new(len, &mut self.planner)) as Arc<dyn ComplexToReal<T>>
            } else {
                Arc::new(ComplexToRealEven::new(len, &mut self.planner))
                    as Arc<dyn ComplexToReal<T>>
            };
            self.c2r_cache.insert(len, Arc::clone(&fft));
            fft
        }
    }
}

impl<T: FftNum> Default for RealFftPlanner<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T: FftNum> RealToComplexOdd<T> {
    /// Create a new RealToComplex FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT.
    /// Panics if the length is not odd.
    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
        if length % 2 == 0 {
            panic!("Length must be odd, got {}", length,);
        }
        let fft = fft_planner.plan_fft_forward(length);
        let scratch_len = fft.get_inplace_scratch_len() + length;
        RealToComplexOdd {
            length,
            fft,
            scratch_len,
        }
    }
}

impl<T: FftNum> RealToComplex<T> for RealToComplexOdd<T> {
    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
        let mut scratch = self.make_scratch_vec();
        self.process_with_scratch(input, output, &mut scratch)
    }

    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    fn process_with_scratch(
        &self,
        input: &mut [T],
        output: &mut [Complex<T>],
        scratch: &mut [Complex<T>],
    ) -> Res<()> {
        if input.len() != self.length {
            return Err(FftError::InputBuffer(self.length, input.len()));
        }
        let expected_output_buffer_size = self.length / 2 + 1;
        if output.len() != expected_output_buffer_size {
            return Err(FftError::OutputBuffer(
                expected_output_buffer_size,
                output.len(),
            ));
        }
        if scratch.len() < (self.scratch_len) {
            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
        }
        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);

        for (val, buf) in input.iter().zip(buffer.iter_mut()) {
            *buf = Complex::new(*val, T::zero());
        }
        // FFT and store result in buffer_out
        self.fft.process_with_scratch(buffer, fft_scratch);
        output.copy_from_slice(&buffer[0..self.length / 2 + 1]);
        Ok(())
    }

    fn get_scratch_len(&self) -> usize {
        self.scratch_len
    }

    fn len(&self) -> usize {
        self.length
    }

    fn make_input_vec(&self) -> Vec<T> {
        vec![T::zero(); self.len()]
    }

    fn make_output_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.len() / 2 + 1]
    }

    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.get_scratch_len()]
    }
}

impl<T: FftNum> RealToComplexEven<T> {
    /// Create a new RealToComplex FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT.
    /// Panics if the length is not even.
    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
        if length % 2 > 0 {
            panic!("Length must be even, got {}", length,);
        }
        let twiddle_count = if length % 4 == 0 {
            length / 4
        } else {
            length / 4 + 1
        };
        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
            .map(|i| compute_twiddle(i, length) * T::from_f64(0.5).unwrap())
            .collect();
        //let mut fft_planner = FftPlanner::<T>::new();
        let fft = fft_planner.plan_fft_forward(length / 2);
        let scratch_len = fft.get_outofplace_scratch_len();
        RealToComplexEven {
            twiddles,
            length,
            fft,
            scratch_len,
        }
    }
}

impl<T: FftNum> RealToComplex<T> for RealToComplexEven<T> {
    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Res<()> {
        let mut scratch = self.make_scratch_vec();
        self.process_with_scratch(input, output, &mut scratch)
    }

    /// Transform a vector of N real-valued samples, storing the result in the N/2+1 element long complex output vector.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    fn process_with_scratch(
        &self,
        input: &mut [T],
        output: &mut [Complex<T>],
        scratch: &mut [Complex<T>],
    ) -> Res<()> {
        if input.len() != self.length {
            return Err(FftError::InputBuffer(self.length, input.len()));
        }
        let expected_output_buffer_size = self.length / 2 + 1;
        if output.len() != expected_output_buffer_size {
            return Err(FftError::OutputBuffer(
                expected_output_buffer_size,
                output.len(),
            ));
        }
        if scratch.len() < (self.scratch_len) {
            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
        }

        let fftlen = self.length / 2;
        let buf_in = unsafe {
            let ptr = input.as_mut_ptr() as *mut Complex<T>;
            let len = input.len();
            std::slice::from_raw_parts_mut(ptr, len / 2)
        };

        // FFT and store result in buffer_out
        self.fft
            .process_outofplace_with_scratch(buf_in, &mut output[0..fftlen], scratch);
        let (mut output_left, mut output_right) = output.split_at_mut(output.len() / 2);

        // The first and last element don't require any twiddle factors, so skip that work
        match (output_left.first_mut(), output_right.last_mut()) {
            (Some(first_element), Some(last_element)) => {
                // The first and last elements are just a sum and difference of the first value's real and imaginary values
                let first_value = *first_element;
                *first_element = Complex {
                    re: first_value.re + first_value.im,
                    im: T::zero(),
                };
                *last_element = Complex {
                    re: first_value.re - first_value.im,
                    im: T::zero(),
                };

                // Chop the first and last element off of our slices so that the loop below doesn't have to deal with them
                output_left = &mut output_left[1..];
                let right_len = output_right.len();
                output_right = &mut output_right[..right_len - 1];
            }
            _ => {
                return Ok(());
            }
        }
        // Loop over the remaining elements and apply twiddle factors on them
        for (twiddle, out, out_rev) in zip3(
            self.twiddles.iter(),
            output_left.iter_mut(),
            output_right.iter_mut().rev(),
        ) {
            let sum = *out + *out_rev;
            let diff = *out - *out_rev;
            let half = T::from_f64(0.5).unwrap();
            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
            // multiplications in half.
            let twiddled_re_sum = sum * twiddle.re;
            let twiddled_im_sum = sum * twiddle.im;
            let twiddled_re_diff = diff * twiddle.re;
            let twiddled_im_diff = diff * twiddle.im;
            let half_sum_re = half * sum.re;
            let half_diff_im = half * diff.im;

            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;

            // We finally have all the data we need to write the transformed data back out where we found it.
            *out = Complex {
                re: half_sum_re + output_twiddled_real,
                im: half_diff_im + output_twiddled_im,
            };

            *out_rev = Complex {
                re: half_sum_re - output_twiddled_real,
                im: output_twiddled_im - half_diff_im,
            };
        }

        // If the output len is odd, the loop above can't postprocess the centermost element, so handle that separately.
        if output.len() % 2 == 1 {
            if let Some(center_element) = output.get_mut(output.len() / 2) {
                center_element.im = -center_element.im;
            }
        }
        Ok(())
    }
    fn get_scratch_len(&self) -> usize {
        self.scratch_len
    }

    fn len(&self) -> usize {
        self.length
    }

    fn make_input_vec(&self) -> Vec<T> {
        vec![T::zero(); self.len()]
    }

    fn make_output_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.len() / 2 + 1]
    }

    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.get_scratch_len()]
    }
}

impl<T: FftNum> ComplexToRealOdd<T> {
    /// Create a new ComplexToReal FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT.
    /// Panics if the length is not odd.
    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
        if length % 2 == 0 {
            panic!("Length must be odd, got {}", length,);
        }
        //let mut fft_planner = FftPlanner::<T>::new();
        let fft = fft_planner.plan_fft_inverse(length);
        let scratch_len = length + fft.get_inplace_scratch_len();
        ComplexToRealOdd {
            length,
            fft,
            scratch_len,
        }
    }
}

impl<T: FftNum> ComplexToReal<T> for ComplexToRealOdd<T> {
    /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// these non-zero values are ignored and the transform is still performed.
    /// The function then returns an `FftError::InputValues` error to tell that the result may not be correct.
    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
        let mut scratch = self.make_scratch_vec();
        self.process_with_scratch(input, output, &mut scratch)
    }

    /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// these non-zero values are ignored and the transform is still performed.
    /// The function then returns an `FftError::InputValues` error to tell that the result may not be correct.
    fn process_with_scratch(
        &self,
        input: &mut [Complex<T>],
        output: &mut [T],
        scratch: &mut [Complex<T>],
    ) -> Res<()> {
        let expected_input_buffer_size = self.length / 2 + 1;
        if input.len() != expected_input_buffer_size {
            return Err(FftError::InputBuffer(
                expected_input_buffer_size,
                input.len(),
            ));
        }
        if output.len() != self.length {
            return Err(FftError::OutputBuffer(self.length, output.len()));
        }
        if scratch.len() < (self.scratch_len) {
            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
        }

        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
            input[0].im = T::from_f64(0.0).unwrap();
            true
        } else {
            false
        };

        let (buffer, fft_scratch) = scratch.split_at_mut(self.length);

        buffer[0..input.len()].copy_from_slice(input);
        for (buf, val) in buffer
            .iter_mut()
            .rev()
            .take(self.length / 2)
            .zip(input.iter().skip(1))
        {
            *buf = val.conj();
            //buf.im = -val.im;
        }
        self.fft.process_with_scratch(buffer, fft_scratch);
        for (val, out) in buffer.iter().zip(output.iter_mut()) {
            *out = val.re;
        }
        if first_invalid {
            return Err(FftError::InputValues(true, false));
        }
        Ok(())
    }

    fn get_scratch_len(&self) -> usize {
        self.scratch_len
    }

    fn len(&self) -> usize {
        self.length
    }

    fn make_input_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.len() / 2 + 1]
    }

    fn make_output_vec(&self) -> Vec<T> {
        vec![T::zero(); self.len()]
    }

    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.get_scratch_len()]
    }
}

impl<T: FftNum> ComplexToRealEven<T> {
    /// Create a new ComplexToReal FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT.
    /// Panics if the length is not even.
    pub fn new(length: usize, fft_planner: &mut FftPlanner<T>) -> Self {
        if length % 2 > 0 {
            panic!("Length must be even, got {}", length,);
        }
        let twiddle_count = if length % 4 == 0 {
            length / 4
        } else {
            length / 4 + 1
        };
        let twiddles: Vec<Complex<T>> = (1..twiddle_count)
            .map(|i| compute_twiddle(i, length).conj())
            .collect();
        //let mut fft_planner = FftPlanner::<T>::new();
        let fft = fft_planner.plan_fft_inverse(length / 2);
        let scratch_len = fft.get_outofplace_scratch_len();
        ComplexToRealEven {
            twiddles,
            length,
            fft,
            scratch_len,
        }
    }
}
impl<T: FftNum> ComplexToReal<T> for ComplexToRealEven<T> {
    /// Transform a complex spectrum of N/2+1 values and store the real result in the N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also allocates additional scratch space as needed.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// these non-zero values are ignored and the transform is still performed.
    /// The function then returns an `FftError::InputValues` error to tell that the result may not be correct.
    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Res<()> {
        let mut scratch = self.make_scratch_vec();
        self.process_with_scratch(input, output, &mut scratch)
    }

    /// Transform a complex spectrum of N/2+1 values and store the real result in the N long output.
    /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling.
    /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once.
    /// An error is returned if any of the given slices has the wrong length.
    /// If the input data is invalid, meaning that one of the positions that should contain a zero holds a different value,
    /// these non-zero values are ignored and the transform is still performed.
    /// The function then returns an `FftError::InputValues` error to tell that the result may not be correct.
    fn process_with_scratch(
        &self,
        input: &mut [Complex<T>],
        output: &mut [T],
        scratch: &mut [Complex<T>],
    ) -> Res<()> {
        let expected_input_buffer_size = self.length / 2 + 1;
        if input.len() != expected_input_buffer_size {
            return Err(FftError::InputBuffer(
                expected_input_buffer_size,
                input.len(),
            ));
        }
        if output.len() != self.length {
            return Err(FftError::OutputBuffer(self.length, output.len()));
        }
        if scratch.len() < (self.scratch_len) {
            return Err(FftError::ScratchBuffer(self.scratch_len, scratch.len()));
        }
        if input.is_empty() {
            return Ok(());
        }
        let first_invalid = if input[0].im != T::from_f64(0.0).unwrap() {
            input[0].im = T::from_f64(0.0).unwrap();
            true
        } else {
            false
        };
        let last_invalid = if input[input.len() - 1].im != T::from_f64(0.0).unwrap() {
            input[input.len() - 1].im = T::from_f64(0.0).unwrap();
            true
        } else {
            false
        };

        let (mut input_left, mut input_right) = input.split_at_mut(input.len() / 2);

        // We have to preprocess the input in-place before we send it to the FFT.
        // The first and centermost values have to be preprocessed separately from the rest, so do that now.
        match (input_left.first_mut(), input_right.last_mut()) {
            (Some(first_input), Some(last_input)) => {
                let first_sum = *first_input + *last_input;
                let first_diff = *first_input - *last_input;

                *first_input = Complex {
                    re: first_sum.re - first_sum.im,
                    im: first_diff.re - first_diff.im,
                };

                input_left = &mut input_left[1..];
                let right_len = input_right.len();
                input_right = &mut input_right[..right_len - 1];
            }
            _ => return Ok(()),
        };

        // now, in a loop, preprocess the rest of the elements 2 at a time.
        for (twiddle, fft_input, fft_input_rev) in zip3(
            self.twiddles.iter(),
            input_left.iter_mut(),
            input_right.iter_mut().rev(),
        ) {
            let sum = *fft_input + *fft_input_rev;
            let diff = *fft_input - *fft_input_rev;

            // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning
            // and one for the end. But the twiddle factor for the end is just the twiddle for the beginning, with the
            // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of
            // multiplications in half.
            let twiddled_re_sum = sum * twiddle.re;
            let twiddled_im_sum = sum * twiddle.im;
            let twiddled_re_diff = diff * twiddle.re;
            let twiddled_im_diff = diff * twiddle.im;

            let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re;
            let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re;

            // We finally have all the data we need to write our preprocessed data back where we got it from.
            *fft_input = Complex {
                re: sum.re - output_twiddled_real,
                im: diff.im - output_twiddled_im,
            };
            *fft_input_rev = Complex {
                re: sum.re + output_twiddled_real,
                im: -output_twiddled_im - diff.im,
            }
        }

        // If the output len is odd, the loop above can't preprocess the centermost element, so handle that separately
        if input.len() % 2 == 1 {
            let center_element = input[input.len() / 2];
            let doubled = center_element + center_element;
            input[input.len() / 2] = doubled.conj();
        }

        // FFT and store result in buffer_out
        let buf_out = unsafe {
            let ptr = output.as_mut_ptr() as *mut Complex<T>;
            let len = output.len();
            std::slice::from_raw_parts_mut(ptr, len / 2)
        };
        self.fft
            .process_outofplace_with_scratch(&mut input[..output.len() / 2], buf_out, scratch);
        if first_invalid || last_invalid {
            return Err(FftError::InputValues(first_invalid, last_invalid));
        }
        Ok(())
    }

    fn get_scratch_len(&self) -> usize {
        self.scratch_len
    }

    fn len(&self) -> usize {
        self.length
    }

    fn make_input_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.len() / 2 + 1]
    }

    fn make_output_vec(&self) -> Vec<T> {
        vec![T::zero(); self.len()]
    }

    fn make_scratch_vec(&self) -> Vec<Complex<T>> {
        vec![Complex::zero(); self.get_scratch_len()]
    }
}

#[cfg(test)]
mod tests {
    use crate::FftError;
    use crate::RealFftPlanner;
    use rand::Rng;
    use rustfft::num_complex::Complex;
    use rustfft::num_traits::Zero;
    use rustfft::FftPlanner;
    use std::error::Error;

    // get the largest difference
    fn compare_complex(a: &[Complex<f64>], b: &[Complex<f64>]) -> f64 {
        a.iter().zip(b.iter()).fold(0.0, |maxdiff, (val_a, val_b)| {
            let diff = (val_a - val_b).norm();
            if maxdiff > diff {
                maxdiff
            } else {
                diff
            }
        })
    }

    // get the largest difference
    fn compare_f64(a: &[f64], b: &[f64]) -> f64 {
        a.iter().zip(b.iter()).fold(0.0, |maxdiff, (val_a, val_b)| {
            let diff = (val_a - val_b).abs();
            if maxdiff > diff {
                maxdiff
            } else {
                diff
            }
        })
    }

    // Compare ComplexToReal with standard iFFT
    #[test]
    fn complex_to_real() {
        for length in 1..1000 {
            let mut real_planner = RealFftPlanner::<f64>::new();
            let c2r = real_planner.plan_fft_inverse(length);
            let mut out_a = c2r.make_output_vec();
            let mut indata = c2r.make_input_vec();
            let mut rustfft_check: Vec<Complex<f64>> = vec![Complex::zero(); length];
            let mut rng = rand::thread_rng();
            for val in indata.iter_mut() {
                *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
            }
            indata[0].im = 0.0;
            if length % 2 == 0 {
                indata[length / 2].im = 0.0;
            }
            for (val_long, val) in rustfft_check
                .iter_mut()
                .take(length / 2 + 1)
                .zip(indata.iter())
            {
                *val_long = *val;
            }
            for (val_long, val) in rustfft_check
                .iter_mut()
                .rev()
                .take(length / 2)
                .zip(indata.iter().skip(1))
            {
                *val_long = val.conj();
            }
            let mut fft_planner = FftPlanner::<f64>::new();
            let fft = fft_planner.plan_fft_inverse(length);

            c2r.process(&mut indata, &mut out_a).unwrap();
            fft.process(&mut rustfft_check);

            let check_real = rustfft_check.iter().map(|val| val.re).collect::<Vec<f64>>();
            let maxdiff = compare_f64(&out_a, &check_real);
            assert!(
                maxdiff < 1.0e-9,
                "Length: {}, too large error: {}",
                length,
                maxdiff
            );
        }
    }

    // Test that ComplexToReal returns the right errors
    #[test]
    fn complex_to_real_errors_even() {
        let length = 100;
        let mut real_planner = RealFftPlanner::<f64>::new();
        let c2r = real_planner.plan_fft_inverse(length);
        let mut out_a = c2r.make_output_vec();
        let mut indata = c2r.make_input_vec();
        let mut rng = rand::thread_rng();

        // Make some valid data
        for val in indata.iter_mut() {
            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
        }
        indata[0].im = 0.0;
        indata[50].im = 0.0;
        // this should be ok
        assert!(c2r.process(&mut indata, &mut out_a).is_ok());

        // Make some invalid data, first point invalid
        for val in indata.iter_mut() {
            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
        }
        indata[50].im = 0.0;
        let res = c2r.process(&mut indata, &mut out_a);
        assert!(res.is_err());
        assert!(matches!(res, Err(FftError::InputValues(true, false))));

        // Make some invalid data, last point invalid
        for val in indata.iter_mut() {
            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
        }
        indata[0].im = 0.0;
        let res = c2r.process(&mut indata, &mut out_a);
        assert!(res.is_err());
        assert!(matches!(res, Err(FftError::InputValues(false, true))));
    }

    // Test that ComplexToReal returns the right errors
    #[test]
    fn complex_to_real_errors_odd() {
        let length = 101;
        let mut real_planner = RealFftPlanner::<f64>::new();
        let c2r = real_planner.plan_fft_inverse(length);
        let mut out_a = c2r.make_output_vec();
        let mut indata = c2r.make_input_vec();
        let mut rng = rand::thread_rng();

        // Make some valid data
        for val in indata.iter_mut() {
            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
        }
        indata[0].im = 0.0;
        // this should be ok
        assert!(c2r.process(&mut indata, &mut out_a).is_ok());

        // Make some invalid data, first point invalid
        for val in indata.iter_mut() {
            *val = Complex::new(rng.gen::<f64>(), rng.gen::<f64>());
        }
        let res = c2r.process(&mut indata, &mut out_a);
        assert!(res.is_err());
        assert!(matches!(res, Err(FftError::InputValues(true, false))));
    }

    // Compare RealToComplex with standard FFT
    #[test]
    fn real_to_complex() {
        for length in 1..1000 {
            let mut real_planner = RealFftPlanner::<f64>::new();
            let r2c = real_planner.plan_fft_forward(length);
            let mut out_a = r2c.make_output_vec();
            let mut indata = r2c.make_input_vec();
            let mut rng = rand::thread_rng();
            for val in indata.iter_mut() {
                *val = rng.gen::<f64>();
            }
            let mut rustfft_check = indata
                .iter()
                .map(Complex::from)
                .collect::<Vec<Complex<f64>>>();
            let mut fft_planner = FftPlanner::<f64>::new();
            let fft = fft_planner.plan_fft_forward(length);

            fft.process(&mut rustfft_check);
            r2c.process(&mut indata, &mut out_a).unwrap();
            let maxdiff = compare_complex(&out_a, &rustfft_check[0..(length / 2 + 1)]);
            assert!(
                maxdiff < 1.0e-9,
                "Length: {}, too large error: {}",
                length,
                maxdiff
            );
        }
    }

    // Check that the ? operator works on the custom errors. No need to run, just needs to compile.
    #[allow(dead_code)]
    fn test_error() -> Result<(), Box<dyn Error>> {
        let mut real_planner = RealFftPlanner::<f64>::new();
        let r2c = real_planner.plan_fft_forward(100);
        let mut out_a = r2c.make_output_vec();
        let mut indata = r2c.make_input_vec();
        r2c.process(&mut indata, &mut out_a)?;
        Ok(())
    }
}