gpu_fft/lib.rs
1pub mod fft;
2pub mod ifft;
3pub mod psd;
4pub mod utils;
5
6// Shared Cooley-Tukey butterfly kernel and helpers; not part of the public API.
7pub(crate) mod butterfly;
8// Work-in-progress precomputed-twiddle path; not yet wired into the public API.
9#[allow(dead_code)]
10pub(crate) mod twiddles;
11
12// ── Optional Apple Silicon backend ───────────────────────────────────────────
13
14/// MLX backend wrapping Apple's MLX framework FFT.
15///
16/// Calls MLX's GPU FFT implementation via the MLX-C API.
17#[cfg(all(target_os = "macos", feature = "mlx"))]
18pub mod mlx;
19
20// ── Runtime backend selection ────────────────────────────────────────────────
21
22/// Available FFT backends, selected at runtime.
23///
24/// Only variants whose corresponding feature flag is enabled will be present.
25///
26/// # Example
27///
28/// ```no_run
29/// use gpu_fft::{Backend, fft_with};
30/// let (real, imag) = fft_with(&[1.0, 0.0, 0.0, 0.0], Backend::Wgpu);
31/// ```
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Backend {
34 /// CubeCL → WGSL → Metal/Vulkan/DX12 (cross-platform).
35 #[cfg(feature = "wgpu")]
36 Wgpu,
37 /// CubeCL → CUDA (NVIDIA only).
38 #[cfg(feature = "cuda")]
39 Cuda,
40 /// Apple MLX FFT (Apple Silicon).
41 #[cfg(all(target_os = "macos", feature = "mlx"))]
42 Mlx,
43}
44
45/// Returns a list of all backends that were compiled into this build.
46///
47/// Useful for CLI tools, logging, or dynamically selecting a backend.
48///
49/// # Example
50///
51/// ```no_run
52/// for b in gpu_fft::available_backends() {
53/// println!("{b:?}");
54/// }
55/// ```
56#[must_use]
57pub fn available_backends() -> Vec<Backend> {
58 vec![
59 #[cfg(feature = "wgpu")]
60 Backend::Wgpu,
61 #[cfg(feature = "cuda")]
62 Backend::Cuda,
63 #[cfg(all(target_os = "macos", feature = "mlx"))]
64 Backend::Mlx,
65 ]
66}
67
68/// Forward FFT using the specified backend.
69///
70/// Same semantics as [`fft`]: zero-pads to next power of two, returns
71/// `(real, imag)` of length `n.next_power_of_two()`.
72#[must_use]
73pub fn fft_with(input: &[f32], backend: Backend) -> (Vec<f32>, Vec<f32>) {
74 match backend {
75 #[cfg(feature = "wgpu")]
76 Backend::Wgpu => fft::fft::<cubecl::wgpu::WgpuRuntime>(&Default::default(), input),
77 #[cfg(feature = "cuda")]
78 Backend::Cuda => fft::fft::<cubecl::cuda::CudaRuntime>(&Default::default(), input),
79 #[cfg(all(target_os = "macos", feature = "mlx"))]
80 Backend::Mlx => mlx::fft::fft(input),
81 }
82}
83
84/// Inverse FFT using the specified backend.
85///
86/// Same semantics as [`ifft`]: returns `Vec<f32>` of length `2*n` where
87/// `[0..n]` is real and `[n..2n]` is imaginary.
88#[must_use]
89pub fn ifft_with(input_real: &[f32], input_imag: &[f32], backend: Backend) -> Vec<f32> {
90 match backend {
91 #[cfg(feature = "wgpu")]
92 Backend::Wgpu => ifft::ifft::<cubecl::wgpu::WgpuRuntime>(&Default::default(), input_real, input_imag),
93 #[cfg(feature = "cuda")]
94 Backend::Cuda => ifft::ifft::<cubecl::cuda::CudaRuntime>(&Default::default(), input_real, input_imag),
95 #[cfg(all(target_os = "macos", feature = "mlx"))]
96 Backend::Mlx => mlx::fft::ifft(input_real, input_imag),
97 }
98}
99
100// 1024 threads per workgroup saturates most desktop GPUs and is the maximum
101// allowed by Metal / Vulkan / WebGPU on typical hardware.
102pub(crate) const WORKGROUP_SIZE: u32 = 1024;
103
104// Shared-memory tile for the inner (fused) butterfly kernel.
105// Each workgroup loads TILE_SIZE elements into two SharedMemory<f32> arrays:
106// 2 × TILE_SIZE × 4 bytes = 8 192 bytes < 16 384 byte WebGPU minimum.
107// TILE_THREADS = TILE_SIZE / 2 = the number of threads per workgroup in the
108// inner kernel (one thread per butterfly pair).
109// TILE_BITS = log₂(TILE_SIZE) = the number of stages that fit inside one tile.
110pub(crate) const TILE_SIZE: usize = 1024;
111pub(crate) const TILE_BITS: usize = 10; // log₂(TILE_SIZE) = log₂(1024)
112
113#[cfg(feature = "wgpu")]
114type Runtime = cubecl::wgpu::WgpuRuntime;
115
116#[cfg(feature = "cuda")]
117type Runtime = cubecl::cuda::CudaRuntime;
118
119/// Computes the Cooley-Tukey radix-2 FFT of a real-valued signal.
120///
121/// Runs in **O(N log₂ N)** on the GPU using `log₂ N` butterfly-stage kernel
122/// dispatches of N/2 threads each.
123///
124/// If `input.len()` is not a power of two the signal is zero-padded to the
125/// next power of two before the transform. Both returned vectors have length
126/// `input.len().next_power_of_two()`.
127///
128/// # Example
129///
130/// ```no_run
131/// use gpu_fft::fft;
132/// let input = vec![0.0f32, 1.0, 0.0, 0.0];
133/// let (real, imag) = fft(&input);
134/// assert_eq!(real.len(), 4); // already a power of two
135/// ```
136#[must_use]
137pub fn fft(input: &[f32]) -> (Vec<f32>, Vec<f32>) {
138 fft::fft::<Runtime>(&Default::default(), input)
139}
140
141/// Computes the Cooley-Tukey radix-2 FFT of a **batch** of real-valued signals
142/// in a single GPU pass.
143///
144/// All signals are zero-padded to the next power-of-two of the **longest** signal
145/// so they share a common length `n`. The batch is processed with a 2-D kernel
146/// dispatch — the Y-dimension selects the signal, and the X-dimension covers
147/// butterfly pairs within a signal.
148///
149/// Returns one `(real, imag)` pair per input signal, each of length `n`.
150///
151/// # Example
152///
153/// ```no_run
154/// use gpu_fft::fft_batch;
155/// let signals = vec![
156/// vec![1.0f32, 0.0, 0.0, 0.0], // impulse → all-ones spectrum
157/// vec![1.0f32, 1.0, 1.0, 1.0], // DC → [4, 0, 0, 0]
158/// ];
159/// let results = fft_batch(&signals);
160/// assert_eq!(results.len(), 2);
161/// ```
162#[must_use]
163pub fn fft_batch(signals: &[Vec<f32>]) -> Vec<(Vec<f32>, Vec<f32>)> {
164 fft::fft_batch::<Runtime>(&Default::default(), signals)
165}
166
167/// Computes the Cooley-Tukey radix-2 IFFT of a complex spectrum.
168///
169/// Runs in **O(N log₂ N)** using `log₂ N` butterfly-stage kernels with
170/// positive twiddle factors, followed by a CPU-side 1/N scaling pass.
171///
172/// Both slices must have the **same power-of-two length** — i.e. pass the
173/// direct output of [`fft`] unchanged.
174///
175/// # Returns
176///
177/// A `Vec<f32>` of length `2 * N`:
178/// - `output[0..N]` — reconstructed real signal
179/// - `output[N..2N]` — reconstructed imaginary signal (≈ 0 for real inputs)
180///
181/// # Example
182///
183/// ```no_run
184/// use gpu_fft::ifft;
185/// let real = vec![0.0f32, 1.0, 0.0, 0.0];
186/// let imag = vec![0.0f32, 0.0, 0.0, 0.0];
187/// let output = ifft(&real, &imag);
188/// let reconstructed = &output[..4]; // real part
189/// ```
190#[must_use]
191pub fn ifft(input_real: &[f32], input_imag: &[f32]) -> Vec<f32> {
192 ifft::ifft::<Runtime>(&Default::default(), input_real, input_imag)
193}
194
195/// Computes the Cooley-Tukey radix-2 IFFT for a **batch** of complex spectra
196/// in a single GPU pass.
197///
198/// Each element of `signals` is a `(real, imag)` pair — the direct output of
199/// [`fft_batch`]. All pairs must share the **same power-of-two length**.
200///
201/// Returns one `Vec<f32>` per input signal, each of length `2 * n`:
202/// - `[0..n]` — reconstructed real signal
203/// - `[n..2n]` — reconstructed imaginary signal (≈ 0 for real-valued inputs)
204///
205/// # Example
206///
207/// ```no_run
208/// use gpu_fft::{fft_batch, ifft_batch};
209/// let signals = vec![vec![1.0f32, 2.0, 3.0, 4.0]];
210/// let spectra = fft_batch(&signals);
211/// let recovered = ifft_batch(&spectra);
212/// ```
213#[must_use]
214pub fn ifft_batch(signals: &[(Vec<f32>, Vec<f32>)]) -> Vec<Vec<f32>> {
215 ifft::ifft_batch::<Runtime>(&Default::default(), signals)
216}
217
218// ── MLX convenience wrappers ─────────────────────────────────────────────────
219
220/// Forward FFT via Apple's MLX framework.
221#[cfg(all(target_os = "macos", feature = "mlx"))]
222#[must_use]
223pub fn fft_mlx(input: &[f32]) -> (Vec<f32>, Vec<f32>) {
224 mlx::fft::fft(input)
225}
226
227/// Inverse FFT via Apple's MLX framework.
228#[cfg(all(target_os = "macos", feature = "mlx"))]
229#[must_use]
230pub fn ifft_mlx(input_real: &[f32], input_imag: &[f32]) -> Vec<f32> {
231 mlx::fft::ifft(input_real, input_imag)
232}