Skip to main content

gpu_fft/
fft.rs

1use cubecl::prelude::*;
2
3use crate::butterfly::{
4    bit_reverse,
5    butterfly_inner, butterfly_inner_batch,
6    butterfly_stage, butterfly_stage_batch,
7    butterfly_stage_radix4, butterfly_stage_radix4_batch,
8};
9use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
10
11/// Computes the Cooley-Tukey radix-2 DIT FFT of `input`.
12///
13/// If `input.len()` is not a power of two the signal is **zero-padded** to the
14/// next power of two. Both returned vectors have length `input.len().next_power_of_two()`.
15///
16/// ### Launch strategy
17///
18/// All stages where `half_stride < TILE_SIZE / 2` are fused into a **single**
19/// `butterfly_inner` dispatch using workgroup shared memory — eliminating the
20/// per-stage kernel-launch overhead that dominates small-N performance.
21/// The remaining outer stages use `butterfly_stage_radix4` (two radix-2 stages
22/// per dispatch) where possible, falling back to `butterfly_stage` for a single
23/// trailing stage when the outer-stage count is odd.
24///
25/// | N        | Inner | Outer        | Total |
26/// |----------|------:|-------------:|------:|
27/// | ≤ 1 024  | 1     | 0            | **1** |
28/// | 4 096    | 1     | 1 (r4)       | **2** |
29/// | 65 536   | 1     | 3 (r4)       | **4** |
30///
31/// # Example
32///
33/// ```ignore
34/// use cubecl::wgpu::WgpuRuntime;
35/// use gpu_fft::fft::fft;
36/// let (real, imag) = fft::<WgpuRuntime>(&Default::default(), &[1.0f32, 0.0, 0.0, 0.0]);
37/// ```
38#[must_use]
39pub fn fft<R: Runtime>(device: &R::Device, input: &[f32]) -> (Vec<f32>, Vec<f32>) {
40    let n_orig = input.len();
41    let n = n_orig.next_power_of_two();
42
43    // Edge case: trivial transform for zero or single element.
44    if n <= 1 {
45        let mut real = vec![0.0f32; n];
46        if n == 1 && n_orig == 1 {
47            real[0] = input[0];
48        }
49        return (real, vec![0.0f32; n]);
50    }
51
52    let m = n.ilog2() as usize;
53
54    // ── Bit-reverse permute the input on the CPU (O(N)) ───────────────────────
55    let mut real = vec![0.0f32; n];
56    for (i, &v) in input.iter().enumerate() {
57        real[bit_reverse(i, m as u32)] = v;
58    }
59    let imag = vec![0.0f32; n];
60
61    let client = R::client(device);
62    let real_handle = client.create_from_slice(f32::as_bytes(&real));
63    let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
64
65    // ── Inner stages: fused into one launch via shared memory ─────────────────
66    // inner_stages = how many stages fit inside a TILE_SIZE-element workgroup tile.
67    // tile         = actual tile size (≤ TILE_SIZE; equals N when N < TILE_SIZE).
68    let inner_stages = m.min(TILE_BITS);
69    let tile         = TILE_SIZE.min(n);     // comptime specialisation value
70    let num_tiles    = (n / TILE_SIZE).max(1) as u32;
71    let wg_threads   = (n / 2).min(TILE_SIZE / 2) as u32; // threads per workgroup
72
73    unsafe {
74        butterfly_inner::launch::<f32, R>(
75            &client,
76            CubeCount::Static(num_tiles, 1, 1),
77            CubeDim::new_1d(wg_threads),
78            ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
79            ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
80            tile,         // comptime
81            inner_stages, // comptime
82            true,         // comptime — forward FFT
83        )
84        .expect("FFT inner (shared-memory) launch failed")
85    };
86
87    // ── Outer stages: radix-4 pairs, then one radix-2 if the count is odd ────
88    // Two consecutive radix-2 stages (strides q and 2q) are fused into a single
89    // radix-4 dispatch, halving the number of kernel launches for large N.
90    let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
91    let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
92
93    let mut s = inner_stages;
94    while s + 1 < m {
95        let q = 1_usize << s;   // quarter-stride = half-stride of the lower stage
96        unsafe {
97            butterfly_stage_radix4::launch::<f32, R>(
98                &client,
99                CubeCount::Static(outer_wg_r4, 1, 1),
100                CubeDim::new_1d(WORKGROUP_SIZE),
101                ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
102                ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
103                n,    // comptime
104                q,    // comptime — specialises kernel for this stage pair
105                true, // comptime — forward FFT
106            )
107            .expect("FFT outer radix-4 butterfly launch failed")
108        };
109        s += 2;
110    }
111    // Trailing radix-2 stage when (m − inner_stages) is odd.
112    if s < m {
113        let hs = 1_usize << s;
114        unsafe {
115            butterfly_stage::launch::<f32, R>(
116                &client,
117                CubeCount::Static(outer_wg_r2, 1, 1),
118                CubeDim::new_1d(WORKGROUP_SIZE),
119                ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
120                ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
121                n,    // comptime
122                hs,   // comptime
123                true, // comptime — forward FFT
124            )
125            .expect("FFT outer radix-2 trailing butterfly launch failed")
126        };
127    }
128
129    let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
130    let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
131
132    (real_out, imag_out)
133}
134
135/// Computes the Cooley-Tukey radix-2 DIT FFT for a **batch** of signals in a
136/// single GPU pass.
137///
138/// All signals are zero-padded to the same length: the next power-of-two of the
139/// **longest** signal.  Every other signal is padded to that same length so the
140/// batch forms a rectangular `batch_size × n` matrix in GPU memory.
141///
142/// Returns one `(real, imag)` pair per input signal, each of length `n`.
143///
144/// ### Performance
145///
146/// All `batch_size` signals are processed simultaneously using a 2-D kernel
147/// dispatch — the Y-dimension of the grid indexes the signal and the X-dimension
148/// covers butterfly pairs within a signal.  This amortises kernel-launch overhead
149/// over the entire batch.
150///
151/// ### Panics
152///
153/// Does not panic.  An empty batch returns an empty `Vec`.
154///
155/// # Example
156///
157/// ```ignore
158/// use cubecl::wgpu::WgpuRuntime;
159/// use gpu_fft::fft::fft_batch;
160/// let signals = vec![vec![1.0f32, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
161/// let results = fft_batch::<WgpuRuntime>(&Default::default(), &signals);
162/// assert_eq!(results.len(), 2);
163/// ```
164#[must_use]
165pub fn fft_batch<R: Runtime>(device: &R::Device, signals: &[Vec<f32>]) -> Vec<(Vec<f32>, Vec<f32>)> {
166    if signals.is_empty() {
167        return Vec::new();
168    }
169
170    let batch_size = signals.len();
171    let max_len    = signals.iter().map(|s| s.len()).max().unwrap_or(0);
172
173    // Edge case: all signals are empty or length 0/1.
174    let n_raw = max_len.next_power_of_two().max(1);
175    if n_raw <= 1 {
176        return signals
177            .iter()
178            .map(|s| {
179                let mut real = vec![0.0f32; n_raw];
180                if n_raw == 1 && !s.is_empty() {
181                    real[0] = s[0];
182                }
183                (real, vec![0.0f32; n_raw])
184            })
185            .collect();
186    }
187
188    let n = n_raw;
189    let m = n.ilog2() as usize;
190
191    // ── Bit-reverse permute each signal on the CPU and pack flat ──────────────
192    let mut real_flat = vec![0.0f32; batch_size * n];
193    let     imag_flat = vec![0.0f32; batch_size * n];
194
195    for (b, signal) in signals.iter().enumerate() {
196        let base = b * n;
197        for (i, &v) in signal.iter().enumerate() {
198            real_flat[base + bit_reverse(i, m as u32)] = v;
199        }
200    }
201
202    let client  = R::client(device);
203    let total   = batch_size * n;
204    let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
205    let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
206
207    // ── Inner stages: one flat 1D dispatch covers all tiles in all signals ──────
208    // Each workgroup = one tile (tile/2 threads). Total workgroups = tiles_per_signal * batch_size.
209    let inner_stages     = m.min(TILE_BITS);
210    let tile             = TILE_SIZE.min(n);
211    let tiles_per_signal = (n / tile).max(1);
212    let wg_count         = (tiles_per_signal * batch_size) as u32;
213    let wg_threads       = (tile / 2) as u32;
214
215    unsafe {
216        butterfly_inner_batch::launch::<f32, R>(
217            &client,
218            CubeCount::Static(wg_count, 1, 1),
219            CubeDim::new_1d(wg_threads),
220            ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
221            ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
222            n,            // comptime — per-signal length
223            tile,         // comptime — tile size
224            inner_stages, // comptime — stages fused per tile
225            true,         // comptime — forward FFT
226        )
227        .expect("FFT batch inner (shared-memory) launch failed")
228    };
229
230    // ── Outer stages: radix-4 pairs, then one radix-2 if the count is odd ────
231    let total_groups_r4 = batch_size * (n / 4);
232    let total_pairs_r2  = batch_size * (n / 2);
233    let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
234    let outer_wg_r2 = ((total_pairs_r2  as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
235
236    let mut s = inner_stages;
237    while s + 1 < m {
238        let q = 1_usize << s;
239        unsafe {
240            butterfly_stage_radix4_batch::launch::<f32, R>(
241                &client,
242                CubeCount::Static(outer_wg_r4, 1, 1),
243                CubeDim::new_1d(WORKGROUP_SIZE),
244                ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
245                ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
246                n,          // comptime
247                q,          // comptime
248                batch_size, // comptime
249                true,       // comptime — forward FFT
250            )
251            .expect("FFT batch outer radix-4 butterfly launch failed")
252        };
253        s += 2;
254    }
255    if s < m {
256        let hs = 1_usize << s;
257        unsafe {
258            butterfly_stage_batch::launch::<f32, R>(
259                &client,
260                CubeCount::Static(outer_wg_r2, 1, 1),
261                CubeDim::new_1d(WORKGROUP_SIZE),
262                ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
263                ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
264                n,          // comptime
265                hs,         // comptime
266                batch_size, // comptime
267                true,       // comptime — forward FFT
268            )
269            .expect("FFT batch outer radix-2 trailing butterfly launch failed")
270        };
271    }
272
273    let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
274    let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
275
276    // ── Unpack flat buffer into per-signal pairs ──────────────────────────────
277    (0..batch_size)
278        .map(|b| {
279            let start = b * n;
280            let end   = start + n;
281            (real_out[start..end].to_vec(), imag_out[start..end].to_vec())
282        })
283        .collect()
284}