ndrustfft/
lib.rs

1//! # ndrustfft: *n*-dimensional complex-to-complex FFT, real-to-complex FFT and real-to-real DCT
2//!
3//! This library is a wrapper for `RustFFT`, `RustDCT` and `RealFft`
4//! that enables performing FFTs and DCTs of complex- and real-valued
5//! data on *n*-dimensional arrays (ndarray).
6//!
7//! ndrustfft provides Handler structs for FFT's and DCTs, which must be provided alongside
8//! with the arrays to the respective functions (see below) .
9//! The Handlers implement a process function, which is a wrapper around Rustfft's
10//! process.
11//! Transforms along the outermost axis are in general the fastest, while transforms along
12//! other axis' will temporarily create copies of the input array.
13//!
14//! ## Parallel
15//! The library ships all functions with a parallel version
16//! which leverages the parallel iterators of the ndarray crate.
17//!
18//! ## Available transforms
19//! ### Complex-to-complex
20//! - `fft` : [`ndfft`],[`ndfft_inplace`],[`ndfft_par`],[`ndfft_inplace_par`]
21//! - `ifft`: [`ndifft`],[`ndifft_inplace`],[`ndifft_par`],[`ndfft_inplace_par`]
22//! ### Real-to-complex
23//! - `fft_r2c` : [`ndfft_r2c`], [`ndfft_r2c_par`],
24//! ### Complex-to-real
25//! - `ifft_r2c`: [`ndifft_r2c`],[`ndifft_r2c_par`]
26//! ### Real-to-real
27//! - `dct1`: [`nddct1`],[`nddct1_inplace`],[`nddct1_par`],[`nddct1_inplace_par`]
28//! - `dct2`: [`nddct2`],[`nddct2_inplace`],[`nddct2_par`],[`nddct2_inplace_par`]
29//! - `dct3`: [`nddct3`],[`nddct3_inplace`],[`nddct3_par`],[`nddct3_inplace_par`]
30//! - `dct4`: [`nddct4`],[`nddct4_inplace`],[`nddct4_par`],[`nddct4_inplace_par`]
31//!
32//! ## Example
33//! 2-Dimensional real-to-complex fft along first axis
34//! ```
35//! use ndarray::{Array2, Dim, Ix};
36//! use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler};
37//!
38//! let (nx, ny) = (6, 4);
39//! let mut data = Array2::<f64>::zeros((nx, ny));
40//! let mut vhat = Array2::<Complex<f64>>::zeros((nx / 2 + 1, ny));
41//! for (i, v) in data.iter_mut().enumerate() {
42//!     *v = i as f64;
43//! }
44//! let mut fft_handler = R2cFftHandler::<f64>::new(nx);
45//! ndfft_r2c(
46//!     &data.view(),
47//!     &mut vhat.view_mut(),
48//!     &mut fft_handler,
49//!     0,
50//! );
51//! ```
52//!
53//! # Normalization
54//! `RustFFT`, `RustDCT` and `RealFft`  do not normalise,
55//! while this library applies normalization as scipy by default.
56//! This means, inverse ffts are divided by a factor of `data.len()`,
57//! and dcts are multiplied by two. It is possible to switch from the
58//! default normalization to no normalization, or to apply a custom
59//! normalization by using the normalization builder.
60//!
61//! See: `examples/fft_norm`
62//!
63//! # Features
64//!
65//! - parallel: Enables parallel transform using `ndarrays` + `rayon` (enabled by default)
66//! - avx: Enables `rustfft`'s avx feature (enabled by default)
67//! - sse: Enables `rustfft`'s sse feature (enabled by default)
68//! - neon: Enables `rustfft`'s neon feature (enabled by default)
69//!
70//! # Documentation
71//! [docs.rs](https://docs.rs/ndrustfft/)
72//!
73//! # Versions
74//! [Changelog](CHANGELOG.md)
75#![warn(missing_docs)]
76extern crate ndarray;
77extern crate rustfft;
78use ndarray::{Array1, ArrayBase, Axis, Dimension, Zip};
79use ndarray::{Data, DataMut};
80use num_traits::FloatConst;
81use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
82use rustdct::{Dct1, DctPlanner, TransformType2And3, TransformType4};
83pub use rustfft::num_complex::Complex;
84pub use rustfft::num_traits::Zero;
85pub use rustfft::FftNum;
86use rustfft::{Fft, FftPlanner};
87use std::sync::Arc;
88
89/// Represents different types of normalization methods.
90#[derive(Clone)]
91pub enum Normalization<T> {
92    /// No normalization applied, output equals `rustfft`, `realfft` or `rustdct`.
93    None,
94    /// Applies normalization similar to scipy's default behavior.
95    Default,
96    /// Applies a custom normalization function provided as a closure.
97    Custom(fn(&mut [T])),
98}
99
100macro_rules! create_transform {
101    (
102        $(#[$meta:meta])* $i: ident, $a: ty, $b: ty, $h: ty, $p: ident
103    ) => {
104        $(#[$meta])*
105        pub fn $i<R, S, T, D>(
106            input: &ArrayBase<R, D>,
107            output: &mut ArrayBase<S, D>,
108            handler: &$h,
109            axis: usize,
110        ) where
111            T: FftNum + FloatConst,
112            R: Data<Elem = $a>,
113            S: Data<Elem = $b> + DataMut,
114            D: Dimension,
115        {
116            let n = output.shape()[axis];
117            if input.is_standard_layout() && output.is_standard_layout() {
118                let outer_axis = input.ndim() - 1;
119                if axis == outer_axis {
120                    Zip::from(input.rows())
121                        .and(output.rows_mut())
122                        .for_each(|x, mut y| {
123                            handler.$p(x.as_slice().unwrap(), y.as_slice_mut().unwrap());
124                        });
125                } else {
126                    let mut outvec = Array1::zeros(output.shape()[axis]);
127                    let mut input = input.view();
128                    input.swap_axes(outer_axis, axis);
129                    output.swap_axes(outer_axis, axis);
130                    Zip::from(input.rows())
131                        .and(output.rows_mut())
132                        .for_each(|x, mut y| {
133                            handler.$p(&x.to_vec(), outvec.as_slice_mut().unwrap());
134                            y.assign(&outvec);
135                        });
136                    output.swap_axes(outer_axis, axis);
137                }
138            } else {
139                Zip::from(input.lanes(Axis(axis)))
140                .and(output.lanes_mut(Axis(axis)))
141                .for_each(|x, mut y| {
142                    if let Some(x_s) = x.as_slice() {
143                        if let Some(y_s) = y.as_slice_mut() {
144                            // x and y are contiguous
145                            handler.$p(x_s, y_s);
146                        } else {
147                            let mut outvec = Array1::zeros(n);
148                            // x is contiguous, y is not contiguous
149                            handler.$p(x_s, outvec.as_slice_mut().unwrap());
150                            y.assign(&outvec);
151                        }
152                    } else {
153                        if let Some(y_s) = y.as_slice_mut() {
154                            // x is not contiguous, y is contiguous
155                            handler.$p(&x.to_vec(), y_s);
156                        } else {
157                            let mut outvec = Array1::zeros(n);
158                            // x and y are not contiguous
159                            handler.$p(&x.to_vec(), outvec.as_slice_mut().unwrap());
160                            y.assign(&outvec);
161                        }
162                    }
163                });
164            }
165        }
166    };
167}
168
169macro_rules! create_transform_inplace {
170    (
171        $(#[$meta:meta])* $i: ident, $a: ty, $h: ty, $p: ident
172    ) => {
173        $(#[$meta])*
174        pub fn $i<R, T, D>(
175            data: &mut ArrayBase<R, D>,
176            handler: &$h,
177            axis: usize,
178        ) where
179            T: FftNum + FloatConst,
180            R: DataMut<Elem = $a>,
181            D: Dimension,
182        {
183            if data.is_standard_layout() {
184                let outer_axis = data.ndim() - 1;
185                if axis == outer_axis {
186                    for mut row in data.rows_mut() {
187                        handler.$p(row.as_slice_mut().unwrap());
188                    }
189                } else {
190                    let mut temp = Array1::zeros(data.shape()[axis]);
191                    data.swap_axes(outer_axis, axis);
192                    for mut row in data.rows_mut() {
193                        temp.assign(&row);
194                        handler.$p(temp.as_slice_mut().unwrap());
195                        row.assign(&temp);
196                    }
197                    data.swap_axes(outer_axis, axis);
198                }
199            } else {
200                Zip::from(data.lanes_mut(Axis(axis)))
201                    .for_each(|mut lane| {
202                        if let Some(slice) = lane.as_slice_mut() {
203                            handler.$p(slice);
204                        } else {
205                            let mut temp = lane.to_vec();
206                            handler.$p(&mut temp);
207                            lane.assign(&Array1::from(temp));
208                        }
209                    });
210            }
211        }
212    };
213}
214
215#[cfg(feature = "parallel")]
216macro_rules! create_transform_par {
217    (
218        $(#[$meta:meta])* $i: ident, $a: ty, $b: ty, $h: ty, $p: ident
219    ) => {
220        $(#[$meta])*
221        pub fn $i<R, S, T, D>(
222            input: &ArrayBase<R, D>,
223            output: &mut ArrayBase<S, D>,
224            handler: &$h,
225            axis: usize,
226        ) where
227            T: FftNum + FloatConst,
228            R: Data<Elem = $a>,
229            S: Data<Elem = $b> + DataMut,
230            D: Dimension,
231        {
232            let n = output.shape()[axis];
233            if input.is_standard_layout() && output.is_standard_layout() {
234                let outer_axis = input.ndim() - 1;
235                if axis == outer_axis {
236                    Zip::from(input.rows())
237                        .and(output.rows_mut())
238                        .par_for_each(|x, mut y| {
239                            handler.$p(x.as_slice().unwrap(), y.as_slice_mut().unwrap());
240                        });
241                } else {
242                    let n = output.shape()[axis];
243                    let mut input = input.view();
244                    input.swap_axes(outer_axis, axis);
245                    output.swap_axes(outer_axis, axis);
246                    Zip::from(input.rows())
247                        .and(output.rows_mut())
248                        .par_for_each(|x, mut y| {
249                            let mut outvec = Array1::zeros(n);
250                            handler.$p(&x.to_vec(), outvec.as_slice_mut().unwrap());
251                            y.assign(&outvec);
252                        });
253                    output.swap_axes(outer_axis, axis);
254                }
255            } else {
256                Zip::from(input.lanes(Axis(axis)))
257                    .and(output.lanes_mut(Axis(axis)))
258                    .par_for_each(|x, mut y| {
259                        if let Some(x_s) = x.as_slice() {
260                            if let Some(y_s) = y.as_slice_mut() {
261                                // x and y are contiguous
262                                handler.$p(x_s, y_s);
263                            } else {
264                                let mut outvec = Array1::zeros(n);
265                                // x is contiguous, y is not contiguous
266                                handler.$p(x_s, outvec.as_slice_mut().unwrap());
267                                y.assign(&outvec);
268                            }
269                        } else {
270                            if let Some(y_s) = y.as_slice_mut() {
271                                // x is not contiguous, y is contiguous
272                                handler.$p(&x.to_vec(), y_s);
273                            } else {
274                                let mut outvec = Array1::zeros(n);
275                                // x and y are not contiguous
276                                handler.$p(&x.to_vec(), outvec.as_slice_mut().unwrap());
277                                y.assign(&outvec);
278                            }
279                        }
280                    });
281            }
282        }
283    };
284}
285
286#[cfg(feature = "parallel")]
287macro_rules! create_transform_inplace_par {
288    (
289        $(#[$meta:meta])* $i: ident, $a: ty, $h: ty, $p: ident
290    ) => {
291        $(#[$meta])*
292        pub fn $i<R, T, D>(
293            input: &mut ArrayBase<R, D>,
294            handler: &$h,
295            axis: usize,
296        ) where
297            T: FftNum + FloatConst,
298            R: Data<Elem = $a> + DataMut,
299            D: Dimension,
300        {
301            let n = input.shape()[axis];
302            if input.is_standard_layout() {
303                let outer_axis = input.ndim() - 1;
304                if axis == outer_axis {
305                    Zip::from(input.rows_mut())
306                        .par_for_each(|mut x| {
307                            handler.$p(x.as_slice_mut().unwrap());
308                        });
309                } else {
310                    input.swap_axes(outer_axis, axis);
311                    Zip::from(input.rows_mut())
312                        .par_for_each(|mut x| {
313                            let mut tmp = x.to_owned();
314                            handler.$p(tmp.as_slice_mut().unwrap());
315                            x.assign(&tmp);
316                        });
317                    input.swap_axes(outer_axis, axis);
318                }
319            } else {
320                Zip::from(input.lanes_mut(Axis(axis)))
321                    .par_for_each(|mut x| {
322                        if let Some(x_s) = x.as_slice_mut() {
323                            handler.$p(x_s);
324                        } else {
325                            let mut tmp = Array1::zeros(n);
326                            handler.$p(tmp.as_slice_mut().unwrap());
327                            x.assign(&tmp);
328                        }
329                    });
330            }
331        }
332    };
333}
334
335/// # *n*-dimensional complex-to-complex Fourier Transform.
336///
337/// Transforms a complex ndarray of size *n* to a complex array of size
338/// *n* and vice versa. The transformation is performed along a single
339/// axis, all other array dimensions are unaffected.
340/// Performs best on sizes which are mutiple of 2 or 3.
341///
342/// The accompanying functions for the forward transform are [`ndfft`] (serial) and
343/// [`ndfft_par`] (parallel).
344///
345/// The accompanying functions for the inverse transform are [`ndifft`] (serial) and
346/// [`ndifft_par`] (parallel).
347///
348/// # Example
349/// 2-Dimensional complex-to-complex fft along first axis
350/// ```
351/// use ndarray::{Array2, Dim, Ix};
352/// use ndrustfft::{ndfft, Complex, FftHandler};
353///
354/// let (nx, ny) = (6, 4);
355/// let mut data = Array2::<Complex<f64>>::zeros((nx, ny));
356/// let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
357/// for (i, v) in data.iter_mut().enumerate() {
358///     v.re = i as f64;
359///     v.im = i as f64;
360/// }
361/// let mut fft_handler: FftHandler<f64> = FftHandler::new(nx);
362/// ndfft(&data, &mut vhat, &mut fft_handler, 0);
363/// ```
364#[derive(Clone)]
365pub struct FftHandler<T> {
366    n: usize,
367    plan_fwd: Arc<dyn Fft<T>>,
368    plan_bwd: Arc<dyn Fft<T>>,
369    norm: Normalization<Complex<T>>,
370}
371
372impl<T: FftNum> FftHandler<T> {
373    /// Creates a new `FftHandler`.
374    ///
375    /// # Arguments
376    ///
377    /// * `n` - Length of array along axis of which fft will be performed.
378    ///   The size of the complex array after the fft is performed will be of
379    ///   size *n*.
380    ///
381    /// # Examples
382    ///
383    /// ```
384    /// use ndrustfft::FftHandler;
385    /// let handler: FftHandler<f64> = FftHandler::new(10);
386    /// ```
387    #[allow(clippy::similar_names)]
388    #[must_use]
389    pub fn new(n: usize) -> Self {
390        let mut planner = FftPlanner::<T>::new();
391        let fwd = planner.plan_fft_forward(n);
392        let bwd = planner.plan_fft_inverse(n);
393        FftHandler::<T> {
394            n,
395            plan_fwd: Arc::clone(&fwd),
396            plan_bwd: Arc::clone(&bwd),
397            norm: Normalization::Default,
398        }
399    }
400
401    /// This method allows modifying the normalization applied to the backward transform. See [`Normalization`] for more details.
402    #[must_use]
403    pub fn normalization(mut self, norm: Normalization<Complex<T>>) -> Self {
404        self.norm = norm;
405        self
406    }
407
408    fn fft_lane(&self, data: &[Complex<T>], out: &mut [Complex<T>]) {
409        Self::assert_size(self.n, data.len());
410        Self::assert_size(self.n, out.len());
411        out.clone_from_slice(data);
412        self.plan_fwd.process(out);
413    }
414
415    #[allow(clippy::cast_precision_loss)]
416    fn ifft_lane(&self, data: &[Complex<T>], out: &mut [Complex<T>]) {
417        Self::assert_size(self.n, data.len());
418        Self::assert_size(self.n, out.len());
419        out.clone_from_slice(data);
420        self.plan_bwd.process(out);
421        match self.norm {
422            Normalization::None => (),
423            Normalization::Default => Self::norm_default(out),
424            Normalization::Custom(f) => f(out),
425        }
426    }
427
428    fn fft_lane_inplace(&self, data: &mut [Complex<T>]) {
429        Self::assert_size(self.n, data.len());
430        self.plan_fwd.process(data);
431    }
432
433    #[allow(clippy::cast_precision_loss)]
434    fn ifft_lane_inplace(&self, data: &mut [Complex<T>]) {
435        Self::assert_size(self.n, data.len());
436        self.plan_bwd.process(data);
437        match self.norm {
438            Normalization::None => (),
439            Normalization::Default => Self::norm_default(data),
440            Normalization::Custom(f) => f(data),
441        }
442    }
443
444    fn norm_default(data: &mut [Complex<T>]) {
445        let n = T::one() / T::from_usize(data.len()).unwrap();
446        for d in &mut *data {
447            *d = *d * n;
448        }
449    }
450
451    fn assert_size(n: usize, size: usize) {
452        assert!(
453            n == size,
454            "Size mismatch in fft, got {} expected {}",
455            size,
456            n
457        );
458    }
459}
460
461create_transform!(
462    /// Complex-to-complex Fourier Transform (serial).
463    /// # Example
464    /// ```
465    /// use ndarray::{Array2, Dim, Ix};
466    /// use ndrustfft::{ndfft, Complex, FftHandler};
467    ///
468    /// let (nx, ny) = (6, 4);
469    /// let mut data = Array2::<Complex<f64>>::zeros((nx, ny));
470    /// let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
471    /// for (i, v) in data.iter_mut().enumerate() {
472    ///     v.re = i as f64;
473    ///     v.im = -1.0*i as f64;
474    /// }
475    /// let mut handler: FftHandler<f64> = FftHandler::new(ny);
476    /// ndfft(&data, &mut vhat, &mut handler, 1);
477    /// ```
478    ndfft,
479    Complex<T>,
480    Complex<T>,
481    FftHandler<T>,
482    fft_lane
483);
484
485create_transform!(
486    /// Complex-to-complex Inverse Fourier Transform (serial).
487    /// # Example
488    /// ```
489    /// use ndarray::Array2;
490    /// use ndrustfft::{ndfft, ndifft, Complex, FftHandler};
491    ///
492    /// let (nx, ny) = (6, 4);
493    /// let mut data = Array2::<Complex<f64>>::zeros((nx, ny));
494    /// let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
495    /// for (i, v) in data.iter_mut().enumerate() {
496    ///     v.re = i as f64;
497    ///     v.im = -1.0*i as f64;
498    /// }
499    /// let mut handler: FftHandler<f64> = FftHandler::new(ny);
500    /// ndfft(&data, &mut vhat, &mut handler, 1);
501    /// ndifft(&vhat, &mut data, &mut handler, 1);
502    /// ```
503    ndifft,
504    Complex<T>,
505    Complex<T>,
506    FftHandler<T>,
507    ifft_lane
508);
509
510create_transform_inplace!(
511    /// Complex-to-complex Fourier Transform (inplace, serial).
512    /// # Example
513    /// ```
514    /// use ndarray::{Array2, Dim, Ix};
515    /// use ndrustfft::{ndfft_inplace, Complex, FftHandler};
516    ///
517    /// let (nx, ny) = (6, 4);
518    /// let mut data = Array2::<Complex<f64>>::zeros((nx, ny));
519    /// for (i, v) in data.iter_mut().enumerate() {
520    ///     v.re = i as f64;
521    ///     v.im = -1.0*i as f64;
522    /// }
523    /// let mut handler: FftHandler<f64> = FftHandler::new(ny);
524    /// ndfft_inplace(&mut data, &mut handler, 1);
525    /// ```
526    ndfft_inplace,
527    Complex<T>,
528    FftHandler<T>,
529    fft_lane_inplace
530);
531
532create_transform_inplace!(
533    /// Complex-to-complex Inverse Fourier Transform (inplace, serial).
534    /// # Example
535    /// ```
536    /// use ndarray::Array2;
537    /// use ndrustfft::{ndfft_inplace, ndifft_inplace, Complex, FftHandler};
538    ///
539    /// let (nx, ny) = (6, 4);
540    /// let mut data = Array2::<Complex<f64>>::zeros((nx, ny));
541    /// for (i, v) in data.iter_mut().enumerate() {
542    ///     v.re = i as f64;
543    ///     v.im = -1.0*i as f64;
544    /// }
545    /// let mut handler: FftHandler<f64> = FftHandler::new(ny);
546    /// ndfft_inplace(&mut data,  &mut handler, 1);
547    /// ndifft_inplace(&mut data, &mut handler, 1);
548    /// ```
549    ndifft_inplace,
550    Complex<T>,
551    FftHandler<T>,
552    ifft_lane_inplace
553);
554
555#[cfg(feature = "parallel")]
556create_transform_par!(
557    /// Complex-to-complex Fourier Transform (parallel).
558    ///
559    /// Further infos: see [`ndfft`]
560    ndfft_par,
561    Complex<T>,
562    Complex<T>,
563    FftHandler<T>,
564    fft_lane
565);
566
567#[cfg(feature = "parallel")]
568create_transform_inplace_par!(
569    /// Complex-to-complex Fourier Transform (inplace, parallel).
570    ///
571    /// Further infos: see [`ndfft`]
572    ndfft_inplace_par,
573    Complex<T>,
574    FftHandler<T>,
575    fft_lane_inplace
576);
577
578#[cfg(feature = "parallel")]
579create_transform_par!(
580    /// Complex-to-complex inverse Fourier Transform (parallel).
581    ///
582    /// Further infos: see [`ndifft`]
583    ndifft_par,
584    Complex<T>,
585    Complex<T>,
586    FftHandler<T>,
587    ifft_lane
588);
589
590#[cfg(feature = "parallel")]
591create_transform_inplace_par!(
592    /// Complex-to-complex inverse Fourier Transform (inplace, parallel).
593    ///
594    /// Further infos: see [`ndifft`]
595    ndifft_inplace_par,
596    Complex<T>,
597    FftHandler<T>,
598    ifft_lane_inplace
599);
600
601/// # *n*-dimensional real-to-complex Fourier Transform.
602///
603/// Transforms a real ndarray of size *n* to a complex array of size
604/// *n/2+1* and vice versa. The transformation is performed along a single
605/// axis, all other array dimensions are unaffected.
606/// Performs best on sizes which are mutiple of 2 or 3.
607///
608/// The accompanying functions for the forward transform are [`ndfft_r2c`] (serial) and
609/// [`ndfft_r2c_par`] (parallel).
610///
611/// The accompanying functions for the inverse transform are [`ndifft_r2c`] (serial) and
612/// [`ndifft_r2c_par`] (parallel).
613///
614/// # Example
615/// 2-Dimensional real-to-complex fft along first axis
616/// ```
617/// use ndarray::{Array2, Dim, Ix};
618/// use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler};
619///
620/// let (nx, ny) = (6, 4);
621/// let mut data = Array2::<f64>::zeros((nx, ny));
622/// let mut vhat = Array2::<Complex<f64>>::zeros((nx / 2 + 1, ny));
623/// for (i, v) in data.iter_mut().enumerate() {
624///     *v = i as f64;
625/// }
626/// let mut fft_handler = R2cFftHandler::<f64>::new(nx);
627/// ndfft_r2c(&data, &mut vhat, &mut fft_handler, 0);
628/// ```
629#[derive(Clone)]
630pub struct R2cFftHandler<T> {
631    n: usize,
632    m: usize,
633    plan_fwd: Arc<dyn RealToComplex<T>>,
634    plan_bwd: Arc<dyn ComplexToReal<T>>,
635    norm: Normalization<Complex<T>>,
636}
637
638impl<T: FftNum> R2cFftHandler<T> {
639    /// Creates a new `RealFftPlanner`.
640    ///
641    /// # Arguments
642    ///
643    /// * `n` - Length of array along axis of which fft will be performed.
644    ///   The size of the complex array after the fft is performed will be of
645    ///   size *n / 2 + 1*.
646    ///
647    /// # Examples
648    ///
649    /// ```
650    /// use ndrustfft::R2cFftHandler;
651    /// let handler = R2cFftHandler::<f64>::new(10);
652    /// ```
653    #[allow(clippy::similar_names)]
654    #[must_use]
655    pub fn new(n: usize) -> Self {
656        let mut planner = RealFftPlanner::<T>::new();
657        let fwd = planner.plan_fft_forward(n);
658        let bwd = planner.plan_fft_inverse(n);
659        Self {
660            n,
661            m: n / 2 + 1,
662            plan_fwd: Arc::clone(&fwd),
663            plan_bwd: Arc::clone(&bwd),
664            norm: Normalization::Default,
665        }
666    }
667
668    /// This method allows modifying the normalization applied to the backward transform. See [`Normalization`] for more details.
669    #[must_use]
670    pub fn normalization(mut self, norm: Normalization<Complex<T>>) -> Self {
671        self.norm = norm;
672        self
673    }
674
675    fn fft_r2c_lane(&self, data: &[T], out: &mut [Complex<T>]) {
676        Self::assert_size(self.n, data.len());
677        Self::assert_size(self.m, out.len());
678        let mut buffer = vec![T::zero(); self.n];
679        buffer.clone_from_slice(data);
680        self.plan_fwd.process(&mut buffer, out).unwrap();
681    }
682
683    #[allow(clippy::cast_precision_loss)]
684    fn ifft_r2c_lane(&self, data: &[Complex<T>], out: &mut [T]) {
685        Self::assert_size(self.m, data.len());
686        Self::assert_size(self.n, out.len());
687        let mut buffer = vec![Complex::zero(); self.m];
688        buffer.clone_from_slice(data);
689        match self.norm {
690            Normalization::None => (),
691            Normalization::Default => Self::norm_default(&mut buffer, self.n),
692            Normalization::Custom(f) => f(&mut buffer),
693        }
694        // First element must be real
695        buffer[0].im = T::zero();
696        // If original vector is even, last element must be real
697        if self.n.is_multiple_of(2) {
698            buffer[self.m - 1].im = T::zero();
699        }
700        self.plan_bwd.process(&mut buffer, out).unwrap();
701    }
702
703    fn norm_default(data: &mut [Complex<T>], size: usize) {
704        let n = T::one() / T::from_usize(size).unwrap();
705        for d in &mut *data {
706            d.re = d.re * n;
707            d.im = d.im * n;
708        }
709    }
710
711    fn assert_size(n: usize, size: usize) {
712        assert!(
713            n == size,
714            "Size mismatch in fft, got {} expected {}",
715            size,
716            n
717        );
718    }
719}
720
721create_transform!(
722    /// Real-to-complex Fourier Transform (serial).
723    /// # Example
724    /// ```
725    /// use ndarray::Array2;
726    /// use ndrustfft::{ndfft_r2c, Complex, R2cFftHandler};
727    ///
728    /// let (nx, ny) = (6, 4);
729    /// let mut data = Array2::<f64>::zeros((nx, ny));
730    /// let mut vhat = Array2::<Complex<f64>>::zeros((nx / 2 + 1, ny));
731    /// for (i, v) in data.iter_mut().enumerate() {
732    ///     *v = i as f64;
733    /// }
734    /// let mut handler = R2cFftHandler::<f64>::new(nx);
735    /// ndfft_r2c(&data, &mut vhat, &mut handler, 0);
736    /// ```
737    ndfft_r2c,
738    T,
739    Complex<T>,
740    R2cFftHandler<T>,
741    fft_r2c_lane
742);
743
744create_transform!(
745    /// Complex-to-real inverse Fourier Transform (serial).
746    /// # Example
747    /// ```
748    /// use ndarray::Array2;
749    /// use ndrustfft::{ndifft_r2c, Complex, R2cFftHandler};
750    ///
751    /// let (nx, ny) = (6, 4);
752    /// let mut data = Array2::<f64>::zeros((nx, ny));
753    /// let mut vhat = Array2::<Complex<f64>>::zeros((nx / 2 + 1, ny));
754    /// for (i, v) in vhat.iter_mut().enumerate() {
755    ///     v.re = i as f64;
756    /// }
757    /// let mut handler = R2cFftHandler::<f64>::new(nx);
758    /// ndifft_r2c(&vhat, &mut data, &mut handler, 0);
759    /// ```
760    ndifft_r2c,
761    Complex<T>,
762    T,
763    R2cFftHandler<T>,
764    ifft_r2c_lane
765);
766
767#[cfg(feature = "parallel")]
768create_transform_par!(
769    /// Real-to-complex Fourier Transform (parallel).
770    ///
771    /// Further infos: see [`ndfft_r2c`]
772    ndfft_r2c_par,
773    T,
774    Complex<T>,
775    R2cFftHandler<T>,
776    fft_r2c_lane
777);
778
779#[cfg(feature = "parallel")]
780create_transform_par!(
781    /// Complex-to-real inverse Fourier Transform (parallel).
782    ///
783    /// Further infos: see [`ndifft_r2c`]
784    ndifft_r2c_par,
785    Complex<T>,
786    T,
787    R2cFftHandler<T>,
788    ifft_r2c_lane
789);
790
791/// # *n*-dimensional real-to-real Cosine Transform.
792///
793/// The dct transforms a real ndarray of size *n* to a real array of size *n*.
794/// The transformation is performed along a single axis, all other array
795/// dimensions are unaffected.
796/// Performs best on sizes where *2(n-1)* is a mutiple of 2 or 3. The crate
797/// contains benchmarks, see benches folder, where different sizes can be
798/// tested to optmize performance.
799///
800/// The accompanying functions are [`nddct1`] (serial) and
801/// [`nddct1_par`] (parallel).
802///
803/// # Example
804/// 2-Dimensional real-to-real dft along second axis
805/// ```
806/// use ndarray::Array2;
807/// use ndrustfft::{DctHandler, nddct1};
808///
809/// let (nx, ny) = (6, 4);
810/// let mut data = Array2::<f64>::zeros((nx, ny));
811/// let mut vhat = Array2::<f64>::zeros((nx, ny));
812/// for (i, v) in data.iter_mut().enumerate() {
813///     *v = i as f64;
814/// }
815/// let mut handler: DctHandler<f64> = DctHandler::new(ny);
816/// nddct1(&data, &mut vhat, &mut handler, 1);
817/// ```
818#[derive(Clone)]
819pub struct DctHandler<T> {
820    n: usize,
821    plan_dct1: Arc<dyn Dct1<T>>,
822    plan_dct2: Arc<dyn TransformType2And3<T>>,
823    plan_dct3: Arc<dyn TransformType2And3<T>>,
824    plan_dct4: Arc<dyn TransformType4<T>>,
825    norm: Normalization<T>,
826}
827
828impl<T: FftNum + FloatConst> DctHandler<T> {
829    /// Creates a new `DctHandler`.
830    ///
831    /// # Arguments
832    ///
833    /// * `n` - Length of array along axis of which dct will be performed.
834    ///   The size and type of the array will be the same after the transform.
835    ///
836    /// # Examples
837    ///
838    /// ```
839    /// use ndrustfft::DctHandler;
840    /// let handler: DctHandler<f64> = DctHandler::new(10);
841    /// ```
842    #[must_use]
843    pub fn new(n: usize) -> Self {
844        let mut planner = DctPlanner::<T>::new();
845        let dct1 = planner.plan_dct1(n);
846        let dct2 = planner.plan_dct2(n);
847        let dct3 = planner.plan_dct3(n);
848        let dct4 = planner.plan_dct4(n);
849        Self {
850            n,
851            plan_dct1: Arc::clone(&dct1),
852            plan_dct2: Arc::clone(&dct2),
853            plan_dct3: Arc::clone(&dct3),
854            plan_dct4: Arc::clone(&dct4),
855            norm: Normalization::Default,
856        }
857    }
858
859    /// This method allows modifying the normalization applied to the backward transform. See [`Normalization`] for more details.
860    #[must_use]
861    pub fn normalization(mut self, norm: Normalization<T>) -> Self {
862        self.norm = norm;
863        self
864    }
865
866    fn dct1_lane(&self, data: &[T], out: &mut [T]) {
867        Self::assert_size(self, data.len());
868        Self::assert_size(self, out.len());
869        out.clone_from_slice(data);
870        match self.norm {
871            Normalization::None => (),
872            Normalization::Default => Self::norm_default(out),
873            Normalization::Custom(f) => f(out),
874        }
875        self.plan_dct1.process_dct1(out);
876    }
877
878    fn dct1_lane_inplace(&self, data: &mut [T]) {
879        Self::assert_size(self, data.len());
880        match self.norm {
881            Normalization::None => (),
882            Normalization::Default => Self::norm_default(data),
883            Normalization::Custom(f) => f(data),
884        }
885        self.plan_dct1.process_dct1(data);
886    }
887
888    fn dct2_lane(&self, data: &[T], out: &mut [T]) {
889        Self::assert_size(self, data.len());
890        Self::assert_size(self, out.len());
891        out.clone_from_slice(data);
892        match self.norm {
893            Normalization::None => (),
894            Normalization::Default => Self::norm_default(out),
895            Normalization::Custom(f) => f(out),
896        }
897        self.plan_dct2.process_dct2(out);
898    }
899
900    fn dct2_lane_inplace(&self, data: &mut [T]) {
901        Self::assert_size(self, data.len());
902        match self.norm {
903            Normalization::None => (),
904            Normalization::Default => Self::norm_default(data),
905            Normalization::Custom(f) => f(data),
906        }
907        self.plan_dct2.process_dct2(data);
908    }
909
910    fn dct3_lane(&self, data: &[T], out: &mut [T]) {
911        Self::assert_size(self, data.len());
912        Self::assert_size(self, out.len());
913        out.clone_from_slice(data);
914        match self.norm {
915            Normalization::None => (),
916            Normalization::Default => Self::norm_default(out),
917            Normalization::Custom(f) => f(out),
918        }
919        self.plan_dct2.process_dct3(out);
920    }
921
922    fn dct3_lane_inplace(&self, data: &mut [T]) {
923        Self::assert_size(self, data.len());
924        match self.norm {
925            Normalization::None => (),
926            Normalization::Default => Self::norm_default(data),
927            Normalization::Custom(f) => f(data),
928        }
929        self.plan_dct3.process_dct3(data);
930    }
931
932    fn dct4_lane(&self, data: &[T], out: &mut [T]) {
933        Self::assert_size(self, data.len());
934        Self::assert_size(self, out.len());
935        out.clone_from_slice(data);
936        match self.norm {
937            Normalization::None => (),
938            Normalization::Default => Self::norm_default(out),
939            Normalization::Custom(f) => f(out),
940        }
941        self.plan_dct4.process_dct4(out);
942    }
943
944    fn dct4_lane_inplace(&self, data: &mut [T]) {
945        Self::assert_size(self, data.len());
946        match self.norm {
947            Normalization::None => (),
948            Normalization::Default => Self::norm_default(data),
949            Normalization::Custom(f) => f(data),
950        }
951        self.plan_dct4.process_dct4(data);
952    }
953
954    fn norm_default(data: &mut [T]) {
955        let two = T::one() + T::one();
956        for d in &mut *data {
957            *d = *d * two;
958        }
959    }
960
961    fn assert_size(&self, size: usize) {
962        assert!(
963            self.n == size,
964            "Size mismatch in dct, got {} expected {}",
965            size,
966            self.n
967        );
968    }
969}
970
971create_transform!(
972    /// Real-to-real Discrete Cosine Transform of type 1 DCT-I (serial).
973    ///
974    /// # Example
975    /// ```
976    /// use ndarray::Array2;
977    /// use ndrustfft::{DctHandler, nddct1};
978    ///
979    /// let (nx, ny) = (6, 4);
980    /// let mut data = Array2::<f64>::zeros((nx, ny));
981    /// let mut vhat = Array2::<f64>::zeros((nx, ny));
982    /// for (i, v) in data.iter_mut().enumerate() {
983    ///     *v = i as f64;
984    /// }
985    /// let mut handler: DctHandler<f64> = DctHandler::new(ny);
986    /// nddct1(&data, &mut vhat, &mut handler, 1);
987    /// ```
988    nddct1,
989    T,
990    T,
991    DctHandler<T>,
992    dct1_lane
993);
994
995create_transform_inplace!(
996    /// Real-to-real Discrete Cosine Transform of type 1 DCT-I (inplace, serial).
997    ///
998    /// # Example
999    /// ```
1000    /// use ndarray::Array2;
1001    /// use ndrustfft::{DctHandler, nddct1_inplace};
1002    ///
1003    /// let (nx, ny) = (6, 4);
1004    /// let mut data = Array2::<f64>::zeros((nx, ny));
1005    /// for (i, v) in data.iter_mut().enumerate() {
1006    ///     *v = i as f64;
1007    /// }
1008    /// let mut handler: DctHandler<f64> = DctHandler::new(ny);
1009    /// nddct1_inplace(&mut data, &mut handler, 1);
1010    /// ```
1011    nddct1_inplace,
1012    T,
1013    DctHandler<T>,
1014    dct1_lane_inplace
1015);
1016
1017#[cfg(feature = "parallel")]
1018create_transform_par!(
1019    /// Real-to-real Discrete Cosine Transform of type 1 DCT-I (parallel).
1020    ///
1021    /// Further infos: see [`nddct1`]
1022    nddct1_par,
1023    T,
1024    T,
1025    DctHandler<T>,
1026    dct1_lane
1027);
1028
1029#[cfg(feature = "parallel")]
1030create_transform_inplace_par!(
1031    /// Real-to-real Discrete Cosine Transform of type 1 DCT-1 (inplace, parallel).
1032    nddct1_inplace_par,
1033    T,
1034    DctHandler<T>,
1035    dct1_lane_inplace
1036);
1037
1038create_transform!(
1039    /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (serial).
1040    nddct2,
1041    T,
1042    T,
1043    DctHandler<T>,
1044    dct2_lane
1045);
1046
1047create_transform_inplace!(
1048    /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (inplace, serial).
1049    nddct2_inplace,
1050    T,
1051    DctHandler<T>,
1052    dct2_lane_inplace
1053);
1054
1055#[cfg(feature = "parallel")]
1056create_transform_par!(
1057    /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (parallel).
1058    nddct2_par,
1059    T,
1060    T,
1061    DctHandler<T>,
1062    dct2_lane
1063);
1064
1065#[cfg(feature = "parallel")]
1066create_transform_inplace_par!(
1067    /// Real-to-real Discrete Cosine Transform of type 2 DCT-2 (inplace, parallel).
1068    nddct2_inplace_par,
1069    T,
1070    DctHandler<T>,
1071    dct2_lane_inplace
1072);
1073
1074create_transform!(
1075    /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (serial).
1076    nddct3,
1077    T,
1078    T,
1079    DctHandler<T>,
1080    dct3_lane
1081);
1082
1083create_transform_inplace!(
1084    /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (inplace, serial).
1085    nddct3_inplace,
1086    T,
1087    DctHandler<T>,
1088    dct3_lane_inplace
1089);
1090
1091#[cfg(feature = "parallel")]
1092create_transform_par!(
1093    /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (parallel).
1094    nddct3_par,
1095    T,
1096    T,
1097    DctHandler<T>,
1098    dct3_lane
1099);
1100
1101#[cfg(feature = "parallel")]
1102create_transform_inplace_par!(
1103    /// Real-to-real Discrete Cosine Transform of type 3 DCT-3 (inplace, parallel).
1104    nddct3_inplace_par,
1105    T,
1106    DctHandler<T>,
1107    dct3_lane_inplace
1108);
1109
1110create_transform!(
1111    /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (serial).
1112    nddct4,
1113    T,
1114    T,
1115    DctHandler<T>,
1116    dct4_lane
1117);
1118
1119create_transform_inplace!(
1120    /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (inplace, serial).
1121    nddct4_inplace,
1122    T,
1123    DctHandler<T>,
1124    dct4_lane_inplace
1125);
1126
1127#[cfg(feature = "parallel")]
1128create_transform_par!(
1129    /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (parallel).
1130    nddct4_par,
1131    T,
1132    T,
1133    DctHandler<T>,
1134    dct4_lane
1135);
1136
1137#[cfg(feature = "parallel")]
1138create_transform_inplace_par!(
1139    /// Real-to-real Discrete Cosine Transform of type 4 DCT-4 (inplace, parallel).
1140    nddct4_inplace_par,
1141    T,
1142    DctHandler<T>,
1143    dct4_lane_inplace
1144);
1145
1146/// Tests
1147#[cfg(test)]
1148mod test {
1149    use super::*;
1150    use ndarray::{array, Array2, ShapeBuilder};
1151
1152    fn approx_eq<A, S, D>(result: &ArrayBase<S, D>, expected: &ArrayBase<S, D>)
1153    where
1154        A: FftNum + std::fmt::Display + std::cmp::PartialOrd,
1155        S: ndarray::Data<Elem = A>,
1156        D: Dimension,
1157    {
1158        let dif = A::from_f64(1e-3).unwrap();
1159        for (a, b) in expected.iter().zip(result.iter()) {
1160            if (*a - *b).abs() > dif {
1161                panic!("Large difference of values, got {} expected {}.", b, a)
1162            }
1163        }
1164    }
1165
1166    fn approx_eq_complex<A, S, D>(result: &ArrayBase<S, D>, expected: &ArrayBase<S, D>)
1167    where
1168        A: FftNum + std::fmt::Display + std::cmp::PartialOrd,
1169        S: ndarray::Data<Elem = Complex<A>>,
1170        D: Dimension,
1171    {
1172        let dif = A::from_f64(1e-3).unwrap();
1173        for (a, b) in expected.iter().zip(result.iter()) {
1174            if (a.re - b.re).abs() > dif || (a.im - b.im).abs() > dif {
1175                panic!("Large difference of values, got {} expected {}.", b, a)
1176            }
1177        }
1178    }
1179
1180    fn test_matrix() -> Array2<f64> {
1181        array![
1182            [0.1, 1.908, -0.035, -0.278, 0.264, -1.349],
1183            [0.88, 0.86, -0.267, -0.809, 1.374, 0.757],
1184            [1.418, -0.68, 0.814, 0.852, -0.613, 0.468],
1185            [0.817, -0.697, -2.157, 0.447, -0.949, 2.243],
1186            [-0.474, -0.09, -0.567, -0.772, 0.021, 2.455],
1187            [-0.745, 1.52, 0.509, -0.066, 2.802, -0.042],
1188        ]
1189    }
1190
1191    fn test_matrix_complex() -> Array2<Complex<f64>> {
1192        test_matrix().mapv(|x| Complex::new(x, x))
1193    }
1194
1195    fn test_matrix_complex_f() -> Array2<Complex<f64>> {
1196        let mut arr = Array2::zeros((6, 6).f());
1197        for (a, b) in arr.iter_mut().zip(test_matrix_complex().iter()) {
1198            *a = *b
1199        }
1200        arr
1201    }
1202
1203    #[test]
1204    fn test_fft() {
1205        // Solution from np.fft.fft
1206        let solution_re = array![
1207            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1208            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1209            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1210            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1211            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1212            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1213        ];
1214
1215        let solution_im = array![
1216            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1217            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1218            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1219            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1220            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1221            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1222        ];
1223
1224        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1225        for (s, (s_re, s_im)) in solution
1226            .iter_mut()
1227            .zip(solution_re.iter().zip(solution_im.iter()))
1228        {
1229            s.re = *s_re;
1230            s.im = *s_im;
1231        }
1232
1233        // Setup
1234        let mut v = test_matrix_complex();
1235        let v_copy = v.clone();
1236        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1237        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
1238        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1239
1240        // Transform
1241        ndfft(&v, &mut vhat, &mut handler, 1);
1242        ndifft(&vhat, &mut v, &mut handler, 1);
1243
1244        // Assert
1245        approx_eq_complex(&vhat, &solution);
1246        approx_eq_complex(&v, &v_copy);
1247    }
1248
1249    #[test]
1250    fn test_fft_axis0() {
1251        // Solution from np.fft.fft
1252        let solution_re = array![
1253            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1254            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1255            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1256            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1257            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1258            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1259        ];
1260
1261        let solution_im = array![
1262            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1263            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1264            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1265            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1266            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1267            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1268        ];
1269
1270        // Transpose the arrays
1271        let solution_re_t = solution_re.t().as_standard_layout().to_owned();
1272        let solution_im_t = solution_im.t().as_standard_layout().to_owned();
1273
1274        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re_t.raw_dim());
1275        for (s, (s_re, s_im)) in solution
1276            .iter_mut()
1277            .zip(solution_re_t.iter().zip(solution_im_t.iter()))
1278        {
1279            s.re = *s_re;
1280            s.im = *s_im;
1281        }
1282
1283        // Setup
1284        let mut v = test_matrix_complex().t().as_standard_layout().to_owned();
1285        let v_copy = v.clone();
1286        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1287        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
1288        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1289
1290        // Transform
1291        ndfft(&v, &mut vhat, &mut handler, 0);
1292        ndifft(&vhat, &mut v, &mut handler, 0);
1293
1294        // Assert
1295        approx_eq_complex(&vhat, &solution);
1296        approx_eq_complex(&v, &v_copy);
1297    }
1298
1299    #[test]
1300    fn test_fft_inplace() {
1301        // Solution from np.fft.fft
1302        let solution_re = array![
1303            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1304            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1305            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1306            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1307            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1308            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1309        ];
1310
1311        let solution_im = array![
1312            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1313            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1314            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1315            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1316            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1317            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1318        ];
1319
1320        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1321        for (s, (s_re, s_im)) in solution
1322            .iter_mut()
1323            .zip(solution_re.iter().zip(solution_im.iter()))
1324        {
1325            s.re = *s_re;
1326            s.im = *s_im;
1327        }
1328
1329        // Setup
1330        let mut v = test_matrix_complex();
1331        let v_copy = v.clone();
1332        let (_, ny) = (v.shape()[0], v.shape()[1]);
1333        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1334
1335        // Transform
1336        ndfft_inplace(&mut v, &mut handler, 1);
1337        approx_eq_complex(&v, &solution);
1338
1339        ndifft_inplace(&mut v, &mut handler, 1);
1340        approx_eq_complex(&v, &v_copy);
1341    }
1342
1343    #[test]
1344    fn test_fft_inplace_axis0() {
1345        // Solution from np.fft.fft
1346        let solution_re = array![
1347            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1348            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1349            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1350            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1351            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1352            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1353        ];
1354
1355        let solution_im = array![
1356            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1357            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1358            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1359            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1360            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1361            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1362        ];
1363
1364        // Transpose the arrays
1365        let solution_re_t = solution_re.t().as_standard_layout().to_owned();
1366        let solution_im_t = solution_im.t().as_standard_layout().to_owned();
1367
1368        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re_t.raw_dim());
1369        for (s, (s_re, s_im)) in solution
1370            .iter_mut()
1371            .zip(solution_re_t.iter().zip(solution_im_t.iter()))
1372        {
1373            s.re = *s_re;
1374            s.im = *s_im;
1375        }
1376
1377        // Setup
1378        let mut v = test_matrix_complex().t().as_standard_layout().to_owned();
1379        let v_copy = v.clone();
1380        let (nx, _) = (v.shape()[0], v.shape()[1]);
1381        let mut handler: FftHandler<f64> = FftHandler::new(nx);
1382
1383        // Transform
1384        ndfft_inplace(&mut v, &mut handler, 0);
1385        approx_eq_complex(&v, &solution);
1386
1387        ndifft_inplace(&mut v, &mut handler, 0);
1388        approx_eq_complex(&v, &v_copy);
1389    }
1390
1391    #[cfg(feature = "parallel")]
1392    #[test]
1393    fn test_fft_par() {
1394        // Solution from np.fft.fft
1395        let solution_re = array![
1396            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1397            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1398            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1399            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1400            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1401            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1402        ];
1403
1404        let solution_im = array![
1405            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1406            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1407            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1408            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1409            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1410            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1411        ];
1412
1413        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1414        for (s, (s_re, s_im)) in solution
1415            .iter_mut()
1416            .zip(solution_re.iter().zip(solution_im.iter()))
1417        {
1418            s.re = *s_re;
1419            s.im = *s_im;
1420        }
1421
1422        // Setup
1423        let mut v = test_matrix_complex();
1424        let v_copy = v.clone();
1425        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1426        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
1427        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1428
1429        // Transform
1430        ndfft_par(&v, &mut vhat, &mut handler, 1);
1431        ndifft_par(&vhat, &mut v, &mut handler, 1);
1432
1433        // Assert
1434        approx_eq_complex(&vhat, &solution);
1435        approx_eq_complex(&v, &v_copy);
1436    }
1437
1438    #[cfg(feature = "parallel")]
1439    #[test]
1440    fn test_fft_inplace_par() {
1441        // Solution from np.fft.fft
1442        let solution_re = array![
1443            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1444            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1445            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1446            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1447            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1448            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1449        ];
1450
1451        let solution_im = array![
1452            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1453            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1454            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1455            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1456            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1457            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1458        ];
1459
1460        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1461        for (s, (s_re, s_im)) in solution
1462            .iter_mut()
1463            .zip(solution_re.iter().zip(solution_im.iter()))
1464        {
1465            s.re = *s_re;
1466            s.im = *s_im;
1467        }
1468
1469        // Setup
1470        let mut v = test_matrix_complex();
1471        let v_copy = v.clone();
1472        let (nx, _) = (v.shape()[0], v.shape()[1]);
1473        let mut handler: FftHandler<f64> = FftHandler::new(nx);
1474
1475        // Transform
1476        ndfft_inplace_par(&mut v, &mut handler, 1);
1477        approx_eq_complex(&v, &solution);
1478
1479        ndifft_inplace_par(&mut v, &mut handler, 1);
1480        approx_eq_complex(&v, &v_copy);
1481    }
1482
1483    #[cfg(feature = "parallel")]
1484    #[test]
1485    fn test_fft_inplace_par_axis0() {
1486        // Solution from np.fft.fft
1487        let solution_re = array![
1488            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1489            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1490            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1491            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1492            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1493            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1494        ];
1495
1496        let solution_im = array![
1497            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1498            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1499            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1500            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1501            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1502            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1503        ];
1504
1505        // Transpose the arrays
1506        let solution_re_t = solution_re.t().as_standard_layout().to_owned();
1507        let solution_im_t = solution_im.t().as_standard_layout().to_owned();
1508
1509        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re_t.raw_dim());
1510        for (s, (s_re, s_im)) in solution
1511            .iter_mut()
1512            .zip(solution_re_t.iter().zip(solution_im_t.iter()))
1513        {
1514            s.re = *s_re;
1515            s.im = *s_im;
1516        }
1517
1518        // Setup
1519        let mut v = test_matrix_complex().t().as_standard_layout().to_owned();
1520        let v_copy = v.clone();
1521        let (nx, _) = (v.shape()[0], v.shape()[1]);
1522        let mut handler: FftHandler<f64> = FftHandler::new(nx);
1523
1524        // Transform
1525        ndfft_inplace_par(&mut v, &mut handler, 0);
1526        approx_eq_complex(&v, &solution);
1527
1528        ndifft_inplace_par(&mut v, &mut handler, 0);
1529        approx_eq_complex(&v, &v_copy);
1530    }
1531
1532    #[test]
1533    fn test_fft_f_layout() {
1534        // Solution from np.fft.fft
1535        let solution_re = array![
1536            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1537            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1538            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1539            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1540            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1541            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1542        ];
1543
1544        let solution_im = array![
1545            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1546            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1547            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1548            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1549            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1550            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1551        ];
1552
1553        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1554        for (s, (s_re, s_im)) in solution
1555            .iter_mut()
1556            .zip(solution_re.iter().zip(solution_im.iter()))
1557        {
1558            s.re = *s_re;
1559            s.im = *s_im;
1560        }
1561
1562        // Setup
1563        let mut v = test_matrix_complex_f();
1564        let v_copy = v.clone();
1565        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1566        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny));
1567        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1568
1569        // Transform
1570        ndfft(&v, &mut vhat, &mut handler, 1);
1571        ndifft(&vhat, &mut v, &mut handler, 1);
1572
1573        // Assert
1574        approx_eq_complex(&vhat, &solution);
1575        approx_eq_complex(&v, &v_copy);
1576    }
1577
1578    #[test]
1579    fn test_fft_inplace_f_layout() {
1580        // Solution from np.fft.fft
1581        let solution_re = array![
1582            [0.61, 3.105, 2.508, 0.048, -3.652, -2.019],
1583            [2.795, 0.612, 0.219, 1.179, -2.801, 3.276],
1584            [2.259, 0.601, 0.045, 0.979, 4.506, 0.118],
1585            [-0.296, -0.896, 0.544, -4.282, 3.544, 6.288],
1586            [0.573, -0.96, -3.85, -2.613, -0.461, 4.467],
1587            [3.978, -2.229, 0.133, 1.154, -6.544, -0.962],
1588        ];
1589
1590        let solution_im = array![
1591            [0.61, -2.019, -3.652, 0.048, 2.508, 3.105],
1592            [2.795, 3.276, -2.801, 1.179, 0.219, 0.612],
1593            [2.259, 0.118, 4.506, 0.979, 0.045, 0.601],
1594            [-0.296, 6.288, 3.544, -4.282, 0.544, -0.896],
1595            [0.573, 4.467, -0.461, -2.613, -3.85, -0.96],
1596            [3.978, -0.962, -6.544, 1.154, 0.133, -2.229],
1597        ];
1598
1599        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1600        for (s, (s_re, s_im)) in solution
1601            .iter_mut()
1602            .zip(solution_re.iter().zip(solution_im.iter()))
1603        {
1604            s.re = *s_re;
1605            s.im = *s_im;
1606        }
1607
1608        // Setup
1609        let mut v = test_matrix_complex_f();
1610        let v_copy = v.clone();
1611        let (_, ny) = (v.shape()[0], v.shape()[1]);
1612        let mut handler: FftHandler<f64> = FftHandler::new(ny);
1613
1614        // Transform
1615        ndfft_inplace(&mut v, &mut handler, 1);
1616        approx_eq_complex(&v, &solution);
1617
1618        ndifft_inplace(&mut v, &mut handler, 1);
1619        approx_eq_complex(&v, &v_copy);
1620    }
1621
1622    #[test]
1623    fn test_fft_r2c() {
1624        // Solution from np.fft.rfft
1625        let solution_re = array![
1626            [0.61, 0.543, -0.572, 0.048],
1627            [2.795, 1.944, -1.291, 1.179],
1628            [2.259, 0.36, 2.275, 0.979],
1629            [-0.296, 2.696, 2.044, -4.282],
1630            [0.573, 1.753, -2.155, -2.613],
1631            [3.978, -1.596, -3.205, 1.154],
1632        ];
1633
1634        let solution_im = array![
1635            [0., -2.562, -3.08, 0.],
1636            [0., 1.332, -1.51, 0.],
1637            [0., -0.242, 2.23, 0.],
1638            [0., 3.592, 1.5, 0.],
1639            [0., 2.713, 1.695, 0.],
1640            [0., 0.633, -3.339, 0.],
1641        ];
1642
1643        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1644        for (s, (s_re, s_im)) in solution
1645            .iter_mut()
1646            .zip(solution_re.iter().zip(solution_im.iter()))
1647        {
1648            s.re = *s_re;
1649            s.im = *s_im;
1650        }
1651
1652        // Setup
1653        let mut v = test_matrix();
1654        let v_copy = v.clone();
1655        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1656        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny / 2 + 1));
1657        let mut handler = R2cFftHandler::<f64>::new(ny);
1658
1659        // Transform
1660        ndfft_r2c(&v, &mut vhat, &mut handler, 1);
1661        ndifft_r2c(&vhat, &mut v, &mut handler, 1);
1662
1663        // Assert
1664        approx_eq_complex(&vhat, &solution);
1665        approx_eq(&v, &v_copy);
1666    }
1667
1668    #[cfg(feature = "parallel")]
1669    #[test]
1670    fn test_fft_r2c_par() {
1671        // Solution from np.fft.rfft
1672        let solution_re = array![
1673            [0.61, 0.543, -0.572, 0.048],
1674            [2.795, 1.944, -1.291, 1.179],
1675            [2.259, 0.36, 2.275, 0.979],
1676            [-0.296, 2.696, 2.044, -4.282],
1677            [0.573, 1.753, -2.155, -2.613],
1678            [3.978, -1.596, -3.205, 1.154],
1679        ];
1680
1681        let solution_im = array![
1682            [0., -2.562, -3.08, 0.],
1683            [0., 1.332, -1.51, 0.],
1684            [0., -0.242, 2.23, 0.],
1685            [0., 3.592, 1.5, 0.],
1686            [0., 2.713, 1.695, 0.],
1687            [0., 0.633, -3.339, 0.],
1688        ];
1689
1690        let mut solution: Array2<Complex<f64>> = Array2::zeros(solution_re.raw_dim());
1691        for (s, (s_re, s_im)) in solution
1692            .iter_mut()
1693            .zip(solution_re.iter().zip(solution_im.iter()))
1694        {
1695            s.re = *s_re;
1696            s.im = *s_im;
1697        }
1698
1699        // Setup
1700        let mut v = test_matrix();
1701        let v_copy = v.clone();
1702        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1703        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny / 2 + 1));
1704        let mut handler = R2cFftHandler::<f64>::new(ny);
1705
1706        // Transform
1707        ndfft_r2c_par(&v, &mut vhat, &mut handler, 1);
1708        ndifft_r2c_par(&vhat, &mut v, &mut handler, 1);
1709
1710        // Assert
1711        approx_eq_complex(&vhat, &solution);
1712        approx_eq(&v, &v_copy);
1713    }
1714
1715    #[test]
1716    fn test_ifft_c2r_first_last_element() {
1717        let n = 6;
1718        let mut v = Array1::<f64>::zeros(n);
1719        let mut vhat = Array1::<Complex<f64>>::zeros(n / 2 + 1);
1720        let solution_numpy_first_elem: Array1<f64> =
1721            array![0.16667, 0.16667, 0.16667, 0.16667, 0.16667, 0.16667];
1722        let solution_numpy_last_elem: Array1<f64> =
1723            array![0.16667, -0.16667, 0.16667, -0.16667, 0.16667, -0.16667];
1724        let mut rfft_handler = R2cFftHandler::<f64>::new(n);
1725
1726        // First element should be purely real, thus the imaginary
1727        // part should not matter. However, original realfft gives
1728        // different results for different imaginary parts
1729        vhat[0].re = 1.;
1730        vhat[0].im = 100.;
1731        // backward
1732        ndifft_r2c(&vhat, &mut v, &mut rfft_handler, 0);
1733        // assert
1734        approx_eq(&v, &solution_numpy_first_elem);
1735
1736        // Same for last element, if input is even
1737        for v in vhat.iter_mut() {
1738            v.re = 0.;
1739            v.im = 0.;
1740        }
1741        vhat[3].re = 1.;
1742        vhat[3].im = 100.;
1743        // backward
1744        ndifft_r2c(&vhat, &mut v, &mut rfft_handler, 0);
1745        // assert
1746        approx_eq(&v, &solution_numpy_last_elem);
1747    }
1748
1749    #[test]
1750    fn test_fft_r2c_odd() {
1751        // Setup
1752        let mut v = array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.],];
1753        let v_copy = v.clone();
1754        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1755        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny / 2 + 1));
1756        let mut handler = R2cFftHandler::<f64>::new(ny);
1757
1758        // Transform
1759        ndfft_r2c(&v, &mut vhat, &mut handler, 1);
1760        ndifft_r2c(&vhat, &mut v, &mut handler, 1);
1761
1762        // Assert
1763        approx_eq(&v, &v_copy);
1764    }
1765
1766    #[cfg(feature = "parallel")]
1767    #[test]
1768    fn test_fft_r2c_odd_par() {
1769        // Setup
1770        let mut v = array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.],];
1771        let v_copy = v.clone();
1772        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1773        let mut vhat = Array2::<Complex<f64>>::zeros((nx, ny / 2 + 1));
1774        let mut handler = R2cFftHandler::<f64>::new(ny);
1775
1776        // Transform
1777        ndfft_r2c_par(&v, &mut vhat, &mut handler, 1);
1778        ndifft_r2c_par(&vhat, &mut v, &mut handler, 1);
1779
1780        // Assert
1781        approx_eq(&v, &v_copy);
1782    }
1783
1784    #[test]
1785    fn test_dct1() {
1786        // Solution from scipy.fft.dct(x, type=1)
1787        let solution = array![
1788            [2.469, 4.259, 0.6, 0.04, -4.957, -1.353],
1789            [3.953, -0.374, 4.759, -0.436, -2.643, 2.235],
1790            [2.632, 0.818, -1.609, 1.053, 5.008, 1.008],
1791            [-3.652, -2.628, 4.81, 2.632, 4.666, -7.138],
1792            [-0.835, -2.982, 4.105, -3.192, 1.265, -2.297],
1793            [8.743, -2.422, 1.167, -0.841, -7.506, 3.011],
1794        ];
1795
1796        // Setup
1797        let v = test_matrix();
1798        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1799        let mut vhat = Array2::<f64>::zeros((nx, ny));
1800        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1801
1802        // Transform
1803        nddct1(&v, &mut vhat, &mut handler, 1);
1804
1805        // Assert
1806        approx_eq(&vhat, &solution);
1807    }
1808
1809    #[test]
1810    fn test_dct1_inplace() {
1811        // Solution from scipy.fft.dct(x, type=1)
1812        let solution = array![
1813            [2.469, 4.259, 0.6, 0.04, -4.957, -1.353],
1814            [3.953, -0.374, 4.759, -0.436, -2.643, 2.235],
1815            [2.632, 0.818, -1.609, 1.053, 5.008, 1.008],
1816            [-3.652, -2.628, 4.81, 2.632, 4.666, -7.138],
1817            [-0.835, -2.982, 4.105, -3.192, 1.265, -2.297],
1818            [8.743, -2.422, 1.167, -0.841, -7.506, 3.011],
1819        ];
1820
1821        // Setup
1822        let mut v = test_matrix();
1823        let (_, ny) = (v.shape()[0], v.shape()[1]);
1824        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1825
1826        // Transform
1827        nddct1_inplace(&mut v, &mut handler, 1);
1828
1829        // Assert
1830        approx_eq(&v, &solution);
1831    }
1832
1833    #[cfg(feature = "parallel")]
1834    #[test]
1835    fn test_dct1_par() {
1836        // Solution from scipy.fft.dct(x, type=1)
1837        let solution = array![
1838            [2.469, 4.259, 0.6, 0.04, -4.957, -1.353],
1839            [3.953, -0.374, 4.759, -0.436, -2.643, 2.235],
1840            [2.632, 0.818, -1.609, 1.053, 5.008, 1.008],
1841            [-3.652, -2.628, 4.81, 2.632, 4.666, -7.138],
1842            [-0.835, -2.982, 4.105, -3.192, 1.265, -2.297],
1843            [8.743, -2.422, 1.167, -0.841, -7.506, 3.011],
1844        ];
1845
1846        // Setup
1847        let v = test_matrix();
1848        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1849        let mut vhat = Array2::<f64>::zeros((nx, ny));
1850        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1851
1852        // Transform
1853        nddct1_par(&v, &mut vhat, &mut handler, 1);
1854
1855        // Assert
1856        approx_eq(&vhat, &solution);
1857    }
1858
1859    #[cfg(feature = "parallel")]
1860    #[test]
1861    fn test_dct1_inplace_par() {
1862        // Solution from scipy.fft.dct(x, type=1)
1863        let solution = array![
1864            [2.469, 4.259, 0.6, 0.04, -4.957, -1.353],
1865            [3.953, -0.374, 4.759, -0.436, -2.643, 2.235],
1866            [2.632, 0.818, -1.609, 1.053, 5.008, 1.008],
1867            [-3.652, -2.628, 4.81, 2.632, 4.666, -7.138],
1868            [-0.835, -2.982, 4.105, -3.192, 1.265, -2.297],
1869            [8.743, -2.422, 1.167, -0.841, -7.506, 3.011],
1870        ];
1871
1872        // Setup
1873        let mut v = test_matrix();
1874        let (_, ny) = (v.shape()[0], v.shape()[1]);
1875        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1876
1877        // Transform
1878        nddct1_inplace(&mut v, &mut handler, 1);
1879
1880        // Assert
1881        approx_eq(&v, &solution);
1882    }
1883
1884    #[test]
1885    fn test_dct2() {
1886        // Solution from scipy.fft.dct(x, type=2)
1887        let solution = array![
1888            [1.22, 5.25, -1.621, -0.619, -5.906, -1.105],
1889            [5.59, -0.209, 4.699, 0.134, -3.907, 1.838],
1890            [4.518, 1.721, 0.381, 1.492, 6.138, 0.513],
1891            [-0.592, -3.746, 8.262, 1.31, 4.642, -6.125],
1892            [1.146, -5.709, 5.75, -4.275, 0.78, -0.963],
1893            [7.956, -2.873, -2.13, 0.006, -8.988, 2.56],
1894        ];
1895
1896        // Setup
1897        let v = test_matrix();
1898        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1899        let mut vhat = Array2::<f64>::zeros((nx, ny));
1900        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1901
1902        // Transform
1903        nddct2(&v, &mut vhat, &mut handler, 1);
1904
1905        // Assert
1906        approx_eq(&vhat, &solution);
1907    }
1908
1909    #[cfg(feature = "parallel")]
1910    #[test]
1911    fn test_dct2_par() {
1912        // Solution from scipy.fft.dct(x, type=2)
1913        let solution = array![
1914            [1.22, 5.25, -1.621, -0.619, -5.906, -1.105],
1915            [5.59, -0.209, 4.699, 0.134, -3.907, 1.838],
1916            [4.518, 1.721, 0.381, 1.492, 6.138, 0.513],
1917            [-0.592, -3.746, 8.262, 1.31, 4.642, -6.125],
1918            [1.146, -5.709, 5.75, -4.275, 0.78, -0.963],
1919            [7.956, -2.873, -2.13, 0.006, -8.988, 2.56],
1920        ];
1921
1922        // Setup
1923        let v = test_matrix();
1924        let (nx, ny) = (v.shape()[0], v.shape()[1]);
1925        let mut vhat = Array2::<f64>::zeros((nx, ny));
1926        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1927
1928        // Transform
1929        nddct2_par(&v, &mut vhat, &mut handler, 1);
1930
1931        // Assert
1932        approx_eq(&vhat, &solution);
1933    }
1934
1935    #[cfg(feature = "parallel")]
1936    #[test]
1937    fn test_dct2_inplace_par() {
1938        // Solution from scipy.fft.dct(x, type=2)
1939        let solution = array![
1940            [1.22, 5.25, -1.621, -0.619, -5.906, -1.105],
1941            [5.59, -0.209, 4.699, 0.134, -3.907, 1.838],
1942            [4.518, 1.721, 0.381, 1.492, 6.138, 0.513],
1943            [-0.592, -3.746, 8.262, 1.31, 4.642, -6.125],
1944            [1.146, -5.709, 5.75, -4.275, 0.78, -0.963],
1945            [7.956, -2.873, -2.13, 0.006, -8.988, 2.56],
1946        ];
1947
1948        // Setup
1949        let mut v = test_matrix();
1950        let (_, ny) = (v.shape()[0], v.shape()[1]);
1951        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1952
1953        // Transform
1954        nddct2_inplace_par(&mut v, &mut handler, 1);
1955
1956        // Assert
1957        approx_eq(&v, &solution);
1958    }
1959
1960    #[cfg(feature = "parallel")]
1961    #[test]
1962    fn test_dct2_inplace_par_axis0() {
1963        // Solution from scipy.fft.dct(x, type=2)
1964        let solution = array![
1965            [1.22, 5.25, -1.621, -0.619, -5.906, -1.105],
1966            [5.59, -0.209, 4.699, 0.134, -3.907, 1.838],
1967            [4.518, 1.721, 0.381, 1.492, 6.138, 0.513],
1968            [-0.592, -3.746, 8.262, 1.31, 4.642, -6.125],
1969            [1.146, -5.709, 5.75, -4.275, 0.78, -0.963],
1970            [7.956, -2.873, -2.13, 0.006, -8.988, 2.56],
1971        ]
1972        .t()
1973        .as_standard_layout()
1974        .to_owned();
1975
1976        // Setup
1977        let mut v = test_matrix().t().as_standard_layout().to_owned();
1978        let (_, ny) = (v.shape()[0], v.shape()[1]);
1979        let mut handler: DctHandler<f64> = DctHandler::new(ny);
1980
1981        // Transform
1982        nddct2_inplace_par(&mut v, &mut handler, 0);
1983
1984        // Assert
1985        approx_eq(&v, &solution);
1986    }
1987
1988    #[test]
1989    fn test_dct3() {
1990        // Solution from scipy.fft.dct(x, type=3)
1991        let solution = array![
1992            [2.898, 4.571, -0.801, 1.65, -5.427, -2.291],
1993            [2.701, -0.578, 5.768, -0.335, -3.158, 0.882],
1994            [2.348, -0.184, -1.258, 0.048, 5.472, 2.081],
1995            [-3.421, -2.075, 6.944, 0.264, 7.505, -4.315],
1996            [-1.43, -3.023, 6.317, -5.259, 1.991, -1.44],
1997            [5.76, -4.047, 1.974, 0.376, -8.651, 0.117],
1998        ];
1999
2000        // Setup
2001        let v = test_matrix();
2002        let (nx, ny) = (v.shape()[0], v.shape()[1]);
2003        let mut vhat = Array2::<f64>::zeros((nx, ny));
2004        let mut handler: DctHandler<f64> = DctHandler::new(ny);
2005
2006        // Transform
2007        nddct3(&v, &mut vhat, &mut handler, 1);
2008
2009        // Assert
2010        approx_eq(&vhat, &solution);
2011    }
2012
2013    #[cfg(feature = "parallel")]
2014    #[test]
2015    // See: https://github.com/preiter93/ndrustfft/issues/25
2016    fn test_dct2_3d_inplace_par_vs_serial_axis0() {
2017        use ndarray::Array3;
2018
2019        let (length, height, width) = (3, 3, 3);
2020        let handler = DctHandler::new(length).normalization(Normalization::None);
2021
2022        let mut v_par =
2023            Array3::from_shape_fn((length, height, width), |(i, j, k)| (i + j + k) as f64);
2024        let mut v_ser = v_par.clone();
2025
2026        nddct2_inplace_par(&mut v_par, &handler, 0);
2027        nddct2_inplace(&mut v_ser, &handler, 0);
2028
2029        approx_eq(&v_par, &v_ser);
2030    }
2031
2032    #[cfg(feature = "parallel")]
2033    #[test]
2034    fn test_dct3_par() {
2035        // Solution from scipy.fft.dct(x, type=3)
2036        let solution = array![
2037            [2.898, 4.571, -0.801, 1.65, -5.427, -2.291],
2038            [2.701, -0.578, 5.768, -0.335, -3.158, 0.882],
2039            [2.348, -0.184, -1.258, 0.048, 5.472, 2.081],
2040            [-3.421, -2.075, 6.944, 0.264, 7.505, -4.315],
2041            [-1.43, -3.023, 6.317, -5.259, 1.991, -1.44],
2042            [5.76, -4.047, 1.974, 0.376, -8.651, 0.117],
2043        ];
2044
2045        // Setup
2046        let v = test_matrix();
2047        let (nx, ny) = (v.shape()[0], v.shape()[1]);
2048        let mut vhat = Array2::<f64>::zeros((nx, ny));
2049        let mut handler: DctHandler<f64> = DctHandler::new(ny);
2050
2051        // Transform
2052        nddct3_par(&v, &mut vhat, &mut handler, 1);
2053
2054        // Assert
2055        approx_eq(&vhat, &solution);
2056    }
2057
2058    #[test]
2059    fn test_dct4() {
2060        // Solution from scipy.fft.dct(x, type=4)
2061        let solution = array![
2062            [3.18, 2.73, -2.314, -2.007, -5.996, 2.127],
2063            [3.175, 0.865, 4.939, -4.305, -0.443, 1.568],
2064            [3.537, 0.677, 0.371, 4.186, 4.528, -1.531],
2065            [-2.687, 1.838, 6.968, 0.899, 2.456, -8.79],
2066            [-2.289, -1.002, 3.67, -5.705, 3.867, -4.349],
2067            [4.192, -5.626, 1.789, -6.057, -4.61, 4.627],
2068        ];
2069
2070        // Setup
2071        let v = test_matrix();
2072        let (nx, ny) = (v.shape()[0], v.shape()[1]);
2073        let mut vhat = Array2::<f64>::zeros((nx, ny));
2074        let mut handler: DctHandler<f64> = DctHandler::new(ny);
2075
2076        // Transform
2077        nddct4(&v, &mut vhat, &mut handler, 1);
2078
2079        // Assert
2080        approx_eq(&vhat, &solution);
2081    }
2082
2083    #[cfg(feature = "parallel")]
2084    #[test]
2085    fn test_dct4_par() {
2086        // Solution from scipy.fft.dct(x, type=4)
2087        let solution = array![
2088            [3.18, 2.73, -2.314, -2.007, -5.996, 2.127],
2089            [3.175, 0.865, 4.939, -4.305, -0.443, 1.568],
2090            [3.537, 0.677, 0.371, 4.186, 4.528, -1.531],
2091            [-2.687, 1.838, 6.968, 0.899, 2.456, -8.79],
2092            [-2.289, -1.002, 3.67, -5.705, 3.867, -4.349],
2093            [4.192, -5.626, 1.789, -6.057, -4.61, 4.627],
2094        ];
2095
2096        // Setup
2097        let v = test_matrix();
2098        let (nx, ny) = (v.shape()[0], v.shape()[1]);
2099        let mut vhat = Array2::<f64>::zeros((nx, ny));
2100        let mut handler: DctHandler<f64> = DctHandler::new(ny);
2101
2102        // Transform
2103        nddct4_par(&v, &mut vhat, &mut handler, 1);
2104
2105        // Assert
2106        approx_eq(&vhat, &solution);
2107    }
2108}