1use crate::buffer::Arena;
19use crate::kernels::{
20 CopyParams, FftGpuParams, Kernel, copy_kernel, fft_gpu_bit_reverse_kernel,
21 fft_gpu_inner_kernel, fft_gpu_outer_r2_kernel, fft_gpu_outer_r4_kernel,
22 fft_gpu_radix2_full_kernel,
23};
24
25const WG: u32 = 256;
26
27fn grid_1d(n: u32) -> u32 {
28 n.div_ceil(WG)
29}
30
31fn dispatch_dims(n: u32, wg: u32) -> (u32, u32, u32) {
32 (n.div_ceil(wg).max(1), 1, 1)
33}
34
35pub struct FftGpuResources {
37 pub uniform: wgpu::Buffer,
38 pub copy_uniform: wgpu::Buffer,
39 pub bg_radix2_full: wgpu::BindGroup,
40 pub bg_bit_reverse: wgpu::BindGroup,
41 pub bg_inner: wgpu::BindGroup,
42 pub bg_outer_r4: wgpu::BindGroup,
43 pub bg_outer_r2: wgpu::BindGroup,
44 pub bg_copy: wgpu::BindGroup,
45}
46
47impl FftGpuResources {
48 pub fn new(device: &wgpu::Device, arena: &wgpu::Buffer) -> Self {
49 let uniform = device.create_buffer(&wgpu::BufferDescriptor {
50 label: Some("rlx-wgpu fft uniform"),
51 size: std::mem::size_of::<FftGpuParams>() as u64,
52 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
53 mapped_at_creation: false,
54 });
55 let copy_uniform = device.create_buffer(&wgpu::BufferDescriptor {
56 label: Some("rlx-wgpu fft copy uniform"),
57 size: std::mem::size_of::<CopyParams>() as u64,
58 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
59 mapped_at_creation: false,
60 });
61 let mk_bg = |k: &Kernel| k.bind_two(device, arena, &uniform);
62 Self {
63 bg_radix2_full: mk_bg(fft_gpu_radix2_full_kernel(device)),
64 bg_bit_reverse: mk_bg(fft_gpu_bit_reverse_kernel(device)),
65 bg_inner: mk_bg(fft_gpu_inner_kernel(device)),
66 bg_outer_r4: mk_bg(fft_gpu_outer_r4_kernel(device)),
67 bg_outer_r2: mk_bg(fft_gpu_outer_r2_kernel(device)),
68 bg_copy: copy_kernel(device).bind_two(device, arena, ©_uniform),
69 uniform,
70 copy_uniform,
71 }
72 }
73}
74
75fn dispatch_with_bg(
76 pass: &mut wgpu::ComputePass<'_>,
77 pipeline: &wgpu::ComputePipeline,
78 bg: &wgpu::BindGroup,
79 gx: u32,
80 gy: u32,
81 gz: u32,
82) {
83 pass.set_pipeline(pipeline);
84 pass.set_bind_group(0, bg, &[]);
85 pass.dispatch_workgroups(gx, gy, gz);
86}
87
88pub fn dispatch_fft_gpu_in_pass(
90 device: &wgpu::Device,
91 queue: &wgpu::Queue,
92 pass: &mut wgpu::ComputePass<'_>,
93 res: &FftGpuResources,
94 src_off: u32,
95 dst_off: u32,
96 outer: u32,
97 n: u32,
98 inverse: bool,
99 norm_scale: f32,
100) {
101 if outer == 0 {
102 return;
103 }
104 let plan = rlx_ir::fft::FftGpuPlan::new(n as usize).expect("run_fft_gpu: n must be pow2");
105 let inv = if inverse { 1u32 } else { 0u32 };
106 let log2n = n.trailing_zeros();
107 if src_off != dst_off && !plan.single_inner_only() {
108 let count = outer * n * 2;
109 let cp = CopyParams {
110 n: count,
111 in_off: src_off,
112 out_off: dst_off,
113 _p0: 0,
114 _p1: 0,
115 _p2: 0,
116 _p3: 0,
117 _p4: 0,
118 };
119 queue.write_buffer(&res.copy_uniform, 0, bytemuck::bytes_of(&cp));
120 let (gx, gy, gz) = dispatch_dims(count, 64);
121 dispatch_with_bg(
122 pass,
123 ©_kernel(device).pipeline,
124 &res.bg_copy,
125 gx,
126 gy,
127 gz,
128 );
129 }
130 let off = dst_off;
131
132 if plan.single_inner_only() {
133 let p = FftGpuParams {
134 off: src_off,
135 dst_off,
136 n,
137 log2n,
138 inverse: inv,
139 norm_scale,
140 outer,
141 tile: 0,
142 inner_stages: 0,
143 q_or_hs: 0,
144 };
145 queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
146 dispatch_with_bg(
147 pass,
148 &fft_gpu_radix2_full_kernel(device).pipeline,
149 &res.bg_radix2_full,
150 1,
151 outer,
152 1,
153 );
154 return;
155 }
156
157 let mut p = FftGpuParams {
158 off,
159 dst_off,
160 n,
161 log2n,
162 inverse: inv,
163 norm_scale: 1.0,
164 outer,
165 tile: 0,
166 inner_stages: 0,
167 q_or_hs: 0,
168 };
169
170 queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
171 dispatch_with_bg(
172 pass,
173 &fft_gpu_bit_reverse_kernel(device).pipeline,
174 &res.bg_bit_reverse,
175 grid_1d(n),
176 outer,
177 1,
178 );
179
180 let tile = rlx_ir::fft::FFT_TILE_SIZE.min(n as usize) as u32;
181 let inner_stages = plan.inner_stages as u32;
182 let num_tiles = (n / tile).max(1);
183 p.tile = tile;
184 p.inner_stages = inner_stages;
185 p.norm_scale = 1.0;
186 queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
187 dispatch_with_bg(
188 pass,
189 &fft_gpu_inner_kernel(device).pipeline,
190 &res.bg_inner,
191 num_tiles,
192 outer,
193 1,
194 );
195
196 let r4_count = plan.outer_rad4_q.len();
197 for (i, q) in plan.outer_rad4_q.iter().enumerate() {
198 p.q_or_hs = *q as u32;
199 p.norm_scale = if plan.outer_r2_hs.is_none() && i + 1 == r4_count {
200 norm_scale
201 } else {
202 1.0
203 };
204 queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
205 dispatch_with_bg(
206 pass,
207 &fft_gpu_outer_r4_kernel(device).pipeline,
208 &res.bg_outer_r4,
209 grid_1d((n / 4).max(1)),
210 outer,
211 1,
212 );
213 }
214
215 if let Some(hs) = plan.outer_r2_hs {
216 p.q_or_hs = hs as u32;
217 p.norm_scale = norm_scale;
218 queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
219 dispatch_with_bg(
220 pass,
221 &fft_gpu_outer_r2_kernel(device).pipeline,
222 &res.bg_outer_r2,
223 grid_1d(n / 2),
224 outer,
225 1,
226 );
227 }
228}
229
230pub fn run_fft_gpu_cached(
232 device: &wgpu::Device,
233 queue: &wgpu::Queue,
234 _arena: &Arena,
235 res: &FftGpuResources,
236 src_off: u32,
237 dst_off: u32,
238 outer: u32,
239 n: u32,
240 inverse: bool,
241 norm_scale: f32,
242) {
243 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
244 label: Some("rlx-wgpu fft gpu"),
245 });
246 {
247 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
248 label: Some("rlx-wgpu fft gpu pass"),
249 timestamp_writes: None,
250 });
251 dispatch_fft_gpu_in_pass(
252 device, queue, &mut pass, res, src_off, dst_off, outer, n, inverse, norm_scale,
253 );
254 }
255 queue.submit(std::iter::once(encoder.finish()));
256}
257
258pub fn run_fft_gpu(
260 device: &wgpu::Device,
261 queue: &wgpu::Queue,
262 arena: &Arena,
263 src_off: u32,
264 dst_off: u32,
265 outer: u32,
266 n: u32,
267 inverse: bool,
268 norm_scale: f32,
269) {
270 let res = FftGpuResources::new(device, &arena.buffer);
271 run_fft_gpu_cached(
272 device, queue, arena, &res, src_off, dst_off, outer, n, inverse, norm_scale,
273 );
274}