Skip to main content

cjc_runtime/
fft.rs

1//! FFT — Cooley-Tukey radix-2 FFT, inverse FFT, real FFT, PSD.
2//!
3//! # Determinism Contract
4//! - Bit-reversal permutation is deterministic.
5//! - Butterfly operations in fixed order.
6//! - Zero-padding to next power of 2 is deterministic.
7
8use std::f64::consts::PI;
9
10/// Compute the Cooley-Tukey radix-2 FFT in-place.
11///
12/// # Arguments
13///
14/// * `data` - Input signal as `(re, im)` pairs. Length **must** be a power of 2.
15///
16/// # Returns
17///
18/// Frequency-domain representation as `Vec<(f64, f64)>` of `(re, im)` pairs.
19///
20/// # Panics
21///
22/// Panics if `data.len()` is not a power of 2.
23///
24/// # Algorithm
25///
26/// Iterative radix-2 decimation-in-time with bit-reversal permutation followed
27/// by butterfly stages. All twiddle factors are computed in fixed order.
28pub fn fft(data: &[(f64, f64)]) -> Vec<(f64, f64)> {
29    let n = data.len();
30    assert!(n.is_power_of_two(), "FFT: input length must be power of 2, got {n}");
31
32    let mut buf = data.to_vec();
33
34    // Bit-reversal permutation
35    let bits = n.trailing_zeros() as usize;
36    for i in 0..n {
37        let j = bit_reverse(i, bits);
38        if i < j {
39            buf.swap(i, j);
40        }
41    }
42
43    // Butterfly stages
44    let mut size = 2;
45    while size <= n {
46        let half = size / 2;
47        let angle = -2.0 * PI / size as f64;
48        let w_base = (angle.cos(), angle.sin());
49        for start in (0..n).step_by(size) {
50            let mut w = (1.0, 0.0);
51            for k in 0..half {
52                let i = start + k;
53                let j = start + k + half;
54                let t = complex_mul(w, buf[j]);
55                buf[j] = (buf[i].0 - t.0, buf[i].1 - t.1);
56                buf[i] = (buf[i].0 + t.0, buf[i].1 + t.1);
57                w = complex_mul(w, w_base);
58            }
59        }
60        size *= 2;
61    }
62    buf
63}
64
65/// Compute the inverse FFT by conjugating, applying [`fft`], then conjugating
66/// and scaling by `1/N`.
67pub fn ifft(data: &[(f64, f64)]) -> Vec<(f64, f64)> {
68    let n = data.len();
69    // Conjugate input
70    let conjugated: Vec<(f64, f64)> = data.iter().map(|&(r, i)| (r, -i)).collect();
71    // Forward FFT
72    let mut result = fft(&conjugated);
73    // Conjugate and scale
74    let scale = 1.0 / n as f64;
75    for v in &mut result {
76        v.0 *= scale;
77        v.1 = -v.1 * scale;
78    }
79    result
80}
81
82/// Compute the FFT of a real-valued signal.
83///
84/// Automatically zero-pads to the next power of 2 if `data.len()` is not
85/// already a power of 2.
86pub fn rfft(data: &[f64]) -> Vec<(f64, f64)> {
87    let n = next_power_of_2(data.len());
88    let mut complex_data: Vec<(f64, f64)> = Vec::with_capacity(n);
89    for i in 0..n {
90        let val = if i < data.len() { data[i] } else { 0.0 };
91        complex_data.push((val, 0.0));
92    }
93    fft(&complex_data)
94}
95
96/// Compute the power spectral density: `|FFT(x)|^2` for each frequency bin.
97pub fn psd(data: &[f64]) -> Vec<f64> {
98    let spectrum = rfft(data);
99    spectrum.iter().map(|&(r, i)| r * r + i * i).collect()
100}
101
102// ---------------------------------------------------------------------------
103// Helpers
104// ---------------------------------------------------------------------------
105
106/// Reverse the lowest `bits` bits of `x`.
107fn bit_reverse(mut x: usize, bits: usize) -> usize {
108    let mut result = 0;
109    for _ in 0..bits {
110        result = (result << 1) | (x & 1);
111        x >>= 1;
112    }
113    result
114}
115
116/// Multiply two complex numbers represented as `(re, im)` tuples.
117fn complex_mul(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
118    (a.0 * b.0 - a.1 * b.1, a.0 * b.1 + a.1 * b.0)
119}
120
121/// Return the smallest power of 2 that is >= `n`.
122fn next_power_of_2(n: usize) -> usize {
123    let mut p = 1;
124    while p < n { p <<= 1; }
125    p
126}
127
128// ---------------------------------------------------------------------------
129// Phase B6: Window functions
130// ---------------------------------------------------------------------------
131
132/// Hann window: w[k] = 0.5 * (1 - cos(2*pi*k / (N-1))).
133/// For N=1, returns [1.0].
134pub fn hann_window(n: usize) -> Vec<f64> {
135    if n <= 1 { return vec![1.0; n]; }
136    (0..n).map(|k| 0.5 * (1.0 - (2.0 * PI * k as f64 / (n - 1) as f64).cos())).collect()
137}
138
139/// Hamming window: w[k] = 0.54 - 0.46 * cos(2*pi*k / (N-1)).
140/// For N=1, returns [1.0].
141pub fn hamming_window(n: usize) -> Vec<f64> {
142    if n <= 1 { return vec![1.0; n]; }
143    (0..n).map(|k| 0.54 - 0.46 * (2.0 * PI * k as f64 / (n - 1) as f64).cos()).collect()
144}
145
146/// Blackman window: w[k] = 0.42 - 0.5*cos(2*pi*k/(N-1)) + 0.08*cos(4*pi*k/(N-1)).
147/// For N=1, returns [1.0].
148pub fn blackman_window(n: usize) -> Vec<f64> {
149    if n <= 1 { return vec![1.0; n]; }
150    (0..n).map(|k| {
151        let frac = k as f64 / (n - 1) as f64;
152        0.42 - 0.5 * (2.0 * PI * frac).cos() + 0.08 * (4.0 * PI * frac).cos()
153    }).collect()
154}
155
156// ---------------------------------------------------------------------------
157// Phase B6: Arbitrary-length FFT (Bluestein's algorithm)
158// ---------------------------------------------------------------------------
159
160/// Compute the FFT for an arbitrary-length signal using Bluestein's chirp-z
161/// algorithm.
162///
163/// Delegates to [`fft`] when the input length is already a power of 2.
164/// For non-power-of-2 lengths, the signal is convolved with a chirp sequence
165/// via zero-padded radix-2 FFTs.
166pub fn fft_arbitrary(data: &[(f64, f64)]) -> Vec<(f64, f64)> {
167    let n = data.len();
168    if n == 0 { return vec![]; }
169    if n == 1 { return data.to_vec(); }
170
171    // If already power of 2, delegate to radix-2
172    if n.is_power_of_two() {
173        return fft(data);
174    }
175
176    // Chirp sequence: w[k] = exp(-i * pi * k^2 / N)
177    let chirp: Vec<(f64, f64)> = (0..n).map(|k| {
178        let angle = -PI * (k * k) as f64 / n as f64;
179        (angle.cos(), angle.sin())
180    }).collect();
181
182    // Multiply input by chirp: a[k] = x[k] * chirp[k]
183    let a: Vec<(f64, f64)> = data.iter().zip(chirp.iter()).map(|(&x, &w)| complex_mul(x, w)).collect();
184
185    // Convolution sequence: b[k] = conj(chirp[k])
186    // We need b extended for circular convolution
187    let m = next_power_of_2(2 * n - 1);
188
189    // Zero-pad a to length m
190    let mut a_padded = vec![(0.0, 0.0); m];
191    for (i, &v) in a.iter().enumerate() {
192        a_padded[i] = v;
193    }
194
195    // Build b_padded: b[0..n] = conj(chirp[0..n]), b[m-n+1..m] = conj(chirp[n-1..1])
196    let mut b_padded = vec![(0.0, 0.0); m];
197    for i in 0..n {
198        b_padded[i] = (chirp[i].0, -chirp[i].1); // conj
199    }
200    for i in 1..n {
201        b_padded[m - i] = (chirp[i].0, -chirp[i].1); // conj
202    }
203
204    // Convolve via FFT
205    let a_fft = fft(&a_padded);
206    let b_fft = fft(&b_padded);
207    let c_fft: Vec<(f64, f64)> = a_fft.iter().zip(b_fft.iter()).map(|(&a, &b)| complex_mul(a, b)).collect();
208    let c = ifft(&c_fft);
209
210    // Extract and multiply by chirp
211    (0..n).map(|k| complex_mul(chirp[k], c[k])).collect()
212}
213
214// ---------------------------------------------------------------------------
215// Phase B6: 2D FFT
216// ---------------------------------------------------------------------------
217
218/// Compute the 2-D FFT by applying 1-D [`fft`] along rows then along columns.
219///
220/// Both `rows` and `cols` must be powers of 2.
221///
222/// # Errors
223///
224/// Returns `Err` if `data.len() != rows * cols` or dimensions are not powers
225/// of 2.
226pub fn fft_2d(data: &[(f64, f64)], rows: usize, cols: usize) -> Result<Vec<(f64, f64)>, String> {
227    if data.len() != rows * cols {
228        return Err(format!("fft_2d: expected {} elements, got {}", rows * cols, data.len()));
229    }
230    if !rows.is_power_of_two() || !cols.is_power_of_two() {
231        return Err("fft_2d: rows and cols must be powers of 2".into());
232    }
233
234    // FFT along rows
235    let mut result = data.to_vec();
236    for r in 0..rows {
237        let row: Vec<(f64, f64)> = result[r * cols..(r + 1) * cols].to_vec();
238        let fft_row = fft(&row);
239        result[r * cols..(r + 1) * cols].copy_from_slice(&fft_row);
240    }
241
242    // FFT along columns
243    for c in 0..cols {
244        let col: Vec<(f64, f64)> = (0..rows).map(|r| result[r * cols + c]).collect();
245        let fft_col = fft(&col);
246        for r in 0..rows {
247            result[r * cols + c] = fft_col[r];
248        }
249    }
250
251    Ok(result)
252}
253
254/// Compute the 2-D inverse FFT by applying 1-D [`ifft`] along rows then columns.
255pub fn ifft_2d(data: &[(f64, f64)], rows: usize, cols: usize) -> Result<Vec<(f64, f64)>, String> {
256    if data.len() != rows * cols {
257        return Err(format!("ifft_2d: expected {} elements, got {}", rows * cols, data.len()));
258    }
259    if !rows.is_power_of_two() || !cols.is_power_of_two() {
260        return Err("ifft_2d: rows and cols must be powers of 2".into());
261    }
262
263    // IFFT along rows
264    let mut result = data.to_vec();
265    for r in 0..rows {
266        let row: Vec<(f64, f64)> = result[r * cols..(r + 1) * cols].to_vec();
267        let ifft_row = ifft(&row);
268        result[r * cols..(r + 1) * cols].copy_from_slice(&ifft_row);
269    }
270
271    // IFFT along columns
272    for c in 0..cols {
273        let col: Vec<(f64, f64)> = (0..rows).map(|r| result[r * cols + c]).collect();
274        let ifft_col = ifft(&col);
275        for r in 0..rows {
276            result[r * cols + c] = ifft_col[r];
277        }
278    }
279
280    Ok(result)
281}
282
283// ---------------------------------------------------------------------------
284// Tests
285// ---------------------------------------------------------------------------
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_fft_constant() {
293        // FFT of [1,1,1,1] = [4, 0, 0, 0]
294        let data = vec![(1.0, 0.0); 4];
295        let result = fft(&data);
296        assert!((result[0].0 - 4.0).abs() < 1e-12);
297        for i in 1..4 {
298            assert!(result[i].0.abs() < 1e-12);
299            assert!(result[i].1.abs() < 1e-12);
300        }
301    }
302
303    #[test]
304    fn test_fft_ifft_roundtrip() {
305        let data = vec![(1.0, 0.0), (2.0, 0.0), (3.0, 0.0), (4.0, 0.0)];
306        let spectrum = fft(&data);
307        let recovered = ifft(&spectrum);
308        for (orig, rec) in data.iter().zip(recovered.iter()) {
309            assert!((orig.0 - rec.0).abs() < 1e-10);
310            assert!((orig.1 - rec.1).abs() < 1e-10);
311        }
312    }
313
314    #[test]
315    fn test_rfft_basic() {
316        let data = [1.0, 0.0, 0.0, 0.0];
317        let result = rfft(&data);
318        // All ones for impulse
319        for &(re, im) in &result {
320            assert!((re - 1.0).abs() < 1e-12);
321            assert!(im.abs() < 1e-12);
322        }
323    }
324
325    #[test]
326    fn test_psd_basic() {
327        let data = [1.0, 0.0, 0.0, 0.0];
328        let power = psd(&data);
329        // All 1.0 for impulse
330        for &p in &power {
331            assert!((p - 1.0).abs() < 1e-12);
332        }
333    }
334
335    #[test]
336    fn test_determinism() {
337        let data = vec![(1.0, 2.0), (3.0, 4.0), (5.0, 6.0), (7.0, 8.0)];
338        let r1 = fft(&data);
339        let r2 = fft(&data);
340        for (a, b) in r1.iter().zip(r2.iter()) {
341            assert_eq!(a.0.to_bits(), b.0.to_bits());
342            assert_eq!(a.1.to_bits(), b.1.to_bits());
343        }
344    }
345
346    // --- B6: Window functions ---
347
348    #[test]
349    fn test_hann_endpoints() {
350        let w = hann_window(8);
351        assert!(w[0].abs() < 1e-12, "hann[0] = {}", w[0]);
352        assert!(w[7].abs() < 1e-12, "hann[N-1] = {}", w[7]);
353    }
354
355    #[test]
356    fn test_hann_midpoint() {
357        let w = hann_window(9); // odd, so exact midpoint
358        assert!((w[4] - 1.0).abs() < 1e-12, "hann[4] = {}", w[4]);
359    }
360
361    #[test]
362    fn test_hann_symmetry() {
363        let w = hann_window(16);
364        for k in 0..8 {
365            assert!((w[k] - w[15 - k]).abs() < 1e-12);
366        }
367    }
368
369    #[test]
370    fn test_hamming_endpoints() {
371        let w = hamming_window(8);
372        assert!((w[0] - 0.08).abs() < 1e-12, "hamming[0] = {}", w[0]);
373        assert!((w[7] - 0.08).abs() < 1e-12, "hamming[N-1] = {}", w[7]);
374    }
375
376    #[test]
377    fn test_blackman_endpoints() {
378        let w = blackman_window(16);
379        assert!(w[0].abs() < 1e-12, "blackman[0] = {}", w[0]);
380        assert!(w[15].abs() < 1e-12, "blackman[N-1] = {}", w[15]);
381    }
382
383    // --- B6: Arbitrary FFT ---
384
385    #[test]
386    fn test_fft_arbitrary_prime() {
387        // 7-element signal, brute-force DFT
388        let n = 7;
389        let data: Vec<(f64, f64)> = (0..n).map(|k| ((k + 1) as f64, 0.0)).collect();
390        let result = fft_arbitrary(&data);
391
392        // Brute-force DFT for comparison
393        for k in 0..n {
394            let mut re = 0.0;
395            let mut im = 0.0;
396            for j in 0..n {
397                let angle = -2.0 * PI * (k * j) as f64 / n as f64;
398                re += data[j].0 * angle.cos() - data[j].1 * angle.sin();
399                im += data[j].0 * angle.sin() + data[j].1 * angle.cos();
400            }
401            assert!((result[k].0 - re).abs() < 1e-8, "re[{k}]: got {} expected {re}", result[k].0);
402            assert!((result[k].1 - im).abs() < 1e-8, "im[{k}]: got {} expected {im}", result[k].1);
403        }
404    }
405
406    #[test]
407    fn test_fft_arbitrary_matches_radix2() {
408        let data: Vec<(f64, f64)> = vec![(1.0, 0.0), (2.0, 0.0), (3.0, 0.0), (4.0, 0.0)];
409        let r_radix2 = fft(&data);
410        let r_arb = fft_arbitrary(&data);
411        for (a, b) in r_radix2.iter().zip(r_arb.iter()) {
412            assert!((a.0 - b.0).abs() < 1e-10);
413            assert!((a.1 - b.1).abs() < 1e-10);
414        }
415    }
416
417    #[test]
418    fn test_fft_arbitrary_parseval() {
419        let data: Vec<(f64, f64)> = vec![(1.0, 0.0), (2.0, 1.0), (3.0, -1.0), (0.5, 0.5), (4.0, 0.0)];
420        let n = data.len();
421        let time_energy: f64 = data.iter().map(|&(r, i)| r * r + i * i).sum();
422        let freq = fft_arbitrary(&data);
423        let freq_energy: f64 = freq.iter().map(|&(r, i)| r * r + i * i).sum::<f64>() / n as f64;
424        assert!((time_energy - freq_energy).abs() < 1e-8, "time={time_energy} freq={freq_energy}");
425    }
426
427    // --- B6: 2D FFT ---
428
429    #[test]
430    fn test_fft_2d_constant() {
431        let data = vec![(1.0, 0.0); 4]; // 2x2 constant
432        let result = fft_2d(&data, 2, 2).unwrap();
433        // DC component should be N*M = 4
434        assert!((result[0].0 - 4.0).abs() < 1e-10);
435        for i in 1..4 {
436            assert!(result[i].0.abs() < 1e-10);
437            assert!(result[i].1.abs() < 1e-10);
438        }
439    }
440
441    #[test]
442    fn test_fft_2d_roundtrip() {
443        let data: Vec<(f64, f64)> = vec![(1.0, 0.0), (2.0, 0.0), (3.0, 0.0), (4.0, 0.0)]; // 2x2
444        let freq = fft_2d(&data, 2, 2).unwrap();
445        let recovered = ifft_2d(&freq, 2, 2).unwrap();
446        for (orig, rec) in data.iter().zip(recovered.iter()) {
447            assert!((orig.0 - rec.0).abs() < 1e-10);
448            assert!((orig.1 - rec.1).abs() < 1e-10);
449        }
450    }
451
452    #[test]
453    fn test_b6_fft_determinism() {
454        let data: Vec<(f64, f64)> = vec![(1.0, 2.0), (3.0, 0.0), (5.0, -1.0)];
455        let r1 = fft_arbitrary(&data);
456        let r2 = fft_arbitrary(&data);
457        for (a, b) in r1.iter().zip(r2.iter()) {
458            assert_eq!(a.0.to_bits(), b.0.to_bits());
459            assert_eq!(a.1.to_bits(), b.1.to_bits());
460        }
461    }
462}