Skip to main content

phasm_core/stego/armor/
fft2d.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Deterministic 2D FFT/IFFT using only WASM-intrinsic f64 operations.
6//!
7//! Replaces `rustfft` with an in-house implementation:
8//! - Radix-2 Cooley-Tukey for power-of-2 sizes
9//! - Bluestein's chirp-z transform for arbitrary sizes
10//! All twiddle factors computed via `det_sincos()`.
11//!
12//! Memory optimizations (Phase 3):
13//! - P1a: Column FFTs use gather-FFT-scatter with a single column buffer
14//!   instead of a full transposed copy (saves ~186 MB for 4032x3024).
15//! - P1b: All FFT data uses f32 (Complex32) — template detection only needs
16//!   coarse peak finding. Twiddle factors still computed in f64 via
17//!   `det_sincos()` then cast to f32.
18//! - P2a: BluesteinPlan precomputes chirp factors and FFT(b_hat) for reuse
19//!   across all rows/columns of the same length.
20
21use num_complex::Complex;
22use crate::det_math::{det_sincos, det_hypot};
23use std::f64::consts::PI;
24
25/// Complex32 type alias for f32 complex numbers.
26pub type Complex32 = Complex<f32>;
27
28/// 2D complex spectrum using f32 for memory efficiency.
29pub struct Spectrum2D {
30    pub data: Vec<Complex32>,
31    pub width: usize,
32    pub height: usize,
33}
34
35// ──────────────────────────────────────────────────────────────────────────
36// Bluestein plan: precomputed chirp factors for reuse (P2a)
37// ──────────────────────────────────────────────────────────────────────────
38
39/// Precomputed Bluestein chirp factors and FFT(b_hat) for a given (n, sign).
40///
41/// Eliminates redundant chirp computation and FFT(b) calls when processing
42/// many rows or columns of the same length.
43struct BluesteinPlan {
44    n: usize,
45    m: usize, // next_pow2(2*n - 1)
46    chirp: Vec<Complex32>,
47    b_hat: Vec<Complex32>, // FFT of padded conjugate chirp
48}
49
50impl BluesteinPlan {
51    /// Create a new Bluestein plan for length `n` and direction `sign`.
52    fn new(n: usize, sign: f64) -> Self {
53        let m = next_pow2(2 * n - 1);
54
55        // Chirp factors: w_k = exp(sign * i * pi * k^2 / n)
56        let mut chirp = vec![Complex32::new(0.0, 0.0); n];
57        for k in 0..n {
58            let angle = sign * PI * (k as f64 * k as f64) / n as f64;
59            let (s, c) = det_sincos(angle);
60            chirp[k] = Complex32::new(c as f32, s as f32);
61        }
62
63        // b[k] = chirp[k], with wrap-around for negative indices, zero-padded
64        let mut b = vec![Complex32::new(0.0, 0.0); m];
65        b[0] = chirp[0];
66        for k in 1..n {
67            b[k] = chirp[k];
68            b[m - k] = chirp[k];
69        }
70
71        // Precompute FFT(b)
72        fft_radix2_f32(&mut b, -1.0);
73
74        BluesteinPlan { n, m, chirp, b_hat: b }
75    }
76
77    /// Execute Bluestein FFT using precomputed plan.
78    fn execute(&self, input: &[Complex32]) -> Vec<Complex32> {
79        debug_assert_eq!(input.len(), self.n);
80
81        // a[k] = x[k] * conj(chirp[k]), zero-padded to length m
82        let mut a = vec![Complex32::new(0.0, 0.0); self.m];
83        for k in 0..self.n {
84            a[k] = input[k] * self.chirp[k].conj();
85        }
86
87        // Convolve: A = FFT(a), C = IFFT(A * B_hat)
88        fft_radix2_f32(&mut a, -1.0);
89        for i in 0..self.m {
90            a[i] *= self.b_hat[i];
91        }
92        fft_radix2_f32(&mut a, 1.0);
93
94        // Normalize radix-2 inverse and apply chirp
95        let inv_m = 1.0 / self.m as f32;
96        let mut result = vec![Complex32::new(0.0, 0.0); self.n];
97        for k in 0..self.n {
98            result[k] = a[k] * inv_m * self.chirp[k].conj();
99        }
100
101        result
102    }
103}
104
105// ──────────────────────────────────────────────────────────────────────────
106// 1D FFT primitives (f32)
107// ──────────────────────────────────────────────────────────────────────────
108
109/// Next power of 2 >= n.
110fn next_pow2(n: usize) -> usize {
111    let mut p = 1;
112    while p < n {
113        p <<= 1;
114    }
115    p
116}
117
118/// In-place radix-2 Cooley-Tukey FFT for f32.  `data.len()` must be a power of 2.
119/// `sign`: -1.0 for forward FFT, +1.0 for inverse FFT.
120fn fft_radix2_f32(data: &mut [Complex32], sign: f64) {
121    let n = data.len();
122    debug_assert!(n.is_power_of_two());
123    if n <= 1 {
124        return;
125    }
126
127    // Bit-reversal permutation
128    let mut j = 0usize;
129    for i in 1..n {
130        let mut bit = n >> 1;
131        while j & bit != 0 {
132            j ^= bit;
133            bit >>= 1;
134        }
135        j ^= bit;
136        if i < j {
137            data.swap(i, j);
138        }
139    }
140
141    // Pre-allocate the twiddle buffer once; largest stage needs n/2 entries.
142    let mut twiddles: Vec<Complex32> = Vec::with_capacity(n / 2);
143
144    // Butterfly stages
145    let mut len = 2;
146    while len <= n {
147        let half = len / 2;
148        let angle_step = sign * PI / half as f64;
149
150        // Pre-compute twiddles for this stage. The previous code recomputed
151        // (s, c) for every (start, k) pair, but the values depend only on
152        // (half, k) — within a stage of width `len`, each twiddle was being
153        // recomputed `n / len` times. With this cache, det_sincos is called
154        // `half` times per stage instead of `(n / len) * half = n / 2`. Total
155        // det_sincos calls across all stages drop from `~(n / 2) · log2(n)`
156        // to `~n` — a `log2(n)` reduction (≈12× on 4096-point row FFTs).
157        twiddles.clear();
158        twiddles.extend((0..half).map(|k| {
159            let angle = angle_step * k as f64;
160            let (s, c) = det_sincos(angle);
161            Complex32::new(c as f32, s as f32)
162        }));
163
164        for start in (0..n).step_by(len) {
165            for k in 0..half {
166                let w = twiddles[k];
167                let u = data[start + k];
168                let v = data[start + k + half] * w;
169                data[start + k] = u + v;
170                data[start + k + half] = u - v;
171            }
172        }
173        len <<= 1;
174    }
175}
176
177/// 1D FFT for arbitrary length using f32.
178/// Uses BluesteinPlan if available (for 2D FFT reuse), or creates one on the fly.
179/// `sign`: -1.0 for forward, +1.0 for inverse.
180fn fft1d_f32_with_plan(input: &[Complex32], sign: f64, plan: Option<&BluesteinPlan>) -> Vec<Complex32> {
181    let n = input.len();
182    if n == 0 {
183        return vec![];
184    }
185    if n == 1 {
186        return input.to_vec();
187    }
188    if n.is_power_of_two() {
189        let mut buf = input.to_vec();
190        fft_radix2_f32(&mut buf, sign);
191        return buf;
192    }
193
194    // Use precomputed plan if available
195    if let Some(p) = plan {
196        debug_assert_eq!(p.n, n);
197        return p.execute(input);
198    }
199
200    // Fallback: create a temporary plan
201    let temp_plan = BluesteinPlan::new(n, sign);
202    temp_plan.execute(input)
203}
204
205/// 1D forward FFT (arbitrary length, f32).
206#[allow(dead_code)]
207fn fft1d_f32(data: &[Complex32]) -> Vec<Complex32> {
208    fft1d_f32_with_plan(data, -1.0, None)
209}
210
211/// 1D inverse FFT (arbitrary length, f32) — unnormalized.
212#[allow(dead_code)]
213fn ifft1d_f32(data: &[Complex32]) -> Vec<Complex32> {
214    fft1d_f32_with_plan(data, 1.0, None)
215}
216
217// ──────────────────────────────────────────────────────────────────────────
218// 2D FFT / IFFT — public API (f32, memory-optimized)
219// ──────────────────────────────────────────────────────────────────────────
220
221/// Real-valued pixel array -> 2D complex spectrum (f32).
222///
223/// The input is a row-major f64 array of size `width * height`.
224/// Uses gather-FFT-scatter for columns (P1a) and precomputed Bluestein
225/// plans for chirp reuse (P2a).
226pub fn fft2d(pixels: &[f64], width: usize, height: usize) -> Spectrum2D {
227    assert_eq!(pixels.len(), width * height);
228
229    let mut data: Vec<Complex32> = pixels.iter().map(|&v| Complex32::new(v as f32, 0.0)).collect();
230
231    // P2a: Precompute Bluestein plans for row and column lengths (if non-power-of-2)
232    let row_plan = if !width.is_power_of_two() && width > 1 {
233        Some(BluesteinPlan::new(width, -1.0))
234    } else {
235        None
236    };
237    let col_plan = if !height.is_power_of_two() && height > 1 {
238        Some(BluesteinPlan::new(height, -1.0))
239    } else {
240        None
241    };
242
243    // FFT each row
244    for row in 0..height {
245        let start = row * width;
246        let row_data = &data[start..start + width];
247        let transformed = fft1d_f32_with_plan(row_data, -1.0, row_plan.as_ref());
248        data[start..start + width].copy_from_slice(&transformed);
249    }
250
251    // P1a: FFT each column using gather-FFT-scatter with a single column buffer.
252    // No full transposed buffer needed.
253    let mut col_buf = vec![Complex32::new(0.0, 0.0); height];
254    for col in 0..width {
255        // Gather column
256        for r in 0..height {
257            col_buf[r] = data[r * width + col];
258        }
259        // FFT
260        let transformed = fft1d_f32_with_plan(&col_buf, -1.0, col_plan.as_ref());
261        // Scatter back
262        for r in 0..height {
263            data[r * width + col] = transformed[r];
264        }
265    }
266
267    Spectrum2D { data, width, height }
268}
269
270/// 2D complex spectrum -> real-valued pixel array.
271///
272/// Takes the real parts after inverse FFT, normalized by `1/(width*height)`.
273/// Uses gather-IFFT-scatter for columns (P1a) and precomputed Bluestein
274/// plans for chirp reuse (P2a).
275pub fn ifft2d(spectrum: &Spectrum2D) -> Vec<f64> {
276    let width = spectrum.width;
277    let height = spectrum.height;
278    let mut data = spectrum.data.clone();
279
280    // P2a: Precompute Bluestein plans for row and column lengths (if non-power-of-2)
281    let row_plan = if !width.is_power_of_two() && width > 1 {
282        Some(BluesteinPlan::new(width, 1.0))
283    } else {
284        None
285    };
286    let col_plan = if !height.is_power_of_two() && height > 1 {
287        Some(BluesteinPlan::new(height, 1.0))
288    } else {
289        None
290    };
291
292    // IFFT each row
293    for row in 0..height {
294        let start = row * width;
295        let row_data = &data[start..start + width];
296        let transformed = fft1d_f32_with_plan(row_data, 1.0, row_plan.as_ref());
297        data[start..start + width].copy_from_slice(&transformed);
298    }
299
300    // P1a: IFFT each column using gather-IFFT-scatter with a single column buffer.
301    let mut col_buf = vec![Complex32::new(0.0, 0.0); height];
302    for col in 0..width {
303        // Gather column
304        for r in 0..height {
305            col_buf[r] = data[r * width + col];
306        }
307        // IFFT
308        let transformed = fft1d_f32_with_plan(&col_buf, 1.0, col_plan.as_ref());
309        // Scatter back
310        for r in 0..height {
311            data[r * width + col] = transformed[r];
312        }
313    }
314
315    // Normalize and extract real parts
316    let norm = 1.0 / (width * height) as f64;
317    let mut result = vec![0.0f64; width * height];
318    for i in 0..data.len() {
319        result[i] = data[i].re as f64 * norm;
320    }
321
322    result
323}
324
325/// Compute magnitude of each spectrum element (returns f32).
326pub fn magnitude_spectrum(spectrum: &Spectrum2D) -> Vec<f32> {
327    spectrum.data.iter().map(|c| {
328        // Use f64 det_hypot for precision, cast result to f32
329        det_hypot(c.re as f64, c.im as f64) as f32
330    }).collect()
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn fft_ifft_roundtrip() {
339        let width = 16;
340        let height = 16;
341        let pixels: Vec<f64> = (0..width * height).map(|i| (i as f64) * 0.1 + 50.0).collect();
342
343        let spectrum = fft2d(&pixels, width, height);
344        let recovered = ifft2d(&spectrum);
345
346        for i in 0..pixels.len() {
347            assert!(
348                (pixels[i] - recovered[i]).abs() < 1e-3,
349                "Mismatch at {i}: expected {}, got {}",
350                pixels[i],
351                recovered[i]
352            );
353        }
354    }
355
356    #[test]
357    fn fft_ifft_roundtrip_non_pow2() {
358        // Test with non-power-of-2 dimensions (Bluestein path)
359        let width = 12;
360        let height = 10;
361        let pixels: Vec<f64> = (0..width * height).map(|i| (i as f64) * 0.3 + 20.0).collect();
362
363        let spectrum = fft2d(&pixels, width, height);
364        let recovered = ifft2d(&spectrum);
365
366        for i in 0..pixels.len() {
367            assert!(
368                (pixels[i] - recovered[i]).abs() < 0.1,
369                "Mismatch at {i}: expected {}, got {}",
370                pixels[i],
371                recovered[i]
372            );
373        }
374    }
375
376    #[test]
377    fn parseval_theorem() {
378        let width = 8;
379        let height = 8;
380        let pixels: Vec<f64> = (0..width * height).map(|i| ((i * 7 + 3) % 256) as f64).collect();
381
382        let spatial_energy: f64 = pixels.iter().map(|v| v * v).sum();
383
384        let spectrum = fft2d(&pixels, width, height);
385        let freq_energy: f64 = spectrum.data.iter().map(|c| {
386            let re = c.re as f64;
387            let im = c.im as f64;
388            re * re + im * im
389        }).sum();
390
391        let n = (width * height) as f64;
392        // Relaxed tolerance for f32 spectrum
393        assert!(
394            (spatial_energy - freq_energy / n).abs() < 10.0,
395            "Parseval's theorem violated: spatial={spatial_energy}, freq/N={}", freq_energy / n
396        );
397    }
398
399    #[test]
400    fn dc_component_is_sum() {
401        let width = 4;
402        let height = 4;
403        let pixels = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
404                          9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
405
406        let spectrum = fft2d(&pixels, width, height);
407
408        let expected_dc: f64 = pixels.iter().sum();
409        // Relaxed tolerance for f32
410        assert!(
411            (spectrum.data[0].re as f64 - expected_dc).abs() < 0.1,
412            "DC component should be sum of all pixels: expected {expected_dc}, got {}",
413            spectrum.data[0].re
414        );
415        assert!((spectrum.data[0].im as f64).abs() < 0.1);
416    }
417
418    #[test]
419    fn fft1d_basic() {
420        // FFT of [1, 0, 0, 0] should be [1, 1, 1, 1]
421        let input = vec![
422            Complex32::new(1.0, 0.0),
423            Complex32::new(0.0, 0.0),
424            Complex32::new(0.0, 0.0),
425            Complex32::new(0.0, 0.0),
426        ];
427        let output = fft1d_f32(&input);
428        for k in 0..4 {
429            assert!((output[k].re - 1.0).abs() < 1e-5, "Re[{k}]={}", output[k].re);
430            assert!(output[k].im.abs() < 1e-5, "Im[{k}]={}", output[k].im);
431        }
432    }
433
434    #[test]
435    fn bluestein_matches_radix2() {
436        // For power-of-2 size, Bluestein plan should give same result as radix-2
437        let n = 8;
438        let input: Vec<Complex32> = (0..n).map(|i| Complex32::new((i * 3 + 1) as f32, (i * 2) as f32)).collect();
439
440        let mut radix2_buf = input.clone();
441        fft_radix2_f32(&mut radix2_buf, -1.0);
442
443        // Test via plan
444        let _plan = BluesteinPlan::new(n, -1.0);
445        // For power-of-2, fft1d_f32_with_plan uses radix-2 directly, not the plan.
446        // Test the plan directly on a non-power-of-2 size instead.
447        let n2 = 7;
448        let input2: Vec<Complex32> = (0..n2).map(|i| Complex32::new((i * 3 + 1) as f32, (i * 2) as f32)).collect();
449        let plan2 = BluesteinPlan::new(n2, -1.0);
450        let result_plan = plan2.execute(&input2);
451        let result_direct = fft1d_f32(&input2);
452        for k in 0..n2 {
453            assert!(
454                (result_plan[k].re - result_direct[k].re).abs() < 1e-3 &&
455                (result_plan[k].im - result_direct[k].im).abs() < 1e-3,
456                "Plan vs direct mismatch at {k}: plan={}, direct={}",
457                result_plan[k], result_direct[k]
458            );
459        }
460
461        // Also verify radix-2 results haven't changed (basic sanity)
462        let result_r2 = fft1d_f32(&input);
463        for k in 0..n {
464            assert!(
465                (radix2_buf[k].re - result_r2[k].re).abs() < 1e-3 &&
466                (radix2_buf[k].im - result_r2[k].im).abs() < 1e-3,
467                "Mismatch at {k}: radix2={}, fft1d={}",
468                radix2_buf[k], result_r2[k]
469            );
470        }
471    }
472
473    #[test]
474    fn bluestein_plan_reuse() {
475        // Verify that reusing a BluesteinPlan gives the same result each time
476        let n = 13; // non-power-of-2
477        let plan = BluesteinPlan::new(n, -1.0);
478
479        let input1: Vec<Complex32> = (0..n).map(|i| Complex32::new(i as f32, 0.0)).collect();
480        let input2: Vec<Complex32> = (0..n).map(|i| Complex32::new(0.0, i as f32)).collect();
481
482        let r1a = plan.execute(&input1);
483        let r2 = plan.execute(&input2);
484        let r1b = plan.execute(&input1);
485
486        for k in 0..n {
487            assert!(
488                (r1a[k].re - r1b[k].re).abs() < 1e-5 &&
489                (r1a[k].im - r1b[k].im).abs() < 1e-5,
490                "Plan reuse gave different results at {k}: first={}, second={}",
491                r1a[k], r1b[k]
492            );
493        }
494        // Just verify r2 didn't corrupt anything (no specific value check needed)
495        assert_eq!(r2.len(), n);
496    }
497}