gpu_fft/lib.rs
1pub mod fft;
2pub mod ifft;
3pub mod psd;
4pub mod utils;
5
6// Shared Cooley-Tukey butterfly kernel and helpers; not part of the public API.
7pub(crate) mod butterfly;
8// Work-in-progress precomputed-twiddle path; not yet wired into the public API.
9#[allow(dead_code)]
10pub(crate) mod twiddles;
11
12// 1024 threads per workgroup saturates most desktop GPUs and is the maximum
13// allowed by Metal / Vulkan / WebGPU on typical hardware.
14pub(crate) const WORKGROUP_SIZE: u32 = 1024;
15
16// Shared-memory tile for the inner (fused) butterfly kernel.
17// Each workgroup loads TILE_SIZE elements into two SharedMemory<f32> arrays:
18// 2 × TILE_SIZE × 4 bytes = 8 192 bytes < 16 384 byte WebGPU minimum.
19// TILE_THREADS = TILE_SIZE / 2 = the number of threads per workgroup in the
20// inner kernel (one thread per butterfly pair).
21// TILE_BITS = log₂(TILE_SIZE) = the number of stages that fit inside one tile.
22pub(crate) const TILE_SIZE: usize = 1024;
23pub(crate) const TILE_BITS: usize = 10; // log₂(TILE_SIZE) = log₂(1024)
24
25#[cfg(feature = "wgpu")]
26type Runtime = cubecl::wgpu::WgpuRuntime;
27
28#[cfg(feature = "cuda")]
29type Runtime = cubecl::cuda::CudaRuntime;
30
31/// Computes the Cooley-Tukey radix-2 FFT of a real-valued signal.
32///
33/// Runs in **O(N log₂ N)** on the GPU using `log₂ N` butterfly-stage kernel
34/// dispatches of N/2 threads each.
35///
36/// If `input.len()` is not a power of two the signal is zero-padded to the
37/// next power of two before the transform. Both returned vectors have length
38/// `input.len().next_power_of_two()`.
39///
40/// # Example
41///
42/// ```no_run
43/// use gpu_fft::fft;
44/// let input = vec![0.0f32, 1.0, 0.0, 0.0];
45/// let (real, imag) = fft(&input);
46/// assert_eq!(real.len(), 4); // already a power of two
47/// ```
48#[must_use]
49pub fn fft(input: &[f32]) -> (Vec<f32>, Vec<f32>) {
50 fft::fft::<Runtime>(&Default::default(), input)
51}
52
53/// Computes the Cooley-Tukey radix-2 IFFT of a complex spectrum.
54///
55/// Runs in **O(N log₂ N)** using `log₂ N` butterfly-stage kernels with
56/// positive twiddle factors, followed by a CPU-side 1/N scaling pass.
57///
58/// Both slices must have the **same power-of-two length** — i.e. pass the
59/// direct output of [`fft`] unchanged.
60///
61/// # Returns
62///
63/// A `Vec<f32>` of length `2 * N`:
64/// - `output[0..N]` — reconstructed real signal
65/// - `output[N..2N]` — reconstructed imaginary signal (≈ 0 for real inputs)
66///
67/// # Example
68///
69/// ```no_run
70/// use gpu_fft::ifft;
71/// let real = vec![0.0f32, 1.0, 0.0, 0.0];
72/// let imag = vec![0.0f32, 0.0, 0.0, 0.0];
73/// let output = ifft(&real, &imag);
74/// let reconstructed = &output[..4]; // real part
75/// ```
76#[must_use]
77pub fn ifft(input_real: &[f32], input_imag: &[f32]) -> Vec<f32> {
78 ifft::ifft::<Runtime>(&Default::default(), input_real, input_imag)
79}
80