advanced_algorithms/numerical/
fft.rs

1//! Fast Fourier Transform (FFT) Implementation
2//!
3//! The FFT is one of the most important numerical algorithms, used in signal processing,
4//! image compression, audio processing, and many other applications.
5//!
6//! This implementation uses the Cooley-Tukey algorithm and supports parallel processing
7//! for improved performance on large datasets.
8//!
9//! # Examples
10//!
11//! ```
12//! use advanced_algorithms::numerical::fft;
13//!
14//! // Transform a simple signal
15//! let signal = vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.0];
16//! let spectrum = fft::fft(&signal);
17//!
18//! // Transform back to time domain
19//! let reconstructed = fft::ifft(&spectrum);
20//! ```
21
22use num_complex::Complex64;
23use rayon::prelude::*;
24use std::f64::consts::PI;
25
26/// Performs a Fast Fourier Transform on the input signal
27///
28/// # Arguments
29///
30/// * `input` - A slice of real-valued samples
31///
32/// # Returns
33///
34/// A vector of complex frequency components
35///
36/// # Panics
37///
38/// Panics if the input length is not a power of 2
39///
40/// # Performance
41///
42/// Time complexity: O(n log n) where n is the input length
43/// Uses parallel processing for inputs larger than 1024 samples
44pub fn fft(input: &[f64]) -> Vec<Complex64> {
45    let n = input.len();
46    assert!(n.is_power_of_two(), "Input length must be a power of 2");
47    
48    let complex_input: Vec<Complex64> = input.iter()
49        .map(|&x| Complex64::new(x, 0.0))
50        .collect();
51    
52    fft_complex(&complex_input)
53}
54
55/// Performs FFT on complex-valued input
56///
57/// # Arguments
58///
59/// * `input` - A slice of complex samples
60///
61/// # Returns
62///
63/// A vector of complex frequency components
64pub fn fft_complex(input: &[Complex64]) -> Vec<Complex64> {
65    let n = input.len();
66    
67    if n <= 1 {
68        return input.to_vec();
69    }
70    
71    if n <= 32 {
72        // Use DFT for small inputs
73        return dft(input);
74    }
75    
76    // Cooley-Tukey FFT algorithm
77    fft_recursive(input)
78}
79
80/// Inverse Fast Fourier Transform
81///
82/// Converts frequency domain back to time domain
83///
84/// # Arguments
85///
86/// * `input` - Frequency domain complex samples
87///
88/// # Returns
89///
90/// Time domain complex samples
91pub fn ifft(input: &[Complex64]) -> Vec<Complex64> {
92    let n = input.len();
93    
94    // Conjugate the input
95    let conjugated: Vec<Complex64> = input.iter()
96        .map(|&x| x.conj())
97        .collect();
98    
99    // Perform FFT
100    let result = fft_complex(&conjugated);
101    
102    // Conjugate and scale the result
103    result.iter()
104        .map(|&x| x.conj() / (n as f64))
105        .collect()
106}
107
108/// Performs FFT with parallel processing for large inputs
109///
110/// # Arguments
111///
112/// * `input` - A slice of real-valued samples
113///
114/// # Returns
115///
116/// A vector of complex frequency components
117pub fn fft_parallel(input: &[f64]) -> Vec<Complex64> {
118    let n = input.len();
119    assert!(n.is_power_of_two(), "Input length must be a power of 2");
120    
121    let complex_input: Vec<Complex64> = input.par_iter()
122        .map(|&x| Complex64::new(x, 0.0))
123        .collect();
124    
125    fft_recursive_parallel(&complex_input)
126}
127
128// Internal recursive FFT implementation
129fn fft_recursive(input: &[Complex64]) -> Vec<Complex64> {
130    let n = input.len();
131    
132    if n <= 1 {
133        return input.to_vec();
134    }
135    
136    // Split into even and odd indices
137    let even: Vec<Complex64> = input.iter()
138        .step_by(2)
139        .copied()
140        .collect();
141    
142    let odd: Vec<Complex64> = input.iter()
143        .skip(1)
144        .step_by(2)
145        .copied()
146        .collect();
147    
148    // Recursively compute FFT
149    let fft_even = fft_recursive(&even);
150    let fft_odd = fft_recursive(&odd);
151    
152    // Combine results
153    let mut result = vec![Complex64::new(0.0, 0.0); n];
154    
155    for k in 0..n/2 {
156        let angle = -2.0 * PI * (k as f64) / (n as f64);
157        let w = Complex64::new(angle.cos(), angle.sin());
158        let t = w * fft_odd[k];
159        
160        result[k] = fft_even[k] + t;
161        result[k + n/2] = fft_even[k] - t;
162    }
163    
164    result
165}
166
167// Parallel FFT implementation
168fn fft_recursive_parallel(input: &[Complex64]) -> Vec<Complex64> {
169    let n = input.len();
170    
171    if n <= 1024 {
172        return fft_recursive(input);
173    }
174    
175    // Split into even and odd indices (same as serial version)
176    let even: Vec<Complex64> = input.iter()
177        .step_by(2)
178        .copied()
179        .collect();
180    
181    let odd: Vec<Complex64> = input.iter()
182        .skip(1)
183        .step_by(2)
184        .copied()
185        .collect();
186    
187    // Recursively compute FFT in parallel
188    let (fft_even, fft_odd) = rayon::join(
189        || fft_recursive_parallel(&even),
190        || fft_recursive_parallel(&odd)
191    );
192    
193    // Combine results (same as serial version but parallelized)
194    let mut result = vec![Complex64::new(0.0, 0.0); n];
195    
196    result.par_iter_mut()
197        .enumerate()
198        .for_each(|(k, r)| {
199            if k < n/2 {
200                let angle = -2.0 * PI * (k as f64) / (n as f64);
201                let w = Complex64::new(angle.cos(), angle.sin());
202                let t = w * fft_odd[k];
203                *r = fft_even[k] + t;
204            } else {
205                let k = k - n/2;
206                let angle = -2.0 * PI * (k as f64) / (n as f64);
207                let w = Complex64::new(angle.cos(), angle.sin());
208                let t = w * fft_odd[k];
209                *r = fft_even[k] - t;
210            }
211        });
212    
213    result
214}
215
216// Direct DFT for small inputs
217fn dft(input: &[Complex64]) -> Vec<Complex64> {
218    let n = input.len();
219    let mut result = vec![Complex64::new(0.0, 0.0); n];
220    
221    for (k, r) in result.iter_mut().enumerate() {
222        let mut sum = Complex64::new(0.0, 0.0);
223        for (j, &x) in input.iter().enumerate() {
224            let angle = -2.0 * PI * (k * j) as f64 / n as f64;
225            let w = Complex64::new(angle.cos(), angle.sin());
226            sum += x * w;
227        }
228        *r = sum;
229    }
230    
231    result
232}
233
234/// Compute the power spectrum (magnitude squared) from FFT output
235///
236/// # Arguments
237///
238/// * `fft_output` - Output from FFT
239///
240/// # Returns
241///
242/// Vector of power values (magnitude squared)
243pub fn power_spectrum(fft_output: &[Complex64]) -> Vec<f64> {
244    fft_output.iter()
245        .map(|c| c.norm_sqr())
246        .collect()
247}
248
249/// Compute the magnitude spectrum from FFT output
250///
251/// # Arguments
252///
253/// * `fft_output` - Output from FFT
254///
255/// # Returns
256///
257/// Vector of magnitude values
258pub fn magnitude_spectrum(fft_output: &[Complex64]) -> Vec<f64> {
259    fft_output.iter()
260        .map(|c| c.norm())
261        .collect()
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    
268    #[test]
269    fn test_fft_basic() {
270        let input = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0];
271        let output = fft(&input);
272        assert_eq!(output.len(), 8);
273    }
274    
275    #[test]
276    fn test_fft_ifft_roundtrip() {
277        let input = vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.0];
278        let spectrum = fft(&input);
279        let reconstructed = ifft(&spectrum);
280        
281        for (i, &val) in input.iter().enumerate() {
282            assert!((reconstructed[i].re - val).abs() < 1e-10);
283            assert!(reconstructed[i].im.abs() < 1e-10);
284        }
285    }
286    
287    #[test]
288    fn test_fft_parallel() {
289        let input: Vec<f64> = (0..2048).map(|i| (i as f64).sin()).collect();
290        let serial = fft(&input);
291        let parallel = fft_parallel(&input);
292        
293        // Parallel computation may have slightly different rounding due to different order
294        // of operations. We check that the maximum error is small.
295        let max_error = serial.iter().zip(parallel.iter())
296            .map(|(s, p)| (s - p).norm())
297            .fold(0.0, f64::max);
298        
299        assert!(max_error < 1e-6, "Maximum error: {}", max_error);
300    }
301}