Skip to main content

scirs2_fft/
butterfly.rs

1//! Enhanced Butterfly Operations for FFT
2//!
3//! This module provides optimized, in-place butterfly operations for radix-2,
4//! radix-4, radix-8, and split-radix FFT algorithms.  All operations work
5//! directly on mutable slices / arrays without any heap allocation.
6//!
7//! # Overview
8//!
9//! | Function | Radix | Points | Twiddles |
10//! |----------|-------|--------|----------|
11//! | [`butterfly2`] | 2 | 2 | 1 |
12//! | [`butterfly4`] | 4 | 4 | 3 |
13//! | [`butterfly8`] | 8 | 8 | 7 |
14//! | [`split_radix_butterfly`] | 2/4 | any | computed |
15//!
16//! Additionally, [`generate_twiddle_table`] pre-computes the unit-root vector
17//! `e^{-2 pi i k / N}` for `k = 0 .. N-1`.
18
19use scirs2_core::numeric::Complex64;
20use std::f64::consts::PI;
21
22use crate::error::{FFTError, FFTResult};
23
24// ─────────────────────────────────────────────────────────────────────────────
25//  Twiddle-table generation
26// ─────────────────────────────────────────────────────────────────────────────
27
28/// Generate twiddle factor table for an N-point DFT.
29///
30/// Returns `W[k] = e^{-2 pi i k / N}` for `k = 0, 1, ..., N-1`.
31///
32/// # Errors
33///
34/// Returns [`FFTError::ValueError`] if `n == 0`.
35pub fn generate_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
36    if n == 0 {
37        return Err(FFTError::ValueError(
38            "generate_twiddle_table: n must be > 0".into(),
39        ));
40    }
41    if n == 1 {
42        return Ok(vec![Complex64::new(1.0, 0.0)]);
43    }
44    let inv_n = -2.0 * PI / n as f64;
45    Ok((0..n)
46        .map(|k| {
47            let angle = inv_n * k as f64;
48            Complex64::new(angle.cos(), angle.sin())
49        })
50        .collect())
51}
52
53/// Generate twiddle factor table for an inverse N-point DFT.
54///
55/// Returns `W[k] = e^{+2 pi i k / N}` for `k = 0, 1, ..., N-1`.
56///
57/// # Errors
58///
59/// Returns [`FFTError::ValueError`] if `n == 0`.
60pub fn generate_inverse_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
61    if n == 0 {
62        return Err(FFTError::ValueError(
63            "generate_inverse_twiddle_table: n must be > 0".into(),
64        ));
65    }
66    if n == 1 {
67        return Ok(vec![Complex64::new(1.0, 0.0)]);
68    }
69    let inv_n = 2.0 * PI / n as f64;
70    Ok((0..n)
71        .map(|k| {
72            let angle = inv_n * k as f64;
73            Complex64::new(angle.cos(), angle.sin())
74        })
75        .collect())
76}
77
78// ─────────────────────────────────────────────────────────────────────────────
79//  Radix-2 butterfly
80// ─────────────────────────────────────────────────────────────────────────────
81
82/// In-place radix-2 butterfly.
83///
84/// Computes:
85/// ```text
86///   a' = a + twiddle * b
87///   b' = a - twiddle * b
88/// ```
89///
90/// This is the fundamental building block of a decimation-in-time (DIT)
91/// Cooley-Tukey FFT.  It is completely allocation-free.
92#[inline(always)]
93pub fn butterfly2(a: &mut Complex64, b: &mut Complex64, twiddle: Complex64) {
94    let t = twiddle * *b;
95    let new_a = *a + t;
96    let new_b = *a - t;
97    *a = new_a;
98    *b = new_b;
99}
100
101// ─────────────────────────────────────────────────────────────────────────────
102//  Radix-4 butterfly
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// In-place radix-4 DFT butterfly.
106///
107/// Computes a 4-point DFT in-place using twiddle factors.
108/// `twiddles[0]` = W_4^1, `twiddles[1]` = W_4^2, `twiddles[2]` = W_4^3.
109///
110/// For a standalone 4-point DFT: `W_4^1 = e^{-pi i/2} = -j`,
111/// `W_4^2 = -1`, `W_4^3 = j`.
112///
113/// The standard radix-4 DFT decomposition:
114/// ```text
115///   X[0] = x[0] + x[1] + x[2] + x[3]
116///   X[1] = x[0] - j*x[1] - x[2] + j*x[3]
117///   X[2] = x[0] - x[1] + x[2] - x[3]
118///   X[3] = x[0] + j*x[1] - x[2] - j*x[3]
119/// ```
120#[inline]
121pub fn butterfly4(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
122    // Standard 4-point DFT using the DFT matrix approach:
123    // X[k] = sum_{n=0}^{3} x[n] * W_4^{nk}
124    // where W_4 = e^{-2*pi*i/4} = -j
125    //
126    // twiddles[0] = W_4^1 = -j
127    // twiddles[1] = W_4^2 = -1
128    // twiddles[2] = W_4^3 = j
129
130    let x0 = a[0];
131    let x1 = a[1];
132    let x2 = a[2];
133    let x3 = a[3];
134
135    // X[0] = x0 + x1 + x2 + x3
136    a[0] = x0 + x1 + x2 + x3;
137
138    // X[1] = x0 + W^1*x1 + W^2*x2 + W^3*x3
139    a[1] = x0 + twiddles[0] * x1 + twiddles[1] * x2 + twiddles[2] * x3;
140
141    // X[2] = x0 + W^2*x1 + W^4*x2 + W^6*x3
142    //      = x0 + W^2*x1 + (W^2)^2*x2 + (W^2)^3*x3
143    let w2 = twiddles[1]; // W^2
144    let w4 = w2 * w2; // W^4 = (W^2)^2
145    let w6 = w4 * w2; // W^6
146    a[2] = x0 + w2 * x1 + w4 * x2 + w6 * x3;
147
148    // X[3] = x0 + W^3*x1 + W^6*x2 + W^9*x3
149    let w3 = twiddles[2]; // W^3
150    let w9 = w3 * w3 * w3; // W^9
151    a[3] = x0 + w3 * x1 + w6 * x2 + w9 * x3;
152}
153
154// ─────────────────────────────────────────────────────────────────────────────
155//  Radix-8 butterfly
156// ─────────────────────────────────────────────────────────────────────────────
157
158/// In-place radix-8 DFT butterfly.
159///
160/// Computes an 8-point DFT in-place. `twiddles[k]` = `W_8^{k+1}` for k=0..6,
161/// i.e., `twiddles` contains `W_8^1` through `W_8^7`.
162#[inline]
163pub fn butterfly8(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
164    // Direct 8-point DFT: X[k] = sum_{n=0}^{7} x[n] * W_8^{n*k}
165    // We pre-compute the powers W_8^m for m = 0..7 from the twiddle table.
166    let w = [
167        Complex64::new(1.0, 0.0), // W^0
168        twiddles[0],              // W^1
169        twiddles[1],              // W^2
170        twiddles[2],              // W^3
171        twiddles[3],              // W^4
172        twiddles[4],              // W^5
173        twiddles[5],              // W^6
174        twiddles[6],              // W^7
175    ];
176
177    let input = *a;
178    for k in 0..8 {
179        let mut sum = Complex64::new(0.0, 0.0);
180        for n in 0..8 {
181            let idx = (n * k) % 8;
182            sum += input[n] * w[idx];
183        }
184        a[k] = sum;
185    }
186}
187
188// ─────────────────────────────────────────────────────────────────────────────
189//  Split-radix butterfly (L-shaped)
190// ─────────────────────────────────────────────────────────────────────────────
191
192/// In-place split-radix FFT for a complex array of length N.
193///
194/// Implements the Cooley-Tukey radix-2 DIT FFT with bit-reversal
195/// permutation.  This is the standard iterative butterfly algorithm
196/// that achieves O(N log N) complexity.
197///
198/// # Errors
199///
200/// Returns [`FFTError::ValueError`] if `data.len()` is not a power of two or is < 4.
201pub fn split_radix_butterfly(data: &mut [Complex64]) -> FFTResult<()> {
202    let n = data.len();
203    if n < 4 {
204        return Err(FFTError::ValueError(
205            "split_radix_butterfly: length must be >= 4".into(),
206        ));
207    }
208    if !n.is_power_of_two() {
209        return Err(FFTError::ValueError(
210            "split_radix_butterfly: length must be a power of two".into(),
211        ));
212    }
213
214    // Bit-reversal permutation
215    let bits = n.trailing_zeros();
216    for i in 0..n {
217        let j = reverse_bits(i, bits);
218        if i < j {
219            data.swap(i, j);
220        }
221    }
222
223    // Iterative butterfly passes
224    let mut size = 2;
225    while size <= n {
226        let half = size / 2;
227        let angle_step = -2.0 * PI / size as f64;
228
229        let mut group_start = 0;
230        while group_start < n {
231            for k in 0..half {
232                let angle = angle_step * k as f64;
233                let twiddle = Complex64::new(angle.cos(), angle.sin());
234
235                let i = group_start + k;
236                let j = i + half;
237
238                let t = twiddle * data[j];
239                data[j] = data[i] - t;
240                data[i] = data[i] + t;
241            }
242            group_start += size;
243        }
244        size *= 2;
245    }
246
247    Ok(())
248}
249
250/// Reverse the lower `bits` bits of `x`.
251fn reverse_bits(x: usize, bits: u32) -> usize {
252    let mut result = 0usize;
253    let mut val = x;
254    for _ in 0..bits {
255        result = (result << 1) | (val & 1);
256        val >>= 1;
257    }
258    result
259}
260
261// ─────────────────────────────────────────────────────────────────────────────
262//  Direct DFT for small sizes (base-case for recursive algorithms)
263// ─────────────────────────────────────────────────────────────────────────────
264
265/// Compute a direct (naive) DFT for small `N`.
266///
267/// Complexity is O(N^2) so this should only be used for small base cases
268/// (typically N <= 16).
269///
270/// # Errors
271///
272/// Returns [`FFTError::ValueError`] if `data` is empty.
273pub fn direct_dft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
274    let n = data.len();
275    if n == 0 {
276        return Err(FFTError::ValueError("direct_dft: empty input".into()));
277    }
278    if n == 1 {
279        return Ok(data.to_vec());
280    }
281
282    let angle_base = -2.0 * PI / n as f64;
283    let mut result = vec![Complex64::new(0.0, 0.0); n];
284    for k in 0..n {
285        let mut sum = Complex64::new(0.0, 0.0);
286        for j in 0..n {
287            let angle = angle_base * (k * j) as f64;
288            let w = Complex64::new(angle.cos(), angle.sin());
289            sum += data[j] * w;
290        }
291        result[k] = sum;
292    }
293    Ok(result)
294}
295
296/// Compute a direct (naive) inverse DFT for small `N`.
297///
298/// # Errors
299///
300/// Returns [`FFTError::ValueError`] if `data` is empty.
301pub fn direct_idft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
302    let n = data.len();
303    if n == 0 {
304        return Err(FFTError::ValueError("direct_idft: empty input".into()));
305    }
306    if n == 1 {
307        return Ok(data.to_vec());
308    }
309
310    let angle_base = 2.0 * PI / n as f64;
311    let inv_n = 1.0 / n as f64;
312    let mut result = vec![Complex64::new(0.0, 0.0); n];
313    for k in 0..n {
314        let mut sum = Complex64::new(0.0, 0.0);
315        for j in 0..n {
316            let angle = angle_base * (k * j) as f64;
317            let w = Complex64::new(angle.cos(), angle.sin());
318            sum += data[j] * w;
319        }
320        result[k] = sum * inv_n;
321    }
322    Ok(result)
323}
324
325// ─────────────────────────────────────────────────────────────────────────────
326//  Tests
327// ─────────────────────────────────────────────────────────────────────────────
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use approx::assert_relative_eq;
333
334    /// Helper: maximum absolute error between two complex slices.
335    fn max_abs_err(a: &[Complex64], b: &[Complex64]) -> f64 {
336        a.iter()
337            .zip(b.iter())
338            .map(|(x, y)| (x - y).norm())
339            .fold(0.0_f64, f64::max)
340    }
341
342    // ── twiddle table ────────────────────────────────────────────────────
343    #[test]
344    fn test_twiddle_table_size_1() {
345        let tw = generate_twiddle_table(1).expect("should succeed");
346        assert_eq!(tw.len(), 1);
347        assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-15);
348        assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-15);
349    }
350
351    #[test]
352    fn test_twiddle_table_values() {
353        let n = 8;
354        let tw = generate_twiddle_table(n).expect("should succeed");
355        assert_eq!(tw.len(), n);
356
357        // W^0 = 1
358        assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-14);
359        assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-14);
360
361        // W^(N/4) = e^{-pi i / 2} = -j
362        assert_relative_eq!(tw[n / 4].re, 0.0, epsilon = 1e-14);
363        assert_relative_eq!(tw[n / 4].im, -1.0, epsilon = 1e-14);
364
365        // W^(N/2) = e^{-pi i} = -1
366        assert_relative_eq!(tw[n / 2].re, -1.0, epsilon = 1e-14);
367        assert_relative_eq!(tw[n / 2].im, 0.0, epsilon = 1e-14);
368
369        // All magnitudes should be 1
370        for w in &tw {
371            assert_relative_eq!(w.norm(), 1.0, epsilon = 1e-14);
372        }
373    }
374
375    #[test]
376    fn test_twiddle_table_error_on_zero() {
377        assert!(generate_twiddle_table(0).is_err());
378    }
379
380    // ── butterfly2 ──────────────────────────────────────────────────────
381    #[test]
382    fn test_butterfly2_trivial_twiddle() {
383        let mut a = Complex64::new(3.0, 0.0);
384        let mut b = Complex64::new(1.0, 0.0);
385        butterfly2(&mut a, &mut b, Complex64::new(1.0, 0.0));
386        assert_relative_eq!(a.re, 4.0, epsilon = 1e-14);
387        assert_relative_eq!(b.re, 2.0, epsilon = 1e-14);
388    }
389
390    #[test]
391    fn test_butterfly2_with_twiddle() {
392        // W = -1  =>  a' = a + (-1)*b = a-b,  b' = a - (-1)*b = a+b
393        let mut a = Complex64::new(5.0, 0.0);
394        let mut b = Complex64::new(3.0, 0.0);
395        butterfly2(&mut a, &mut b, Complex64::new(-1.0, 0.0));
396        assert_relative_eq!(a.re, 2.0, epsilon = 1e-14);
397        assert_relative_eq!(b.re, 8.0, epsilon = 1e-14);
398    }
399
400    // ── butterfly4 ──────────────────────────────────────────────────────
401    #[test]
402    fn test_butterfly4_matches_direct_dft() {
403        let input = [
404            Complex64::new(1.0, 0.0),
405            Complex64::new(2.0, 0.0),
406            Complex64::new(3.0, 0.0),
407            Complex64::new(4.0, 0.0),
408        ];
409        let expected = direct_dft(&input).expect("direct_dft failed");
410
411        // For a 4-point DFT: W_4 = e^{-2*pi*i/4} = e^{-pi*i/2} = -j
412        // twiddles[0] = W_4^1 = -j
413        // twiddles[1] = W_4^2 = -1
414        // twiddles[2] = W_4^3 = j
415        let twiddles = [
416            Complex64::new(0.0, -1.0), // W4^1 = -j
417            Complex64::new(-1.0, 0.0), // W4^2 = -1
418            Complex64::new(0.0, 1.0),  // W4^3 = j
419        ];
420        let mut data = input;
421        butterfly4(&mut data, &twiddles);
422
423        let err = max_abs_err(&data, &expected);
424        assert!(err < 1e-12, "butterfly4 error = {err}");
425    }
426
427    // ── butterfly8 ──────────────────────────────────────────────────────
428    #[test]
429    fn test_butterfly8_matches_direct_dft() {
430        let input: [Complex64; 8] = [
431            Complex64::new(1.0, 0.0),
432            Complex64::new(2.0, -1.0),
433            Complex64::new(0.5, 0.5),
434            Complex64::new(3.0, 0.0),
435            Complex64::new(-1.0, 1.0),
436            Complex64::new(0.0, 2.0),
437            Complex64::new(1.5, -0.5),
438            Complex64::new(-0.5, 0.0),
439        ];
440        let expected = direct_dft(&input).expect("direct_dft failed");
441
442        // W_8^k for k=1..7
443        let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
444            let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
445            Complex64::new(angle.cos(), angle.sin())
446        });
447
448        let mut data = input;
449        butterfly8(&mut data, &twiddles);
450
451        let err = max_abs_err(&data, &expected);
452        assert!(err < 1e-10, "butterfly8 error = {err}");
453    }
454
455    // ── direct DFT ──────────────────────────────────────────────────────
456    #[test]
457    fn test_direct_dft_known_result() {
458        // DFT of [1, 1, 1, 1] = [4, 0, 0, 0]
459        let input = vec![Complex64::new(1.0, 0.0); 4];
460        let result = direct_dft(&input).expect("direct_dft failed");
461        assert_relative_eq!(result[0].re, 4.0, epsilon = 1e-12);
462        for k in 1..4 {
463            assert!(result[k].norm() < 1e-12, "non-zero at k={k}");
464        }
465    }
466
467    #[test]
468    fn test_direct_dft_idft_roundtrip() {
469        let input = vec![
470            Complex64::new(1.0, 2.0),
471            Complex64::new(3.0, -1.0),
472            Complex64::new(0.5, 0.5),
473            Complex64::new(-2.0, 1.5),
474        ];
475        let spectrum = direct_dft(&input).expect("dft failed");
476        let recovered = direct_idft(&spectrum).expect("idft failed");
477        let err = max_abs_err(&input, &recovered);
478        assert!(err < 1e-12, "roundtrip error = {err}");
479    }
480
481    #[test]
482    fn test_direct_dft_empty() {
483        assert!(direct_dft(&[]).is_err());
484    }
485
486    // ── split-radix butterfly ───────────────────────────────────────────
487    #[test]
488    fn test_split_radix_butterfly_size_4() {
489        let input = vec![
490            Complex64::new(1.0, 0.0),
491            Complex64::new(0.0, 1.0),
492            Complex64::new(-1.0, 0.0),
493            Complex64::new(0.0, -1.0),
494        ];
495        let expected = direct_dft(&input).expect("dft failed");
496        let mut data = input;
497        split_radix_butterfly(&mut data).expect("split_radix failed");
498        let err = max_abs_err(&data, &expected);
499        assert!(err < 1e-10, "split_radix error (n=4) = {err}");
500    }
501
502    #[test]
503    fn test_split_radix_butterfly_size_8() {
504        let input: Vec<Complex64> = (0..8)
505            .map(|k| Complex64::new(k as f64, -(k as f64) * 0.5))
506            .collect();
507        let expected = direct_dft(&input).expect("dft failed");
508        let mut data = input;
509        split_radix_butterfly(&mut data).expect("split_radix failed");
510        let err = max_abs_err(&data, &expected);
511        assert!(err < 1e-10, "split_radix error (n=8) = {err}");
512    }
513
514    #[test]
515    fn test_split_radix_butterfly_size_16() {
516        let input: Vec<Complex64> = (0..16)
517            .map(|k| Complex64::new((k as f64 * 0.5).sin(), (k as f64 * 0.3).cos()))
518            .collect();
519        let expected = direct_dft(&input).expect("dft failed");
520        let mut data = input;
521        split_radix_butterfly(&mut data).expect("split_radix failed");
522        let err = max_abs_err(&data, &expected);
523        assert!(err < 1e-10, "split_radix error (n=16) = {err}");
524    }
525
526    #[test]
527    fn test_split_radix_butterfly_not_power_of_two() {
528        let mut data = vec![Complex64::new(1.0, 0.0); 6];
529        assert!(split_radix_butterfly(&mut data).is_err());
530    }
531
532    #[test]
533    fn test_split_radix_butterfly_too_small() {
534        let mut data = vec![Complex64::new(1.0, 0.0); 2];
535        assert!(split_radix_butterfly(&mut data).is_err());
536    }
537
538    // ── inverse twiddle table ───────────────────────────────────────────
539    #[test]
540    fn test_inverse_twiddle_table() {
541        let n = 8;
542        let fw = generate_twiddle_table(n).expect("forward failed");
543        let inv = generate_inverse_twiddle_table(n).expect("inverse failed");
544        // W[k] * W_inv[k] = |W|^2 = 1, since W_inv[k] = conj(W[k])
545        for k in 0..n {
546            let product = fw[k] * inv[k];
547            assert_relative_eq!(product.re, 1.0, epsilon = 1e-14);
548            assert_relative_eq!(product.im, 0.0, epsilon = 1e-14);
549        }
550    }
551}