1use num_complex::Complex;
2use num_traits::Zero;
3
4use crate::{twiddles, FftDirection};
5use crate::{Direction, Fft, FftNum, Length};
6
7pub struct Dft<T> {
23 twiddles: Vec<Complex<T>>,
24 direction: FftDirection,
25}
26
27impl<T: FftNum> Dft<T> {
28 pub fn new(len: usize, direction: FftDirection) -> Self {
30 let twiddles = (0..len)
31 .map(|i| twiddles::compute_twiddle(i, len, direction))
32 .collect();
33 Self {
34 twiddles,
35 direction,
36 }
37 }
38
39 fn inplace_scratch_len(&self) -> usize {
40 self.len()
41 }
42 fn outofplace_scratch_len(&self) -> usize {
43 0
44 }
45 fn immut_scratch_len(&self) -> usize {
46 0
47 }
48
49 fn perform_fft_immut(
50 &self,
51 signal: &[Complex<T>],
52 spectrum: &mut [Complex<T>],
53 _scratch: &mut [Complex<T>],
54 ) {
55 for k in 0..spectrum.len() {
56 let output_cell = spectrum.get_mut(k).unwrap();
57
58 *output_cell = Zero::zero();
59 let mut twiddle_index = 0;
60
61 for input_cell in signal {
62 let twiddle = self.twiddles[twiddle_index];
63 *output_cell = *output_cell + twiddle * input_cell;
64
65 twiddle_index += k;
66 if twiddle_index >= self.twiddles.len() {
67 twiddle_index -= self.twiddles.len();
68 }
69 }
70 }
71 }
72
73 fn perform_fft_out_of_place(
74 &self,
75 signal: &[Complex<T>],
76 spectrum: &mut [Complex<T>],
77 _scratch: &mut [Complex<T>],
78 ) {
79 self.perform_fft_immut(signal, spectrum, _scratch);
80 }
81}
82boilerplate_fft_oop!(Dft, |this: &Dft<_>| this.twiddles.len());
83
84#[cfg(test)]
85mod unit_tests {
86 use super::*;
87 use crate::test_utils::{compare_vectors, random_signal};
88 use num_complex::Complex;
89 use num_traits::Zero;
90 use std::f32;
91
92 fn dft(signal: &[Complex<f32>], spectrum: &mut [Complex<f32>]) {
93 for (k, spec_bin) in spectrum.iter_mut().enumerate() {
94 let mut sum = Zero::zero();
95 for (i, &x) in signal.iter().enumerate() {
96 let angle = -1f32 * (i * k) as f32 * 2f32 * f32::consts::PI / signal.len() as f32;
97 let twiddle = Complex::from_polar(1f32, angle);
98
99 sum = sum + twiddle * x;
100 }
101 *spec_bin = sum;
102 }
103 }
104
105 #[test]
106 fn test_matches_dft() {
107 let n = 4;
108
109 for len in 1..20 {
110 let dft_instance = Dft::new(len, FftDirection::Forward);
111 assert_eq!(
112 dft_instance.len(),
113 len,
114 "Dft instance reported incorrect length"
115 );
116
117 let input = random_signal(len * n);
118 let mut expected_output = input.clone();
119
120 for (input_chunk, output_chunk) in
122 input.chunks(len).zip(expected_output.chunks_mut(len))
123 {
124 dft(input_chunk, output_chunk);
125 }
126
127 {
129 let mut inplace_buffer = input.clone();
130
131 dft_instance.process(&mut inplace_buffer);
132
133 assert!(
134 compare_vectors(&expected_output, &inplace_buffer),
135 "process() failed, length = {}",
136 len
137 );
138 }
139
140 {
142 let mut inplace_with_scratch_buffer = input.clone();
143 let mut inplace_scratch =
144 vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
145
146 dft_instance
147 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
148
149 assert!(
150 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
151 "process_inplace() failed, length = {}",
152 len
153 );
154
155 for item in inplace_scratch.iter_mut() {
157 *item = Complex::new(100.0, 100.0);
158 }
159 inplace_with_scratch_buffer.copy_from_slice(&input);
160
161 dft_instance
162 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
163
164 assert!(
165 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
166 "process_with_scratch() failed the 'dirty scratch' test for len = {}",
167 len
168 );
169 }
170
171 {
173 let mut outofplace_input = input.clone();
174 let mut outofplace_output = expected_output.clone();
175
176 dft_instance.process_outofplace_with_scratch(
177 &mut outofplace_input,
178 &mut outofplace_output,
179 &mut [],
180 );
181
182 assert!(
183 compare_vectors(&expected_output, &outofplace_output),
184 "process_outofplace_with_scratch() failed, length = {}",
185 len
186 );
187 }
188 }
189
190 let zero_dft = Dft::new(0, FftDirection::Forward);
192 let mut zero_input: Vec<Complex<f32>> = Vec::new();
193 let mut zero_output: Vec<Complex<f32>> = Vec::new();
194 let mut zero_scratch: Vec<Complex<f32>> = Vec::new();
195
196 zero_dft.process(&mut zero_input);
197 zero_dft.process_with_scratch(&mut zero_input, &mut zero_scratch);
198 zero_dft.process_outofplace_with_scratch(
199 &mut zero_input,
200 &mut zero_output,
201 &mut zero_scratch,
202 );
203 }
204
205 fn test_dft_correct(input: &[Complex<f32>], expected_output: &[Complex<f32>]) {
208 assert_eq!(input.len(), expected_output.len());
209 let len = input.len();
210
211 let mut reference_output = vec![Zero::zero(); len];
212 dft(&input, &mut reference_output);
213 assert!(
214 compare_vectors(expected_output, &reference_output),
215 "Reference implementation failed for len={}",
216 len
217 );
218
219 let dft_instance = Dft::new(len, FftDirection::Forward);
220
221 {
223 let mut inplace_buffer = input.to_vec();
224
225 dft_instance.process(&mut inplace_buffer);
226
227 assert!(
228 compare_vectors(&expected_output, &inplace_buffer),
229 "process() failed, length = {}",
230 len
231 );
232 }
233
234 {
236 let mut inplace_with_scratch_buffer = input.to_vec();
237 let mut inplace_scratch = vec![Zero::zero(); dft_instance.get_inplace_scratch_len()];
238
239 dft_instance
240 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
241
242 assert!(
243 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
244 "process_inplace() failed, length = {}",
245 len
246 );
247
248 for item in inplace_scratch.iter_mut() {
250 *item = Complex::new(100.0, 100.0);
251 }
252 inplace_with_scratch_buffer.copy_from_slice(&input);
253
254 dft_instance
255 .process_with_scratch(&mut inplace_with_scratch_buffer, &mut inplace_scratch);
256
257 assert!(
258 compare_vectors(&expected_output, &inplace_with_scratch_buffer),
259 "process_with_scratch() failed the 'dirty scratch' test for len = {}",
260 len
261 );
262 }
263
264 {
266 let mut outofplace_input = input.to_vec();
267 let mut outofplace_output = expected_output.to_vec();
268
269 dft_instance.process_outofplace_with_scratch(
270 &mut outofplace_input,
271 &mut outofplace_output,
272 &mut [],
273 );
274
275 assert!(
276 compare_vectors(&expected_output, &outofplace_output),
277 "process_outofplace_with_scratch() failed, length = {}",
278 len
279 );
280 }
281 }
282
283 #[test]
284 fn test_dft_known_len_2() {
285 let signal = [
286 Complex { re: 1f32, im: 0f32 },
287 Complex {
288 re: -1f32,
289 im: 0f32,
290 },
291 ];
292 let spectrum = [
293 Complex { re: 0f32, im: 0f32 },
294 Complex { re: 2f32, im: 0f32 },
295 ];
296 test_dft_correct(&signal[..], &spectrum[..]);
297 }
298
299 #[test]
300 fn test_dft_known_len_3() {
301 let signal = [
302 Complex { re: 1f32, im: 1f32 },
303 Complex {
304 re: 2f32,
305 im: -3f32,
306 },
307 Complex {
308 re: -1f32,
309 im: 4f32,
310 },
311 ];
312 let spectrum = [
313 Complex { re: 2f32, im: 2f32 },
314 Complex {
315 re: -5.562177f32,
316 im: -2.098076f32,
317 },
318 Complex {
319 re: 6.562178f32,
320 im: 3.09807f32,
321 },
322 ];
323 test_dft_correct(&signal[..], &spectrum[..]);
324 }
325
326 #[test]
327 fn test_dft_known_len_4() {
328 let signal = [
329 Complex { re: 0f32, im: 1f32 },
330 Complex {
331 re: 2.5f32,
332 im: -3f32,
333 },
334 Complex {
335 re: -1f32,
336 im: -1f32,
337 },
338 Complex { re: 4f32, im: 0f32 },
339 ];
340 let spectrum = [
341 Complex {
342 re: 5.5f32,
343 im: -3f32,
344 },
345 Complex {
346 re: -2f32,
347 im: 3.5f32,
348 },
349 Complex {
350 re: -7.5f32,
351 im: 3f32,
352 },
353 Complex {
354 re: 4f32,
355 im: 0.5f32,
356 },
357 ];
358 test_dft_correct(&signal[..], &spectrum[..]);
359 }
360
361 #[test]
362 fn test_dft_known_len_6() {
363 let signal = [
364 Complex { re: 1f32, im: 1f32 },
365 Complex { re: 2f32, im: 2f32 },
366 Complex { re: 3f32, im: 3f32 },
367 Complex { re: 4f32, im: 4f32 },
368 Complex { re: 5f32, im: 5f32 },
369 Complex { re: 6f32, im: 6f32 },
370 ];
371 let spectrum = [
372 Complex {
373 re: 21f32,
374 im: 21f32,
375 },
376 Complex {
377 re: -8.16f32,
378 im: 2.16f32,
379 },
380 Complex {
381 re: -4.76f32,
382 im: -1.24f32,
383 },
384 Complex {
385 re: -3f32,
386 im: -3f32,
387 },
388 Complex {
389 re: -1.24f32,
390 im: -4.76f32,
391 },
392 Complex {
393 re: 2.16f32,
394 im: -8.16f32,
395 },
396 ];
397 test_dft_correct(&signal[..], &spectrum[..]);
398 }
399}