rustfft/algorithm/
dft.rs

1use num_complex::Complex;
2use num_traits::Zero;
3
4use crate::{twiddles, FftDirection};
5use crate::{Direction, Fft, FftNum, Length};
6
7/// Naive O(n^2 ) Discrete Fourier Transform implementation
8///
9/// This implementation is primarily used to test other FFT algorithms.
10///
11/// ~~~
12/// // Computes a naive DFT of size 123
13/// use rustfft::algorithm::Dft;
14/// use rustfft::{Fft, FftDirection};
15/// use rustfft::num_complex::Complex;
16///
17/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 123];
18///
19/// let dft = Dft::new(123, FftDirection::Forward);
20/// dft.process(&mut buffer);
21/// ~~~
22pub struct Dft<T> {
23    twiddles: Vec<Complex<T>>,
24    direction: FftDirection,
25}
26
27impl<T: FftNum> Dft<T> {
28    /// Preallocates necessary arrays and precomputes necessary data to efficiently compute Dft
29    pub fn new(len: usize, direction: FftDirection) -> Self {
30        let twiddles = (0..len)
31            .map(|i| twiddles::compute_twiddle(i, len, direction))
32            .collect();
33        Self {
34            twiddles,
35            direction,
36        }
37    }
38
39    fn inplace_scratch_len(&self) -> usize {
40        self.len()
41    }
42    fn outofplace_scratch_len(&self) -> usize {
43        0
44    }
45    fn immut_scratch_len(&self) -> usize {
46        0
47    }
48
49    fn perform_fft_immut(
50        &self,
51        signal: &[Complex<T>],
52        spectrum: &mut [Complex<T>],
53        _scratch: &mut [Complex<T>],
54    ) {
55        for k in 0..spectrum.len() {
56            let output_cell = spectrum.get_mut(k).unwrap();
57
58            *output_cell = Zero::zero();
59            let mut twiddle_index = 0;
60
61            for input_cell in signal {
62                let twiddle = self.twiddles[twiddle_index];
63                *output_cell = *output_cell + twiddle * input_cell;
64
65                twiddle_index += k;
66                if twiddle_index >= self.twiddles.len() {
67                    twiddle_index -= self.twiddles.len();
68                }
69            }
70        }
71    }
72
73    fn perform_fft_out_of_place(
74        &self,
75        signal: &[Complex<T>],
76        spectrum: &mut [Complex<T>],
77        _scratch: &mut [Complex<T>],
78    ) {
79        self.perform_fft_immut(signal, spectrum, _scratch);
80    }
81}
82boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len());
83
84#[cfg(test)]
85mod unit_tests {
86    use super::*;
87    use crate::test_utils::{compare_vectors, random_signal};
88    use num_complex::Complex;
89    use num_traits::Zero;
90    use std::f32;
91
92    fn dft(signal: &[Complex<f32>], spectrum: &mut [Complex<f32>]) {
93        for (k, spec_bin) in spectrum.iter_mut().enumerate() {
94            let mut sum = Zero::zero();
95            for (i, &x) in signal.iter().enumerate() {
96                let angle = -1f32 * (i * k) as f32 * 2f32 * f32::consts::PI / signal.len() as f32;
97                let twiddle = Complex::from_polar(1f32, angle);
98
99                sum = sum + twiddle * x;
100            }
101            *spec_bin = sum;
102        }
103    }
104
105    #[test]
106    fn test_matches_dft() {
107        let n = 4;
108
109        for len in 1..20 {
110            let dft_instance = Dft::new(len, FftDirection::Forward);
111            assert_eq!(
112                dft_instance.len(),
113                len,
114                "Dft instance reported incorrect length"
115            );
116
117            let input = random_signal(len * n);
118            let mut expected_output = input.clone();
119
120            // Compute the control data using our simplified Dft definition
121            for (input_chunk, output_chunk) in
122                input.chunks(len).zip(expected_output.chunks_mut(len))
123            {
124                dft(input_chunk, output_chunk);
125            }
126
127            // test process()
128            {
129                let mut inplace_buffer = input.clone();
130
131                dft_instance.process(&mut inplace_buffer);
132
133                assert!(
134                    compare_vectors(&expected_output, &inplace_buffer),
135                    "process() failed, length = {}",
136                    len
137                );
138            }
139
140            // test process_with_scratch()
141            {
142                let mut inplace_with_scratch_buffer = input.clone();
143                let mut inplace_scratch =
144                    vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
145
146                dft_instance
147                    .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
148
149                assert!(
150                    compare_vectors(&expected_output, &inplace_with_scratch_buffer),
151                    "process_inplace() failed, length = {}",
152                    len
153                );
154
155                // one more thing: make sure that the Dft algorithm even works with dirty scratch space
156                for item in inplace_scratch.iter_mut() {
157                    *item = Complex::new(100.0, 100.0);
158                }
159                inplace_with_scratch_buffer.copy_from_slice(&input);
160
161                dft_instance
162                    .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
163
164                assert!(
165                    compare_vectors(&expected_output, &inplace_with_scratch_buffer),
166                    "process_with_scratch() failed the 'dirty scratch' test for len = {}",
167                    len
168                );
169            }
170
171            // test process_outofplace_with_scratch
172            {
173                let mut outofplace_input = input.clone();
174                let mut outofplace_output = expected_output.clone();
175
176                dft_instance.process_outofplace_with_scratch(
177                    &mut outofplace_input,
178                    &mut outofplace_output,
179                    &mut [],
180                );
181
182                assert!(
183                    compare_vectors(&expected_output, &outofplace_output),
184                    "process_outofplace_with_scratch() failed, length = {}",
185                    len
186                );
187            }
188        }
189
190        //verify that it doesn't crash or infinite loop if we have a length of 0
191        let zero_dft = Dft::new(0, FftDirection::Forward);
192        let mut zero_input: Vec<Complex<f32>> = Vec::new();
193        let mut zero_output: Vec<Complex<f32>> = Vec::new();
194        let mut zero_scratch: Vec<Complex<f32>> = Vec::new();
195
196        zero_dft.process(&mut zero_input);
197        zero_dft.process_with_scratch(&mut zero_input, &mut zero_scratch);
198        zero_dft.process_outofplace_with_scratch(
199            &mut zero_input,
200            &mut zero_output,
201            &mut zero_scratch,
202        );
203    }
204
205    /// Returns true if our `dft` function calculates the given output from the
206    /// given input, and if rustfft's Dft struct does the same
207    fn test_dft_correct(input: &[Complex<f32>], expected_output: &[Complex<f32>]) {
208        assert_eq!(input.len(), expected_output.len());
209        let len = input.len();
210
211        let mut reference_output = vec![Zero::zero(); len];
212        dft(&input, &mut reference_output);
213        assert!(
214            compare_vectors(expected_output, &reference_output),
215            "Reference implementation failed for len={}",
216            len
217        );
218
219        let dft_instance = Dft::new(len, FftDirection::Forward);
220
221        // test process()
222        {
223            let mut inplace_buffer = input.to_vec();
224
225            dft_instance.process(&mut inplace_buffer);
226
227            assert!(
228                compare_vectors(&expected_output, &inplace_buffer),
229                "process() failed, length = {}",
230                len
231            );
232        }
233
234        // test process_with_scratch()
235        {
236            let mut inplace_with_scratch_buffer = input.to_vec();
237            let mut inplace_scratch = vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
238
239            dft_instance
240                .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
241
242            assert!(
243                compare_vectors(&expected_output, &inplace_with_scratch_buffer),
244                "process_inplace() failed, length = {}",
245                len
246            );
247
248            // one more thing: make sure that the Dft algorithm even works with dirty scratch space
249            for item in inplace_scratch.iter_mut() {
250                *item = Complex::new(100.0, 100.0);
251            }
252            inplace_with_scratch_buffer.copy_from_slice(&input);
253
254            dft_instance
255                .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
256
257            assert!(
258                compare_vectors(&expected_output, &inplace_with_scratch_buffer),
259                "process_with_scratch() failed the 'dirty scratch' test for len = {}",
260                len
261            );
262        }
263
264        // test process_outofplace_with_scratch
265        {
266            let mut outofplace_input = input.to_vec();
267            let mut outofplace_output = expected_output.to_vec();
268
269            dft_instance.process_outofplace_with_scratch(
270                &mut outofplace_input,
271                &mut outofplace_output,
272                &mut [],
273            );
274
275            assert!(
276                compare_vectors(&expected_output, &outofplace_output),
277                "process_outofplace_with_scratch() failed, length = {}",
278                len
279            );
280        }
281    }
282
283    #[test]
284    fn test_dft_known_len_2() {
285        let signal = [
286            Complex { re: 1f32, im: 0f32 },
287            Complex {
288                re: -1f32,
289                im: 0f32,
290            },
291        ];
292        let spectrum = [
293            Complex { re: 0f32, im: 0f32 },
294            Complex { re: 2f32, im: 0f32 },
295        ];
296        test_dft_correct(&signal[..], &spectrum[..]);
297    }
298
299    #[test]
300    fn test_dft_known_len_3() {
301        let signal = [
302            Complex { re: 1f32, im: 1f32 },
303            Complex {
304                re: 2f32,
305                im: -3f32,
306            },
307            Complex {
308                re: -1f32,
309                im: 4f32,
310            },
311        ];
312        let spectrum = [
313            Complex { re: 2f32, im: 2f32 },
314            Complex {
315                re: -5.562177f32,
316                im: -2.098076f32,
317            },
318            Complex {
319                re: 6.562178f32,
320                im: 3.09807f32,
321            },
322        ];
323        test_dft_correct(&signal[..], &spectrum[..]);
324    }
325
326    #[test]
327    fn test_dft_known_len_4() {
328        let signal = [
329            Complex { re: 0f32, im: 1f32 },
330            Complex {
331                re: 2.5f32,
332                im: -3f32,
333            },
334            Complex {
335                re: -1f32,
336                im: -1f32,
337            },
338            Complex { re: 4f32, im: 0f32 },
339        ];
340        let spectrum = [
341            Complex {
342                re: 5.5f32,
343                im: -3f32,
344            },
345            Complex {
346                re: -2f32,
347                im: 3.5f32,
348            },
349            Complex {
350                re: -7.5f32,
351                im: 3f32,
352            },
353            Complex {
354                re: 4f32,
355                im: 0.5f32,
356            },
357        ];
358        test_dft_correct(&signal[..], &spectrum[..]);
359    }
360
361    #[test]
362    fn test_dft_known_len_6() {
363        let signal = [
364            Complex { re: 1f32, im: 1f32 },
365            Complex { re: 2f32, im: 2f32 },
366            Complex { re: 3f32, im: 3f32 },
367            Complex { re: 4f32, im: 4f32 },
368            Complex { re: 5f32, im: 5f32 },
369            Complex { re: 6f32, im: 6f32 },
370        ];
371        let spectrum = [
372            Complex {
373                re: 21f32,
374                im: 21f32,
375            },
376            Complex {
377                re: -8.16f32,
378                im: 2.16f32,
379            },
380            Complex {
381                re: -4.76f32,
382                im: -1.24f32,
383            },
384            Complex {
385                re: -3f32,
386                im: -3f32,
387            },
388            Complex {
389                re: -1.24f32,
390                im: -4.76f32,
391            },
392            Complex {
393                re: 2.16f32,
394                im: -8.16f32,
395            },
396        ];
397        test_dft_correct(&signal[..], &spectrum[..]);
398    }
399}