advanced_algorithms/numerical/
fft.rs1use num_complex::Complex64;
23use rayon::prelude::*;
24use std::f64::consts::PI;
25
26pub fn fft(input: &[f64]) -> Vec<Complex64> {
45 let n = input.len();
46 assert!(n.is_power_of_two(), "Input length must be a power of 2");
47
48 let complex_input: Vec<Complex64> = input.iter()
49 .map(|&x| Complex64::new(x, 0.0))
50 .collect();
51
52 fft_complex(&complex_input)
53}
54
55pub fn fft_complex(input: &[Complex64]) -> Vec<Complex64> {
65 let n = input.len();
66
67 if n <= 1 {
68 return input.to_vec();
69 }
70
71 if n <= 32 {
72 return dft(input);
74 }
75
76 fft_recursive(input)
78}
79
80pub fn ifft(input: &[Complex64]) -> Vec<Complex64> {
92 let n = input.len();
93
94 let conjugated: Vec<Complex64> = input.iter()
96 .map(|&x| x.conj())
97 .collect();
98
99 let result = fft_complex(&conjugated);
101
102 result.iter()
104 .map(|&x| x.conj() / (n as f64))
105 .collect()
106}
107
108pub fn fft_parallel(input: &[f64]) -> Vec<Complex64> {
118 let n = input.len();
119 assert!(n.is_power_of_two(), "Input length must be a power of 2");
120
121 let complex_input: Vec<Complex64> = input.par_iter()
122 .map(|&x| Complex64::new(x, 0.0))
123 .collect();
124
125 fft_recursive_parallel(&complex_input)
126}
127
128fn fft_recursive(input: &[Complex64]) -> Vec<Complex64> {
130 let n = input.len();
131
132 if n <= 1 {
133 return input.to_vec();
134 }
135
136 let even: Vec<Complex64> = input.iter()
138 .step_by(2)
139 .copied()
140 .collect();
141
142 let odd: Vec<Complex64> = input.iter()
143 .skip(1)
144 .step_by(2)
145 .copied()
146 .collect();
147
148 let fft_even = fft_recursive(&even);
150 let fft_odd = fft_recursive(&odd);
151
152 let mut result = vec![Complex64::new(0.0, 0.0); n];
154
155 for k in 0..n/2 {
156 let angle = -2.0 * PI * (k as f64) / (n as f64);
157 let w = Complex64::new(angle.cos(), angle.sin());
158 let t = w * fft_odd[k];
159
160 result[k] = fft_even[k] + t;
161 result[k + n/2] = fft_even[k] - t;
162 }
163
164 result
165}
166
167fn fft_recursive_parallel(input: &[Complex64]) -> Vec<Complex64> {
169 let n = input.len();
170
171 if n <= 1024 {
172 return fft_recursive(input);
173 }
174
175 let even: Vec<Complex64> = input.iter()
177 .step_by(2)
178 .copied()
179 .collect();
180
181 let odd: Vec<Complex64> = input.iter()
182 .skip(1)
183 .step_by(2)
184 .copied()
185 .collect();
186
187 let (fft_even, fft_odd) = rayon::join(
189 || fft_recursive_parallel(&even),
190 || fft_recursive_parallel(&odd)
191 );
192
193 let mut result = vec![Complex64::new(0.0, 0.0); n];
195
196 result.par_iter_mut()
197 .enumerate()
198 .for_each(|(k, r)| {
199 if k < n/2 {
200 let angle = -2.0 * PI * (k as f64) / (n as f64);
201 let w = Complex64::new(angle.cos(), angle.sin());
202 let t = w * fft_odd[k];
203 *r = fft_even[k] + t;
204 } else {
205 let k = k - n/2;
206 let angle = -2.0 * PI * (k as f64) / (n as f64);
207 let w = Complex64::new(angle.cos(), angle.sin());
208 let t = w * fft_odd[k];
209 *r = fft_even[k] - t;
210 }
211 });
212
213 result
214}
215
216fn dft(input: &[Complex64]) -> Vec<Complex64> {
218 let n = input.len();
219 let mut result = vec![Complex64::new(0.0, 0.0); n];
220
221 for (k, r) in result.iter_mut().enumerate() {
222 let mut sum = Complex64::new(0.0, 0.0);
223 for (j, &x) in input.iter().enumerate() {
224 let angle = -2.0 * PI * (k * j) as f64 / n as f64;
225 let w = Complex64::new(angle.cos(), angle.sin());
226 sum += x * w;
227 }
228 *r = sum;
229 }
230
231 result
232}
233
234pub fn power_spectrum(fft_output: &[Complex64]) -> Vec<f64> {
244 fft_output.iter()
245 .map(|c| c.norm_sqr())
246 .collect()
247}
248
249pub fn magnitude_spectrum(fft_output: &[Complex64]) -> Vec<f64> {
259 fft_output.iter()
260 .map(|c| c.norm())
261 .collect()
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_fft_basic() {
270 let input = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0];
271 let output = fft(&input);
272 assert_eq!(output.len(), 8);
273 }
274
275 #[test]
276 fn test_fft_ifft_roundtrip() {
277 let input = vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.0];
278 let spectrum = fft(&input);
279 let reconstructed = ifft(&spectrum);
280
281 for (i, &val) in input.iter().enumerate() {
282 assert!((reconstructed[i].re - val).abs() < 1e-10);
283 assert!(reconstructed[i].im.abs() < 1e-10);
284 }
285 }
286
287 #[test]
288 fn test_fft_parallel() {
289 let input: Vec<f64> = (0..2048).map(|i| (i as f64).sin()).collect();
290 let serial = fft(&input);
291 let parallel = fft_parallel(&input);
292
293 let max_error = serial.iter().zip(parallel.iter())
296 .map(|(s, p)| (s - p).norm())
297 .fold(0.0, f64::max);
298
299 assert!(max_error < 1e-6, "Maximum error: {}", max_error);
300 }
301}