1use cubecl::prelude::*;
2
3use crate::butterfly::{
4 bit_reverse,
5 butterfly_inner, butterfly_inner_batch,
6 butterfly_stage, butterfly_stage_batch,
7 butterfly_stage_radix4, butterfly_stage_radix4_batch,
8};
9use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
10
11#[must_use]
39pub fn fft<R: Runtime>(device: &R::Device, input: &[f32]) -> (Vec<f32>, Vec<f32>) {
40 let n_orig = input.len();
41 let n = n_orig.next_power_of_two();
42
43 if n <= 1 {
45 let mut real = vec![0.0f32; n];
46 if n == 1 && n_orig == 1 {
47 real[0] = input[0];
48 }
49 return (real, vec![0.0f32; n]);
50 }
51
52 let m = n.ilog2() as usize;
53
54 let mut real = vec![0.0f32; n];
56 for (i, &v) in input.iter().enumerate() {
57 real[bit_reverse(i, m as u32)] = v;
58 }
59 let imag = vec![0.0f32; n];
60
61 let client = R::client(device);
62 let real_handle = client.create_from_slice(f32::as_bytes(&real));
63 let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
64
65 let inner_stages = m.min(TILE_BITS);
69 let tile = TILE_SIZE.min(n); let num_tiles = (n / TILE_SIZE).max(1) as u32;
71 let wg_threads = (n / 2).min(TILE_SIZE / 2) as u32; unsafe {
74 butterfly_inner::launch::<f32, R>(
75 &client,
76 CubeCount::Static(num_tiles, 1, 1),
77 CubeDim::new_1d(wg_threads),
78 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
79 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
80 tile, inner_stages, true, )
84 .expect("FFT inner (shared-memory) launch failed")
85 };
86
87 let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
91 let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
92
93 let mut s = inner_stages;
94 while s + 1 < m {
95 let q = 1_usize << s; unsafe {
97 butterfly_stage_radix4::launch::<f32, R>(
98 &client,
99 CubeCount::Static(outer_wg_r4, 1, 1),
100 CubeDim::new_1d(WORKGROUP_SIZE),
101 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
102 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
103 n, q, true, )
107 .expect("FFT outer radix-4 butterfly launch failed")
108 };
109 s += 2;
110 }
111 if s < m {
113 let hs = 1_usize << s;
114 unsafe {
115 butterfly_stage::launch::<f32, R>(
116 &client,
117 CubeCount::Static(outer_wg_r2, 1, 1),
118 CubeDim::new_1d(WORKGROUP_SIZE),
119 ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
120 ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
121 n, hs, true, )
125 .expect("FFT outer radix-2 trailing butterfly launch failed")
126 };
127 }
128
129 let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
130 let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
131
132 (real_out, imag_out)
133}
134
135#[must_use]
165pub fn fft_batch<R: Runtime>(device: &R::Device, signals: &[Vec<f32>]) -> Vec<(Vec<f32>, Vec<f32>)> {
166 if signals.is_empty() {
167 return Vec::new();
168 }
169
170 let batch_size = signals.len();
171 let max_len = signals.iter().map(|s| s.len()).max().unwrap_or(0);
172
173 let n_raw = max_len.next_power_of_two().max(1);
175 if n_raw <= 1 {
176 return signals
177 .iter()
178 .map(|s| {
179 let mut real = vec![0.0f32; n_raw];
180 if n_raw == 1 && !s.is_empty() {
181 real[0] = s[0];
182 }
183 (real, vec![0.0f32; n_raw])
184 })
185 .collect();
186 }
187
188 let n = n_raw;
189 let m = n.ilog2() as usize;
190
191 let mut real_flat = vec![0.0f32; batch_size * n];
193 let imag_flat = vec![0.0f32; batch_size * n];
194
195 for (b, signal) in signals.iter().enumerate() {
196 let base = b * n;
197 for (i, &v) in signal.iter().enumerate() {
198 real_flat[base + bit_reverse(i, m as u32)] = v;
199 }
200 }
201
202 let client = R::client(device);
203 let total = batch_size * n;
204 let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
205 let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
206
207 let inner_stages = m.min(TILE_BITS);
210 let tile = TILE_SIZE.min(n);
211 let tiles_per_signal = (n / tile).max(1);
212 let wg_count = (tiles_per_signal * batch_size) as u32;
213 let wg_threads = (tile / 2) as u32;
214
215 unsafe {
216 butterfly_inner_batch::launch::<f32, R>(
217 &client,
218 CubeCount::Static(wg_count, 1, 1),
219 CubeDim::new_1d(wg_threads),
220 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
221 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
222 n, tile, inner_stages, true, )
227 .expect("FFT batch inner (shared-memory) launch failed")
228 };
229
230 let total_groups_r4 = batch_size * (n / 4);
232 let total_pairs_r2 = batch_size * (n / 2);
233 let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
234 let outer_wg_r2 = ((total_pairs_r2 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
235
236 let mut s = inner_stages;
237 while s + 1 < m {
238 let q = 1_usize << s;
239 unsafe {
240 butterfly_stage_radix4_batch::launch::<f32, R>(
241 &client,
242 CubeCount::Static(outer_wg_r4, 1, 1),
243 CubeDim::new_1d(WORKGROUP_SIZE),
244 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
245 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
246 n, q, batch_size, true, )
251 .expect("FFT batch outer radix-4 butterfly launch failed")
252 };
253 s += 2;
254 }
255 if s < m {
256 let hs = 1_usize << s;
257 unsafe {
258 butterfly_stage_batch::launch::<f32, R>(
259 &client,
260 CubeCount::Static(outer_wg_r2, 1, 1),
261 CubeDim::new_1d(WORKGROUP_SIZE),
262 ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
263 ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
264 n, hs, batch_size, true, )
269 .expect("FFT batch outer radix-2 trailing butterfly launch failed")
270 };
271 }
272
273 let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
274 let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
275
276 (0..batch_size)
278 .map(|b| {
279 let start = b * n;
280 let end = start + n;
281 (real_out[start..end].to_vec(), imag_out[start..end].to_vec())
282 })
283 .collect()
284}