cute_dsp/
fft.rs

1//! Fast Fourier Transform implementation
2//!
3//! This module provides FFT implementations optimized for sizes that are products of 2^a * 3^b.
4//! It includes both complex and real FFT implementations with various optimizations.
5
6#![allow(unused_imports)]
7
8#[cfg(feature = "std")]
9use std::{f64::consts::PI, vec::Vec};
10
11
12#[cfg(not(feature = "std"))]
13use core::f64::consts::PI;
14
15#[cfg(all(not(feature = "std"), feature = "alloc"))]
16use alloc::vec::Vec;
17
18use num_complex::Complex;
19use num_traits::Float;
20use num_traits::FromPrimitive;
21
22use crate::perf;
23
24/// Helper functions for complex multiplication and data interleaving
25mod helpers {
26    use super::*;
27
28    /// Complex multiplication
29    pub fn complex_mul<T: Float>(
30        a: &mut [Complex<T>],
31        b: &[Complex<T>],
32        c: &[Complex<T>],
33        size: usize,
34    ) {
35        for i in 0..size {
36            let bi = b[i];
37            let ci = c[i];
38            a[i] = Complex::new(
39                bi.re * ci.re - bi.im * ci.im,
40                bi.im * ci.re + bi.re * ci.im,
41            );
42        }
43    }
44
45    /// Complex multiplication with conjugate of second argument
46    pub fn complex_mul_conj<T: Float>(
47        a: &mut [Complex<T>],
48        b: &[Complex<T>],
49        c: &[Complex<T>],
50        size: usize,
51    ) {
52        for i in 0..size {
53            let bi = b[i];
54            let ci = c[i];
55            a[i] = Complex::new(
56                bi.re * ci.re + bi.im * ci.im,
57                bi.im * ci.re - bi.re * ci.im,
58            );
59        }
60    }
61
62    /// Complex multiplication with split complex representation
63    pub fn complex_mul_split<T: Float>(
64        ar: &mut [T],
65        ai: &mut [T],
66        br: &[T],
67        bi: &[T],
68        cr: &[T],
69        ci: &[T],
70        size: usize,
71    ) {
72        for i in 0..size {
73            let rr = br[i] * cr[i] - bi[i] * ci[i];
74            let ri = br[i] * ci[i] + bi[i] * cr[i];
75            ar[i] = rr;
76            ai[i] = ri;
77        }
78    }
79
80    /// Complex multiplication with conjugate and split complex representation
81    pub fn complex_mul_conj_split<T: Float>(
82        ar: &mut [T],
83        ai: &mut [T],
84        br: &[T],
85        bi: &[T],
86        cr: &[T],
87        ci: &[T],
88        size: usize,
89    ) {
90        for i in 0..size {
91            let rr = cr[i] * br[i] + ci[i] * bi[i];
92            let ri = cr[i] * bi[i] - ci[i] * br[i];
93            ar[i] = rr;
94            ai[i] = ri;
95        }
96    }
97
98    /// Interleave copy with fixed stride
99    pub fn interleave_copy<T: Copy>(a: &[T], b: &mut [T], a_stride: usize, b_stride: usize) {
100        for bi in 0..b_stride {
101            for ai in 0..a_stride {
102                b[bi + ai * b_stride] = a[bi * a_stride + ai];
103            }
104        }
105    }
106
107    /// Interleave copy with split complex representation
108    pub fn interleave_copy_split<T: Copy>(
109        a_real: &[T],
110        a_imag: &[T],
111        b_real: &mut [T],
112        b_imag: &mut [T],
113        a_stride: usize,
114        b_stride: usize,
115    ) {
116        for bi in 0..b_stride {
117            for ai in 0..a_stride {
118                b_real[bi + ai * b_stride] = a_real[bi * a_stride + ai];
119                b_imag[bi + ai * b_stride] = a_imag[bi * a_stride + ai];
120            }
121        }
122    }
123}
124
125/// A simple and portable power-of-2 FFT implementation
126pub struct SimpleFFT<T: Float> {
127    twiddles: Vec<Complex<T>>,
128    working: Vec<Complex<T>>,
129}
130
131impl<T: Float + FromPrimitive> SimpleFFT<T> {
132    /// Create a new FFT with the specified size
133    pub fn new(size: usize) -> Self {
134        let mut result = Self {
135            twiddles: Vec::new(),
136            working: Vec::new(),
137        };
138        result.resize(size);
139        result
140    }
141
142    /// Resize the FFT to handle a different size
143    pub fn resize(&mut self, size: usize) {
144        self.twiddles.resize(size * 3 / 4, Complex::new(T::zero(), T::zero()));
145        for i in 0..self.twiddles.len() {
146            let twiddle_phase = -T::from_f64(2.0).unwrap() * T::from_f64(PI as f64).unwrap() * T::from_f64(i as f64).unwrap() / T::from_f64(size as f64).unwrap();
147            self.twiddles[i] = Complex::new(
148                twiddle_phase.cos(),
149                twiddle_phase.sin(),
150            );
151        }
152        self.working.resize(size, Complex::new(T::zero(), T::zero()));
153    }
154
155    /// Perform a forward FFT
156    pub fn fft(&mut self, time: &[Complex<T>], freq: &mut [Complex<T>]) {
157        let size = self.working.len();
158        if size <= 1 {
159            if size == 1 {
160                freq[0] = time[0];
161            }
162            return;
163        }
164        let working_size = self.working.len();
165            let working_mut = &mut self.working;
166            Self::fft_pass::<false>(
167                working_size,
168                &self.twiddles,
169                size, 
170                1, 
171                time, 
172                freq, 
173                working_mut);
174    }
175
176    /// Perform an inverse FFT
177    pub fn ifft(&mut self, freq: &[Complex<T>], time: &mut [Complex<T>]) {
178        let size = self.working.len();
179        if size <= 1 {
180            if size == 1 {
181                time[0] = freq[0];
182            }
183            return;
184        }
185        let working_size = self.working.len();
186        let working_mut = &mut self.working;
187        Self::fft_pass::<true>(working_size,&self.twiddles,size, 1, freq, time, working_mut);
188    }
189
190    /// Perform a forward FFT with split complex representation
191    pub fn fft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T], out_i: &mut [T]) {
192        let size = self.working.len();
193        if size <= 1 {
194            if size == 1 {
195                out_r[0] = in_r[0];
196                out_i[0] = in_i[0];
197            }
198            return;
199        }
200        
201        // Create temporary buffers for working space
202        let mut working_r = vec![T::zero(); size];
203        let mut working_i = vec![T::zero(); size];
204        
205        self.fft_pass_split::<false>(size, 1, in_r, in_i, out_r, out_i, &mut working_r, &mut working_i);
206    }
207
208    /// Perform an inverse FFT with split complex representation
209    pub fn ifft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T], out_i: &mut [T]) {
210        let size = self.working.len();
211        if size <= 1 {
212            if size == 1 {
213                out_r[0] = in_r[0];
214                out_i[0] = in_i[0];
215            }
216            return;
217        }
218        
219        // Create temporary buffers for working space
220        let mut working_r = vec![T::zero(); size];
221        let mut working_i = vec![T::zero(); size];
222        
223        self.fft_pass_split::<true>(size, 1, in_r, in_i, out_r, out_i, &mut working_r, &mut working_i);
224    }
225
226    // Internal implementation of FFT pass
227     fn fft_pass<const INVERSE: bool>(
228        orignal_working_size:usize,
229        twiddles: &[Complex<T>],
230        size: usize,
231        stride: usize,
232        input: &[Complex<T>],
233        output: &mut [Complex<T>],
234        working: &mut [Complex<T>],
235    ) {
236        if size / 4 > 1 {
237            // Calculate four quarter-size FFTs
238            Self::fft_pass::<INVERSE>(orignal_working_size,twiddles,size / 4, stride * 4, input, working, output);
239            Self::combine4::<INVERSE>(orignal_working_size,twiddles,size, stride, working, output);
240        } else if size == 4 {
241            Self::combine4::<INVERSE>(orignal_working_size,twiddles,4, stride, input, output);
242        } else {
243            // 2-point FFT
244            for s in 0..stride {
245                let b = input[s + stride];
246                let a = input[s];
247                output[s] = a + b;
248                output[s + stride] = a - b;
249            }
250        }
251    }
252
253    // Internal implementation of FFT pass with split complex representation
254    fn fft_pass_split<const INVERSE: bool>(
255        &self,
256        size: usize,
257        stride: usize,
258        in_r: &[T],
259        in_i: &[T],
260        out_r: &mut [T],
261        out_i: &mut [T],
262        working_r: &mut [T],
263        working_i: &mut [T],
264    ) {
265        if size / 4 > 1 {
266            // Calculate four quarter-size FFTs
267            self.fft_pass_split::<INVERSE>(
268                size / 4,
269                stride * 4,
270                in_r,
271                in_i,
272                working_r,
273                working_i,
274                out_r,
275                out_i,
276            );
277            self.combine4_split::<INVERSE>(size, stride, working_r, working_i, out_r, out_i);
278        } else if size == 4 {
279            self.combine4_split::<INVERSE>(4, stride, in_r, in_i, out_r, out_i);
280        } else {
281            // 2-point FFT
282            for s in 0..stride {
283                let ar = in_r[s];
284                let ai = in_i[s];
285                let br = in_r[s + stride];
286                let bi = in_i[s + stride];
287                out_r[s] = ar + br;
288                out_i[s] = ai + bi;
289                out_r[s + stride] = ar - br;
290                out_i[s + stride] = ai - bi;
291            }
292        }
293    }
294
295    // Combine interleaved results into a single spectrum
296    fn combine4<const INVERSE: bool>(
297        working_buf_len:usize,
298        twiddles: &[Complex<T>],
299        size: usize,
300        stride: usize,
301        input: &[Complex<T>],
302        output: &mut [Complex<T>],
303    ) {
304        let twiddle_step = working_buf_len / size;
305        
306        for i in 0..size / 4 {
307            let twiddle_b = twiddles[i * twiddle_step];
308            let twiddle_c = twiddles[i * 2 * twiddle_step];
309            let twiddle_d = twiddles[i * 3 * twiddle_step];
310            
311            let input_a = &input[4 * i * stride..];
312            let input_b = &input[(4 * i + 1) * stride..];
313            let input_c = &input[(4 * i + 2) * stride..];
314            let input_d = &input[(4 * i + 3) * stride..];
315            
316            let (output_first_half, output_second_half) = output.split_at_mut((size / 4 * 2) * stride);
317            let (output_a_chunk, output_b_chunk) = output_first_half.split_at_mut((size / 4) * stride);
318            let (output_c_chunk, output_d_chunk) = output_second_half.split_at_mut((size / 4) * stride);
319
320            let output_a = &mut output_a_chunk[i * stride..];
321            let output_b = &mut output_b_chunk[i * stride..];
322            let output_c = &mut output_c_chunk[i * stride..];
323            let output_d = &mut output_d_chunk[i * stride..];
324            
325            for s in 0..stride {
326                let a = input_a[s];
327                let b = if INVERSE {
328                    Complex::new(
329                        input_b[s].re * twiddle_b.re + input_b[s].im * twiddle_b.im,
330                        input_b[s].im * twiddle_b.re - input_b[s].re * twiddle_b.im,
331                    )
332                } else {
333                    Complex::new(
334                        input_b[s].re * twiddle_b.re - input_b[s].im * twiddle_b.im,
335                        input_b[s].im * twiddle_b.re + input_b[s].re * twiddle_b.im,
336                    )
337                };
338                
339                let c = if INVERSE {
340                    Complex::new(
341                        input_c[s].re * twiddle_c.re + input_c[s].im * twiddle_c.im,
342                        input_c[s].im * twiddle_c.re - input_c[s].re * twiddle_c.im,
343                    )
344                } else {
345                    Complex::new(
346                        input_c[s].re * twiddle_c.re - input_c[s].im * twiddle_c.im,
347                        input_c[s].im * twiddle_c.re + input_c[s].re * twiddle_c.im,
348                    )
349                };
350                
351                let d = if INVERSE {
352                    Complex::new(
353                        input_d[s].re * twiddle_d.re + input_d[s].im * twiddle_d.im,
354                        input_d[s].im * twiddle_d.re - input_d[s].re * twiddle_d.im,
355                    )
356                } else {
357                    Complex::new(
358                        input_d[s].re * twiddle_d.re - input_d[s].im * twiddle_d.im,
359                        input_d[s].im * twiddle_d.re + input_d[s].re * twiddle_d.im,
360                    )
361                };
362                
363                let ac0 = a + c;
364                let ac1 = a - c;
365                let bd0 = b + d;
366                let bd1 = if INVERSE { b - d } else { d - b };
367                let bd1i = Complex::new(-bd1.im, bd1.re);
368                
369                output_a[s] = ac0 + bd0;
370                output_b[s] = ac1 + bd1i;
371                output_c[s] = ac0 - bd0;
372                output_d[s] = ac1 - bd1i;
373            }
374        }
375    }
376
377    // Combine interleaved results into a single spectrum with split complex representation
378    fn combine4_split<const INVERSE: bool>(
379        &self,
380        size: usize,
381        stride: usize,
382        input_r: &[T],
383        input_i: &[T],
384        output_r: &mut [T],
385        output_i: &mut [T],
386    ) {
387        let twiddle_step = self.working.len() / size;
388        
389        for i in 0..size / 4 {
390            let twiddle_b = self.twiddles[i * twiddle_step];
391            let twiddle_c = self.twiddles[i * 2 * twiddle_step];
392            let twiddle_d = self.twiddles[i * 3 * twiddle_step];
393            
394            for s in 0..stride {
395                // Get input values
396                let a_r = input_r[4 * i * stride + s];
397                let a_i = input_i[4 * i * stride + s];
398                
399                let b_r = input_r[(4 * i + 1) * stride + s];
400                let b_i = input_i[(4 * i + 1) * stride + s];
401                
402                let c_r = input_r[(4 * i + 2) * stride + s];
403                let c_i = input_i[(4 * i + 2) * stride + s];
404                
405                let d_r = input_r[(4 * i + 3) * stride + s];
406                let d_i = input_i[(4 * i + 3) * stride + s];
407                
408                // Apply twiddle factors
409                let (b_r_tw, b_i_tw) = if INVERSE {
410                    (
411                        b_r * twiddle_b.re + b_i * twiddle_b.im,
412                        b_i * twiddle_b.re - b_r * twiddle_b.im,
413                    )
414                } else {
415                    (
416                        b_r * twiddle_b.re - b_i * twiddle_b.im,
417                        b_i * twiddle_b.re + b_r * twiddle_b.im,
418                    )
419                };
420                
421                let (c_r_tw, c_i_tw) = if INVERSE {
422                    (
423                        c_r * twiddle_c.re + c_i * twiddle_c.im,
424                        c_i * twiddle_c.re - c_r * twiddle_c.im,
425                    )
426                } else {
427                    (
428                        c_r * twiddle_c.re - c_i * twiddle_c.im,
429                        c_i * twiddle_c.re + c_r * twiddle_c.im,
430                    )
431                };
432                
433                let (d_r_tw, d_i_tw) = if INVERSE {
434                    (
435                        d_r * twiddle_d.re + d_i * twiddle_d.im,
436                        d_i * twiddle_d.re - d_r * twiddle_d.im,
437                    )
438                } else {
439                    (
440                        d_r * twiddle_d.re - d_i * twiddle_d.im,
441                        d_i * twiddle_d.re + d_r * twiddle_d.im,
442                    )
443                };
444                
445                // Butterfly calculations
446                let ac0_r = a_r + c_r_tw;
447                let ac0_i = a_i + c_i_tw;
448                let ac1_r = a_r - c_r_tw;
449                let ac1_i = a_i - c_i_tw;
450                
451                let bd0_r = b_r_tw + d_r_tw;
452                let bd0_i = b_i_tw + d_i_tw;
453                
454                let (bd1_r, bd1_i) = if INVERSE {
455                    (b_r_tw - d_r_tw, b_i_tw - d_i_tw)
456                } else {
457                    (d_r_tw - b_r_tw, d_i_tw - b_i_tw)
458                };
459                
460                let bd1i_r = -bd1_i;
461                let bd1i_i = bd1_r;
462                
463                // Store results
464                output_r[i * stride + s] = ac0_r + bd0_r;
465                output_i[i * stride + s] = ac0_i + bd0_i;
466                
467                output_r[(i + size / 4) * stride + s] = ac1_r + bd1i_r;
468                output_i[(i + size / 4) * stride + s] = ac1_i + bd1i_i;
469                
470                output_r[(i + size / 4 * 2) * stride + s] = ac0_r - bd0_r;
471                output_i[(i + size / 4 * 2) * stride + s] = ac0_i - bd0_i;
472                
473                output_r[(i + size / 4 * 3) * stride + s] = ac1_r - bd1i_r;
474                output_i[(i + size / 4 * 3) * stride + s] = ac1_i - bd1i_i;
475            }
476        }
477    }
478}
479
480/// A wrapper for complex FFT to handle real data
481pub struct SimpleRealFFT<T: Float> {
482    complex_fft: SimpleFFT<T>,
483    tmp_time: Vec<Complex<T>>,
484    tmp_freq: Vec<Complex<T>>,
485}
486
487impl<T: Float + num_traits::FromPrimitive> SimpleRealFFT<T> {
488    /// Create a new real FFT with the specified size
489    pub fn new(size: usize) -> Self {
490        let mut result = Self {
491            complex_fft: SimpleFFT::new(size),
492            tmp_time: Vec::new(),
493            tmp_freq: Vec::new(),
494        };
495        result.resize(size);
496        result
497    }
498
499    /// Resize the FFT to handle a different size
500    pub fn resize(&mut self, size: usize) {
501        self.complex_fft.resize(size);
502        self.tmp_time.resize(size, Complex::new(T::zero(), T::zero()));
503        self.tmp_freq.resize(size, Complex::new(T::zero(), T::zero()));
504    }
505
506    /// Perform a forward FFT on real data
507    pub fn fft(&mut self, time: &[T], freq: &mut [Complex<T>]) {
508        let size = self.tmp_time.len();
509
510        // Copy real data to complex buffer
511        for i in 0..size {
512            self.tmp_time[i] = Complex::new(time[i], T::zero());
513        }
514
515        // Perform complex FFT
516        self.complex_fft.fft(&self.tmp_time, &mut self.tmp_freq);
517
518        // Corrected output handling:
519        // DC component (real only)
520        freq[0] = Complex::new(self.tmp_freq[0].re, T::zero());
521        // Positive frequencies
522        for i in 1..size/2 {
523            freq[i] = self.tmp_freq[i];
524        }
525        // Nyquist component (real only)
526        freq[size/2] = Complex::new(self.tmp_freq[size/2].re, T::zero());
527    }
528
529    /// Perform a forward FFT on real data with split output
530    pub fn fft_split(&self, in_r: &[T], out_r: &mut [T], out_i: &mut [T]) {
531        let size = self.tmp_time.len();
532        
533        // Create temporary zero buffer for imaginary part
534        let tmp_i = vec![T::zero(); size];
535        
536        // Perform complex FFT with split representation
537        self.complex_fft.fft_split(in_r, &tmp_i, out_r, out_i);
538        
539        // Special case for Nyquist frequency
540        out_i[0] = out_r[size / 2];
541    }
542
543    /// Perform an inverse FFT to real data
544    pub fn ifft(&mut self, freq: &[Complex<T>], time: &mut [T]) {
545        let size = self.tmp_freq.len();
546
547        // DC component
548        self.tmp_freq[0] = Complex::new(freq[0].re, T::zero());
549        // Nyquist component
550        self.tmp_freq[size/2] = Complex::new(freq[size/2].re, T::zero());
551
552        // Fill the rest of the spectrum using conjugate symmetry
553        for i in 1..size/2 {
554            self.tmp_freq[i] = freq[i];
555            self.tmp_freq[size - i] = freq[i].conj();
556        }
557
558        // Perform inverse complex FFT
559        self.complex_fft.ifft(&self.tmp_freq, &mut self.tmp_time);
560
561        // Extract real part
562        for i in 0..size {
563            time[i] = self.tmp_time[i].re;
564        }
565    }
566
567    /// Perform an inverse FFT from split complex to real data
568    pub fn ifft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T]) {
569        let size = self.tmp_freq.len();
570
571        // Create temporary buffers for the full spectrum
572        let mut tmp_freq_r = vec![T::zero(); size];
573        let mut tmp_freq_i = vec![T::zero(); size];
574
575        // DC component
576        tmp_freq_r[0] = in_r[0];
577        tmp_freq_i[0] = T::zero();
578
579        // Nyquist component
580        tmp_freq_r[size / 2] = in_i[0];
581        tmp_freq_i[size / 2] = T::zero();
582
583        // Fill the rest of the spectrum using conjugate symmetry
584        for i in 1..size / 2 {
585            tmp_freq_r[i] = in_r[i];
586            tmp_freq_i[i] = in_i[i];
587            tmp_freq_r[size - i] = in_r[i];
588            tmp_freq_i[size - i] = -in_i[i];
589        }
590
591        // Create temporary buffer for imaginary output (will be discarded)
592        let mut tmp_out_i = vec![T::zero(); size];
593
594        // Perform inverse complex FFT
595        self.complex_fft.ifft_split(&tmp_freq_r, &tmp_freq_i, out_r, &mut tmp_out_i);
596    }
597}
598
599/// A power-of-2 FFT implementation that can be specialized for different platforms
600pub struct Pow2FFT<T: Float> {
601    simple_fft: SimpleFFT<T>,
602    tmp: Vec<Complex<T>>,
603}
604
605impl<T: Float+ FromPrimitive> Pow2FFT<T> {
606    /// Whether this FFT implementation is faster when given split-complex inputs
607    pub const PREFERS_SPLIT: bool = true;
608
609    /// Create a new FFT with the specified size
610    pub fn new(size: usize) -> Self {
611        let mut result = Self {
612            simple_fft: SimpleFFT::new(size),
613            tmp: Vec::new(),
614        };
615        result.resize(size);
616        result
617    }
618
619    /// Resize the FFT to handle a different size
620    pub fn resize(&mut self, size: usize) {
621        self.simple_fft.resize(size);
622        self.tmp.resize(size, Complex::new(T::zero(), T::zero()));
623    }
624
625    /// Perform a forward FFT
626    pub fn fft(&mut self, time: &[Complex<T>], freq: &mut [Complex<T>]) {
627        self.simple_fft.fft(time, freq);
628    }
629
630    /// Perform a forward FFT with split complex representation
631    pub fn fft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T], out_i: &mut [T]) {
632        self.simple_fft.fft_split(in_r, in_i, out_r, out_i);
633    }
634
635    /// Perform an inverse FFT
636    pub fn ifft(&mut self, freq: &[Complex<T>], time: &mut [Complex<T>]) {
637        self.simple_fft.ifft(freq, time);
638    }
639
640    /// Perform an inverse FFT with split complex representation
641    pub fn ifft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T], out_i: &mut [T]) {
642        self.simple_fft.ifft_split(in_r, in_i, out_r, out_i);
643    }
644}
645
646/// A power-of-2 real FFT implementation
647pub struct Pow2RealFFT<T: Float> {
648    simple_real_fft: SimpleRealFFT<T>,
649}
650
651impl<T: Float + FromPrimitive> Pow2RealFFT<T> {
652    /// Whether this FFT implementation is faster when given split-complex inputs
653    pub const PREFERS_SPLIT: bool = Pow2FFT::<T>::PREFERS_SPLIT;
654
655    /// Create a new real FFT with the specified size
656    pub fn new(size: usize) -> Self {
657        Self {
658            simple_real_fft: SimpleRealFFT::new(size),
659        }
660    }
661
662    /// Resize the FFT to handle a different size
663    pub fn resize(&mut self, size: usize) {
664        self.simple_real_fft.resize(size);
665    }
666
667    /// Perform a forward FFT on real data
668    pub fn fft(&mut self, time: &[T], freq: &mut [Complex<T>]) {
669        self.simple_real_fft.fft(time, freq);
670    }
671
672    /// Perform a forward FFT on real data with split output
673    pub fn fft_split(&self, in_r: &[T], out_r: &mut [T], out_i: &mut [T]) {
674        self.simple_real_fft.fft_split(in_r, out_r, out_i);
675    }
676
677    /// Perform an inverse FFT to real data
678    pub fn ifft(&mut self, freq: &[Complex<T>], time: &mut [T]) {
679        self.simple_real_fft.ifft(freq, time);
680    }
681
682    /// Perform an inverse FFT from split complex to real data
683    pub fn ifft_split(&self, in_r: &[T], in_i: &[T], out_r: &mut [T]) {
684        self.simple_real_fft.ifft_split(in_r, in_i, out_r);
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use num_complex::ComplexFloat;
691    use super::*;
692    
693    #[test]
694    fn test_simple_fft() {
695        // Create a 4-point FFT
696        let mut fft = SimpleFFT::<f32>::new(4);
697        
698        // Create input and output buffers
699        let input = vec![
700            Complex::new(1.0, 0.0),
701            Complex::new(0.0, 0.0),
702            Complex::new(0.0, 0.0),
703            Complex::new(0.0, 0.0),
704        ];
705        let mut output = vec![Complex::new(0.0, 0.0); 4];
706        
707        // Perform forward FFT
708        fft.fft(&input, &mut output);
709        
710        // All values should be 1.0 for a delta function input
711        for i in 0..4 {
712            assert!((output[i].re - 1.0).abs() < 1e-10);
713            assert!(output[i].im.abs() < 1e-10);
714        }
715        
716        // Create a new input with a sine wave
717        let input = vec![
718            Complex::new(0.0, 0.0),
719            Complex::new(1.0, 0.0),
720            Complex::new(0.0, 0.0),
721            Complex::new(-1.0, 0.0),
722        ];
723        
724        // Perform forward FFT
725        fft.fft(&input, &mut output);
726        
727        // For this input, we should have energy at frequency bin 1
728        assert!(output[0].abs() < 1e-10);
729        assert!((output[1].im + 2.0).abs() < 1e-10);
730        assert!(output[2].abs() < 1e-10);
731        assert!((output[3].im - 2.0).abs() < 1e-10);
732        
733        // Test inverse FFT
734        let mut inverse_output = vec![Complex::new(0.0, 0.0); 4];
735        fft.ifft(&output, &mut inverse_output);
736        
737        // Scale by 1/N
738        for i in 0..4 {
739            inverse_output[i] = inverse_output[i] / 4.0;
740        }
741        
742        // Should recover the original signal
743        for i in 0..4 {
744            assert!((inverse_output[i].re - input[i].re).abs() < 1e-10);
745            assert!((inverse_output[i].im - input[i].im).abs() < 1e-10);
746        }
747    }
748    
749    #[test]
750    fn test_real_fft() {
751        // Create an 8-point real FFT
752        let mut real_fft = SimpleRealFFT::<f32>::new(8);
753        
754        // Create input and output buffers
755        let input = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
756        let mut output = vec![Complex::new(0.0, 0.0); 5]; // Only need N/2+1 for real FFT
757        
758        // Perform forward FFT
759        real_fft.fft(&input, &mut output);
760        
761        // All values should be 1.0 for a delta function input
762        for i in 0..5 {
763            assert!((output[i].re - 1.0).abs() < 1e-10);
764            assert!(output[i].im.abs() < 1e-10);
765        }
766        
767        // Test inverse FFT
768        let mut inverse_output = vec![0.0; 8];
769        real_fft.ifft(&output, &mut inverse_output);
770        
771        // Scale by 1/N
772        for i in 0..8 {
773            inverse_output[i] /= 8.0;
774        }
775        
776        // Should recover the original signal
777        for i in 0..8 {
778            assert!((inverse_output[i] - input[i]).abs() < 1e-10);
779        }
780    }
781}