Skip to main content

gpu_fft/
ifft.rs

1use cubecl::prelude::*;
2
3use crate::butterfly::{
4    bit_reverse,
5    butterfly_inner, butterfly_inner_batch,
6    butterfly_stage, butterfly_stage_batch,
7    butterfly_stage_radix4, butterfly_stage_radix4_batch,
8};
9use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
10
11/// Computes the Cooley-Tukey radix-2 DIT IFFT of `(input_real, input_imag)`.
12///
13/// Both slices must have the **same power-of-two length** (the direct output of
14/// [`fft`][crate::fft::fft]). Uses the same inner/outer launch strategy as `fft`:
15/// inner stages are fused into a single shared-memory dispatch; outer stages use
16/// one global-memory dispatch each.  After the butterflies, a CPU-side 1/N divide
17/// is applied.
18///
19/// # Returns
20///
21/// `Vec<f32>` of length `2 * N`:
22/// - `[0..N]`  — reconstructed real signal
23/// - `[N..2N]` — reconstructed imaginary signal (≈ 0 for real-valued inputs)
24///
25/// # Panics
26///
27/// Panics if the slice lengths differ or are not a power of two.
28///
29/// # Example
30///
31/// ```ignore
32/// use cubecl::wgpu::WgpuRuntime;
33/// use gpu_fft::ifft::ifft;
34/// let real = vec![1.0f32, 0.0, 0.0, 0.0];
35/// let imag = vec![0.0f32, 0.0, 0.0, 0.0];
36/// let output = ifft::<WgpuRuntime>(&Default::default(), &real, &imag);
37/// ```
38#[must_use]
39pub fn ifft<R: Runtime>(
40    device: &R::Device,
41    input_real: &[f32],
42    input_imag: &[f32],
43) -> Vec<f32> {
44    assert_eq!(
45        input_real.len(),
46        input_imag.len(),
47        "ifft: real and imag slices must have the same length"
48    );
49    let n = input_real.len();
50    assert!(
51        n.is_power_of_two(),
52        "ifft: input length {n} is not a power of two (pass the direct output of fft)"
53    );
54
55    // Edge case: trivial inverse transform.
56    if n <= 1 {
57        let mut out = input_real.to_vec();
58        out.extend_from_slice(input_imag);
59        return out;
60    }
61
62    let m = n.ilog2() as usize;
63
64    // ── Bit-reverse permute the input on the CPU (O(N)) ───────────────────────
65    let mut real = vec![0.0f32; n];
66    let mut imag = vec![0.0f32; n];
67    for i in 0..n {
68        let j = bit_reverse(i, m as u32);
69        real[j] = input_real[i];
70        imag[j] = input_imag[i];
71    }
72
73    let client = R::client(device);
74    let real_handle = client.create_from_slice(f32::as_bytes(&real));
75    let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
76
77    // ── Inner stages: fused into one launch via shared memory ─────────────────
78    let inner_stages = m.min(TILE_BITS);
79    let tile         = TILE_SIZE.min(n);
80    let num_tiles    = (n / TILE_SIZE).max(1) as u32;
81    let wg_threads   = (n / 2).min(TILE_SIZE / 2) as u32;
82
83    unsafe {
84        butterfly_inner::launch::<f32, R>(
85            &client,
86            CubeCount::Static(num_tiles, 1, 1),
87            CubeDim::new_1d(wg_threads),
88            ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
89            ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
90            tile,         // comptime
91            inner_stages, // comptime
92            false,        // comptime — inverse FFT (positive twiddle)
93        )
94        .expect("IFFT inner (shared-memory) launch failed")
95    };
96
97    // ── Outer stages: radix-4 pairs, then one radix-2 if the count is odd ────
98    let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
99    let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
100
101    let mut s = inner_stages;
102    while s + 1 < m {
103        let q = 1_usize << s;
104        unsafe {
105            butterfly_stage_radix4::launch::<f32, R>(
106                &client,
107                CubeCount::Static(outer_wg_r4, 1, 1),
108                CubeDim::new_1d(WORKGROUP_SIZE),
109                ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
110                ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
111                n,     // comptime
112                q,     // comptime
113                false, // comptime — inverse FFT
114            )
115            .expect("IFFT outer radix-4 butterfly launch failed")
116        };
117        s += 2;
118    }
119    if s < m {
120        let hs = 1_usize << s;
121        unsafe {
122            butterfly_stage::launch::<f32, R>(
123                &client,
124                CubeCount::Static(outer_wg_r2, 1, 1),
125                CubeDim::new_1d(WORKGROUP_SIZE),
126                ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
127                ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
128                n,     // comptime
129                hs,    // comptime
130                false, // comptime — inverse FFT
131            )
132            .expect("IFFT outer radix-2 trailing butterfly launch failed")
133        };
134    }
135
136    // ── Read back and apply 1/N scaling on the CPU (O(N)) ────────────────────
137    let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
138    let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
139
140    let scale = (n as f32).recip();
141    for v in &mut real_out {
142        *v *= scale;
143    }
144    for v in &mut imag_out {
145        *v *= scale;
146    }
147
148    // Return [real[0..n] ++ imag[0..n]] — same layout as before.
149    real_out.extend_from_slice(&imag_out);
150    real_out
151}
152
153/// Computes the Cooley-Tukey radix-2 DIT IFFT for a **batch** of complex spectra
154/// in a single GPU pass.
155///
156/// Each element of `signals` is a `(real, imag)` pair produced by [`fft_batch`]
157/// (or by calling [`fft`][crate::fft::fft] repeatedly).  All pairs must share
158/// the **same power-of-two length** — pass the direct output of
159/// [`fft_batch`][crate::fft::fft_batch] unchanged.
160///
161/// Returns one `Vec<f32>` per input signal, each of length `2 * n`:
162/// - `[0..n]`  — reconstructed real signal
163/// - `[n..2n]` — reconstructed imaginary signal (≈ 0 for real-valued inputs)
164///
165/// ### Panics
166///
167/// Panics if any pair has mismatched lengths, or if the shared length is not a
168/// power of two.  An empty batch returns an empty `Vec`.
169///
170/// # Example
171///
172/// ```ignore
173/// use cubecl::wgpu::WgpuRuntime;
174/// use gpu_fft::{fft::fft_batch, ifft::ifft_batch};
175/// let signals = vec![vec![1.0f32, 2.0, 3.0, 4.0]];
176/// let spectra = fft_batch::<WgpuRuntime>(&Default::default(), &signals);
177/// let pairs: Vec<_> = spectra.into_iter().collect();
178/// let recovered = ifft_batch::<WgpuRuntime>(&Default::default(), &pairs);
179/// ```
180#[must_use]
181pub fn ifft_batch<R: Runtime>(
182    device: &R::Device,
183    signals: &[(Vec<f32>, Vec<f32>)],
184) -> Vec<Vec<f32>> {
185    if signals.is_empty() {
186        return Vec::new();
187    }
188
189    let batch_size = signals.len();
190
191    // Validate: all pairs must have identical power-of-two lengths.
192    let n = signals[0].0.len();
193    for (b, (re, im)) in signals.iter().enumerate() {
194        assert_eq!(
195            re.len(), im.len(),
196            "ifft_batch: signal {b}: real and imag slices have different lengths"
197        );
198        assert_eq!(
199            re.len(), n,
200            "ifft_batch: all signals must have the same length (expected {n}, got {})", re.len()
201        );
202    }
203    assert!(
204        n.is_power_of_two(),
205        "ifft_batch: signal length {n} is not a power of two"
206    );
207
208    // Edge case: trivial inverse transform.
209    if n <= 1 {
210        return signals
211            .iter()
212            .map(|(re, im)| {
213                let mut out = re.clone();
214                out.extend_from_slice(im);
215                out
216            })
217            .collect();
218    }
219
220    let m = n.ilog2() as usize;
221
222    // ── Bit-reverse permute each signal on the CPU and pack flat ──────────────
223    let mut real_flat = vec![0.0f32; batch_size * n];
224    let mut imag_flat = vec![0.0f32; batch_size * n];
225
226    for (b, (input_real, input_imag)) in signals.iter().enumerate() {
227        let base = b * n;
228        for i in 0..n {
229            let j = bit_reverse(i, m as u32);
230            real_flat[base + j] = input_real[i];
231            imag_flat[base + j] = input_imag[i];
232        }
233    }
234
235    let client  = R::client(device);
236    let total   = batch_size * n;
237    let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
238    let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
239
240    // ── Inner stages: one flat 1D dispatch covers all tiles in all signals ──────
241    let inner_stages     = m.min(TILE_BITS);
242    let tile             = TILE_SIZE.min(n);
243    let tiles_per_signal = (n / tile).max(1);
244    let wg_count         = (tiles_per_signal * batch_size) as u32;
245    let wg_threads       = (tile / 2) as u32;
246
247    unsafe {
248        butterfly_inner_batch::launch::<f32, R>(
249            &client,
250            CubeCount::Static(wg_count, 1, 1),
251            CubeDim::new_1d(wg_threads),
252            ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
253            ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
254            n,            // comptime — per-signal length
255            tile,         // comptime — tile size
256            inner_stages, // comptime — stages fused per tile
257            false,        // comptime — inverse FFT
258        )
259        .expect("IFFT batch inner (shared-memory) launch failed")
260    };
261
262    // ── Outer stages: radix-4 pairs, then one radix-2 if the count is odd ────
263    let total_groups_r4 = batch_size * (n / 4);
264    let total_pairs_r2  = batch_size * (n / 2);
265    let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
266    let outer_wg_r2 = ((total_pairs_r2  as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
267
268    let mut s = inner_stages;
269    while s + 1 < m {
270        let q = 1_usize << s;
271        unsafe {
272            butterfly_stage_radix4_batch::launch::<f32, R>(
273                &client,
274                CubeCount::Static(outer_wg_r4, 1, 1),
275                CubeDim::new_1d(WORKGROUP_SIZE),
276                ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
277                ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
278                n,          // comptime
279                q,          // comptime
280                batch_size, // comptime
281                false,      // comptime — inverse FFT
282            )
283            .expect("IFFT batch outer radix-4 butterfly launch failed")
284        };
285        s += 2;
286    }
287    if s < m {
288        let hs = 1_usize << s;
289        unsafe {
290            butterfly_stage_batch::launch::<f32, R>(
291                &client,
292                CubeCount::Static(outer_wg_r2, 1, 1),
293                CubeDim::new_1d(WORKGROUP_SIZE),
294                ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
295                ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
296                n,          // comptime
297                hs,         // comptime
298                batch_size, // comptime
299                false,      // comptime — inverse FFT
300            )
301            .expect("IFFT batch outer radix-2 trailing butterfly launch failed")
302        };
303    }
304
305    // ── Read back and apply 1/N scaling on the CPU ────────────────────────────
306    let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
307    let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
308
309    let scale = (n as f32).recip();
310    for v in real_out.iter_mut() { *v *= scale; }
311    for v in imag_out.iter_mut() { *v *= scale; }
312
313    // ── Unpack: each output is [real[0..n] ++ imag[0..n]] ────────────────────
314    (0..batch_size)
315        .map(|b| {
316            let start = b * n;
317            let end   = start + n;
318            let mut out = real_out[start..end].to_vec();
319            out.extend_from_slice(&imag_out[start..end]);
320            out
321        })
322        .collect()
323}