1use cubecl::prelude::*;
2
3use crate::butterfly::{
4 bit_reverse,
5 butterfly_inner, butterfly_inner_batch,
6 butterfly_stage, butterfly_stage_batch,
7 butterfly_stage_radix4, butterfly_stage_radix4_batch,
8};
9use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
10
11#[must_use]
39pub fn ifft<R: Runtime>(
40 device: &R::Device,
41 input_real: &[f32],
42 input_imag: &[f32],
43) -> Vec<f32> {
44 assert_eq!(
45 input_real.len(),
46 input_imag.len(),
47 "ifft: real and imag slices must have the same length"
48 );
49 let n = input_real.len();
50 assert!(
51 n.is_power_of_two(),
52 "ifft: input length {n} is not a power of two (pass the direct output of fft)"
53 );
54
55 if n <= 1 {
57 let mut out = input_real.to_vec();
58 out.extend_from_slice(input_imag);
59 return out;
60 }
61
62 let m = n.ilog2() as usize;
63
64 let mut real = vec![0.0f32; n];
66 let mut imag = vec![0.0f32; n];
67 for i in 0..n {
68 let j = bit_reverse(i, m as u32);
69 real[j] = input_real[i];
70 imag[j] = input_imag[i];
71 }
72
73 let client = R::client(device);
74 let real_handle = client.create_from_slice(f32::as_bytes(&real));
75 let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
76
77 let inner_stages = m.min(TILE_BITS);
79 let tile = TILE_SIZE.min(n);
80 let num_tiles = (n / TILE_SIZE).max(1) as u32;
81 let wg_threads = (n / 2).min(TILE_SIZE / 2) as u32;
82
83 unsafe {
84 butterfly_inner::launch::<f32, R>(
85 &client,
86 CubeCount::Static(num_tiles, 1, 1),
87 CubeDim::new_1d(wg_threads),
88 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
89 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
90 tile, inner_stages, false, )
94 .expect("IFFT inner (shared-memory) launch failed")
95 };
96
97 let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
99 let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
100
101 let mut s = inner_stages;
102 while s + 1 < m {
103 let q = 1_usize << s;
104 unsafe {
105 butterfly_stage_radix4::launch::<f32, R>(
106 &client,
107 CubeCount::Static(outer_wg_r4, 1, 1),
108 CubeDim::new_1d(WORKGROUP_SIZE),
109 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
110 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
111 n, q, false, )
115 .expect("IFFT outer radix-4 butterfly launch failed")
116 };
117 s += 2;
118 }
119 if s < m {
120 let hs = 1_usize << s;
121 unsafe {
122 butterfly_stage::launch::<f32, R>(
123 &client,
124 CubeCount::Static(outer_wg_r2, 1, 1),
125 CubeDim::new_1d(WORKGROUP_SIZE),
126 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
127 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
128 n, hs, false, )
132 .expect("IFFT outer radix-2 trailing butterfly launch failed")
133 };
134 }
135
136 let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
138 let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
139
140 let scale = (n as f32).recip();
141 for v in &mut real_out {
142 *v *= scale;
143 }
144 for v in &mut imag_out {
145 *v *= scale;
146 }
147
148 real_out.extend_from_slice(&imag_out);
150 real_out
151}
152
153#[must_use]
181pub fn ifft_batch<R: Runtime>(
182 device: &R::Device,
183 signals: &[(Vec<f32>, Vec<f32>)],
184) -> Vec<Vec<f32>> {
185 if signals.is_empty() {
186 return Vec::new();
187 }
188
189 let batch_size = signals.len();
190
191 let n = signals[0].0.len();
193 for (b, (re, im)) in signals.iter().enumerate() {
194 assert_eq!(
195 re.len(), im.len(),
196 "ifft_batch: signal {b}: real and imag slices have different lengths"
197 );
198 assert_eq!(
199 re.len(), n,
200 "ifft_batch: all signals must have the same length (expected {n}, got {})", re.len()
201 );
202 }
203 assert!(
204 n.is_power_of_two(),
205 "ifft_batch: signal length {n} is not a power of two"
206 );
207
208 if n <= 1 {
210 return signals
211 .iter()
212 .map(|(re, im)| {
213 let mut out = re.clone();
214 out.extend_from_slice(im);
215 out
216 })
217 .collect();
218 }
219
220 let m = n.ilog2() as usize;
221
222 let mut real_flat = vec![0.0f32; batch_size * n];
224 let mut imag_flat = vec![0.0f32; batch_size * n];
225
226 for (b, (input_real, input_imag)) in signals.iter().enumerate() {
227 let base = b * n;
228 for i in 0..n {
229 let j = bit_reverse(i, m as u32);
230 real_flat[base + j] = input_real[i];
231 imag_flat[base + j] = input_imag[i];
232 }
233 }
234
235 let client = R::client(device);
236 let total = batch_size * n;
237 let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
238 let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
239
240 let inner_stages = m.min(TILE_BITS);
242 let tile = TILE_SIZE.min(n);
243 let tiles_per_signal = (n / tile).max(1);
244 let wg_count = (tiles_per_signal * batch_size) as u32;
245 let wg_threads = (tile / 2) as u32;
246
247 unsafe {
248 butterfly_inner_batch::launch::<f32, R>(
249 &client,
250 CubeCount::Static(wg_count, 1, 1),
251 CubeDim::new_1d(wg_threads),
252 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
253 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
254 n, tile, inner_stages, false, )
259 .expect("IFFT batch inner (shared-memory) launch failed")
260 };
261
262 let total_groups_r4 = batch_size * (n / 4);
264 let total_pairs_r2 = batch_size * (n / 2);
265 let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
266 let outer_wg_r2 = ((total_pairs_r2 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
267
268 let mut s = inner_stages;
269 while s + 1 < m {
270 let q = 1_usize << s;
271 unsafe {
272 butterfly_stage_radix4_batch::launch::<f32, R>(
273 &client,
274 CubeCount::Static(outer_wg_r4, 1, 1),
275 CubeDim::new_1d(WORKGROUP_SIZE),
276 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
277 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
278 n, q, batch_size, false, )
283 .expect("IFFT batch outer radix-4 butterfly launch failed")
284 };
285 s += 2;
286 }
287 if s < m {
288 let hs = 1_usize << s;
289 unsafe {
290 butterfly_stage_batch::launch::<f32, R>(
291 &client,
292 CubeCount::Static(outer_wg_r2, 1, 1),
293 CubeDim::new_1d(WORKGROUP_SIZE),
294 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
295 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
296 n, hs, batch_size, false, )
301 .expect("IFFT batch outer radix-2 trailing butterfly launch failed")
302 };
303 }
304
305 let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
307 let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
308
309 let scale = (n as f32).recip();
310 for v in real_out.iter_mut() { *v *= scale; }
311 for v in imag_out.iter_mut() { *v *= scale; }
312
313 (0..batch_size)
315 .map(|b| {
316 let start = b * n;
317 let end = start + n;
318 let mut out = real_out[start..end].to_vec();
319 out.extend_from_slice(&imag_out[start..end]);
320 out
321 })
322 .collect()
323}