gpu_fft/lib.rs
1pub mod fft;
2pub mod ifft;
3pub mod psd;
4pub mod utils;
5
6// Shared Cooley-Tukey butterfly kernel and helpers; not part of the public API.
7pub(crate) mod butterfly;
8// Work-in-progress precomputed-twiddle path; not yet wired into the public API.
9#[allow(dead_code)]
10pub(crate) mod twiddles;
11
12// 1024 threads per workgroup saturates most desktop GPUs and is the maximum
13// allowed by Metal / Vulkan / WebGPU on typical hardware.
14pub(crate) const WORKGROUP_SIZE: u32 = 1024;
15
16// Shared-memory tile for the inner (fused) butterfly kernel.
17// Each workgroup loads TILE_SIZE elements into two SharedMemory<f32> arrays:
18// 2 × TILE_SIZE × 4 bytes = 8 192 bytes < 16 384 byte WebGPU minimum.
19// TILE_THREADS = TILE_SIZE / 2 = the number of threads per workgroup in the
20// inner kernel (one thread per butterfly pair).
21// TILE_BITS = log₂(TILE_SIZE) = the number of stages that fit inside one tile.
22pub(crate) const TILE_SIZE: usize = 1024;
23pub(crate) const TILE_BITS: usize = 10; // log₂(TILE_SIZE) = log₂(1024)
24
25#[cfg(feature = "wgpu")]
26type Runtime = cubecl::wgpu::WgpuRuntime;
27
28#[cfg(feature = "cuda")]
29type Runtime = cubecl::cuda::CudaRuntime;
30
31/// Computes the Cooley-Tukey radix-2 FFT of a real-valued signal.
32///
33/// Runs in **O(N log₂ N)** on the GPU using `log₂ N` butterfly-stage kernel
34/// dispatches of N/2 threads each.
35///
36/// If `input.len()` is not a power of two the signal is zero-padded to the
37/// next power of two before the transform. Both returned vectors have length
38/// `input.len().next_power_of_two()`.
39///
40/// # Example
41///
42/// ```no_run
43/// use gpu_fft::fft;
44/// let input = vec![0.0f32, 1.0, 0.0, 0.0];
45/// let (real, imag) = fft(&input);
46/// assert_eq!(real.len(), 4); // already a power of two
47/// ```
48#[must_use]
49pub fn fft(input: &[f32]) -> (Vec<f32>, Vec<f32>) {
50 fft::fft::<Runtime>(&Default::default(), input)
51}
52
53/// Computes the Cooley-Tukey radix-2 FFT of a **batch** of real-valued signals
54/// in a single GPU pass.
55///
56/// All signals are zero-padded to the next power-of-two of the **longest** signal
57/// so they share a common length `n`. The batch is processed with a 2-D kernel
58/// dispatch — the Y-dimension selects the signal, and the X-dimension covers
59/// butterfly pairs within a signal.
60///
61/// Returns one `(real, imag)` pair per input signal, each of length `n`.
62///
63/// # Example
64///
65/// ```no_run
66/// use gpu_fft::fft_batch;
67/// let signals = vec![
68/// vec![1.0f32, 0.0, 0.0, 0.0], // impulse → all-ones spectrum
69/// vec![1.0f32, 1.0, 1.0, 1.0], // DC → [4, 0, 0, 0]
70/// ];
71/// let results = fft_batch(&signals);
72/// assert_eq!(results.len(), 2);
73/// ```
74#[must_use]
75pub fn fft_batch(signals: &[Vec<f32>]) -> Vec<(Vec<f32>, Vec<f32>)> {
76 fft::fft_batch::<Runtime>(&Default::default(), signals)
77}
78
79/// Computes the Cooley-Tukey radix-2 IFFT of a complex spectrum.
80///
81/// Runs in **O(N log₂ N)** using `log₂ N` butterfly-stage kernels with
82/// positive twiddle factors, followed by a CPU-side 1/N scaling pass.
83///
84/// Both slices must have the **same power-of-two length** — i.e. pass the
85/// direct output of [`fft`] unchanged.
86///
87/// # Returns
88///
89/// A `Vec<f32>` of length `2 * N`:
90/// - `output[0..N]` — reconstructed real signal
91/// - `output[N..2N]` — reconstructed imaginary signal (≈ 0 for real inputs)
92///
93/// # Example
94///
95/// ```no_run
96/// use gpu_fft::ifft;
97/// let real = vec![0.0f32, 1.0, 0.0, 0.0];
98/// let imag = vec![0.0f32, 0.0, 0.0, 0.0];
99/// let output = ifft(&real, &imag);
100/// let reconstructed = &output[..4]; // real part
101/// ```
102#[must_use]
103pub fn ifft(input_real: &[f32], input_imag: &[f32]) -> Vec<f32> {
104 ifft::ifft::<Runtime>(&Default::default(), input_real, input_imag)
105}
106
107/// Computes the Cooley-Tukey radix-2 IFFT for a **batch** of complex spectra
108/// in a single GPU pass.
109///
110/// Each element of `signals` is a `(real, imag)` pair — the direct output of
111/// [`fft_batch`]. All pairs must share the **same power-of-two length**.
112///
113/// Returns one `Vec<f32>` per input signal, each of length `2 * n`:
114/// - `[0..n]` — reconstructed real signal
115/// - `[n..2n]` — reconstructed imaginary signal (≈ 0 for real-valued inputs)
116///
117/// # Example
118///
119/// ```no_run
120/// use gpu_fft::{fft_batch, ifft_batch};
121/// let signals = vec![vec![1.0f32, 2.0, 3.0, 4.0]];
122/// let spectra = fft_batch(&signals);
123/// let recovered = ifft_batch(&spectra);
124/// ```
125#[must_use]
126pub fn ifft_batch(signals: &[(Vec<f32>, Vec<f32>)]) -> Vec<Vec<f32>> {
127 ifft::ifft_batch::<Runtime>(&Default::default(), signals)
128}
129