Skip to main content

rlx_wgpu/
fft_dispatch.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// Multi-kernel f32 FFT dispatch for wgpu (mirrors rlx-cuda/src/fft_dispatch.rs).
17
18use crate::buffer::Arena;
19use crate::kernels::{
20    CopyParams, FftGpuParams, Kernel, copy_kernel, fft_gpu_bit_reverse_kernel,
21    fft_gpu_inner_kernel, fft_gpu_outer_r2_kernel, fft_gpu_outer_r4_kernel,
22    fft_gpu_radix2_full_kernel,
23};
24
25const WG: u32 = 256;
26
27fn grid_1d(n: u32) -> u32 {
28    n.div_ceil(WG)
29}
30
31fn dispatch_dims(n: u32, wg: u32) -> (u32, u32, u32) {
32    (n.div_ceil(wg).max(1), 1, 1)
33}
34
35/// Pre-built uniform buffers + bind groups for FFT stages (per executable).
36pub struct FftGpuResources {
37    pub uniform: wgpu::Buffer,
38    pub copy_uniform: wgpu::Buffer,
39    pub bg_radix2_full: wgpu::BindGroup,
40    pub bg_bit_reverse: wgpu::BindGroup,
41    pub bg_inner: wgpu::BindGroup,
42    pub bg_outer_r4: wgpu::BindGroup,
43    pub bg_outer_r2: wgpu::BindGroup,
44    pub bg_copy: wgpu::BindGroup,
45}
46
47impl FftGpuResources {
48    pub fn new(device: &wgpu::Device, arena: &wgpu::Buffer) -> Self {
49        let uniform = device.create_buffer(&wgpu::BufferDescriptor {
50            label: Some("rlx-wgpu fft uniform"),
51            size: std::mem::size_of::<FftGpuParams>() as u64,
52            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
53            mapped_at_creation: false,
54        });
55        let copy_uniform = device.create_buffer(&wgpu::BufferDescriptor {
56            label: Some("rlx-wgpu fft copy uniform"),
57            size: std::mem::size_of::<CopyParams>() as u64,
58            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
59            mapped_at_creation: false,
60        });
61        let mk_bg = |k: &Kernel| k.bind_two(device, arena, &uniform);
62        Self {
63            bg_radix2_full: mk_bg(fft_gpu_radix2_full_kernel(device)),
64            bg_bit_reverse: mk_bg(fft_gpu_bit_reverse_kernel(device)),
65            bg_inner: mk_bg(fft_gpu_inner_kernel(device)),
66            bg_outer_r4: mk_bg(fft_gpu_outer_r4_kernel(device)),
67            bg_outer_r2: mk_bg(fft_gpu_outer_r2_kernel(device)),
68            bg_copy: copy_kernel(device).bind_two(device, arena, &copy_uniform),
69            uniform,
70            copy_uniform,
71        }
72    }
73}
74
75fn dispatch_with_bg(
76    pass: &mut wgpu::ComputePass<'_>,
77    pipeline: &wgpu::ComputePipeline,
78    bg: &wgpu::BindGroup,
79    gx: u32,
80    gy: u32,
81    gz: u32,
82) {
83    pass.set_pipeline(pipeline);
84    pass.set_bind_group(0, bg, &[]);
85    pass.dispatch_workgroups(gx, gy, gz);
86}
87
88/// Run FFT stages inside an existing compute pass (no extra submit/poll).
89pub fn dispatch_fft_gpu_in_pass(
90    device: &wgpu::Device,
91    queue: &wgpu::Queue,
92    pass: &mut wgpu::ComputePass<'_>,
93    res: &FftGpuResources,
94    src_off: u32,
95    dst_off: u32,
96    outer: u32,
97    n: u32,
98    inverse: bool,
99    norm_scale: f32,
100) {
101    if outer == 0 {
102        return;
103    }
104    let plan = rlx_ir::fft::FftGpuPlan::new(n as usize).expect("run_fft_gpu: n must be pow2");
105    let inv = if inverse { 1u32 } else { 0u32 };
106    let log2n = n.trailing_zeros();
107    if src_off != dst_off && !plan.single_inner_only() {
108        let count = outer * n * 2;
109        let cp = CopyParams {
110            n: count,
111            in_off: src_off,
112            out_off: dst_off,
113            _p0: 0,
114            _p1: 0,
115            _p2: 0,
116            _p3: 0,
117            _p4: 0,
118        };
119        queue.write_buffer(&res.copy_uniform, 0, bytemuck::bytes_of(&cp));
120        let (gx, gy, gz) = dispatch_dims(count, 64);
121        dispatch_with_bg(
122            pass,
123            &copy_kernel(device).pipeline,
124            &res.bg_copy,
125            gx,
126            gy,
127            gz,
128        );
129    }
130    let off = dst_off;
131
132    if plan.single_inner_only() {
133        let p = FftGpuParams {
134            off: src_off,
135            dst_off,
136            n,
137            log2n,
138            inverse: inv,
139            norm_scale,
140            outer,
141            tile: 0,
142            inner_stages: 0,
143            q_or_hs: 0,
144        };
145        queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
146        dispatch_with_bg(
147            pass,
148            &fft_gpu_radix2_full_kernel(device).pipeline,
149            &res.bg_radix2_full,
150            1,
151            outer,
152            1,
153        );
154        return;
155    }
156
157    let mut p = FftGpuParams {
158        off,
159        dst_off,
160        n,
161        log2n,
162        inverse: inv,
163        norm_scale: 1.0,
164        outer,
165        tile: 0,
166        inner_stages: 0,
167        q_or_hs: 0,
168    };
169
170    queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
171    dispatch_with_bg(
172        pass,
173        &fft_gpu_bit_reverse_kernel(device).pipeline,
174        &res.bg_bit_reverse,
175        grid_1d(n),
176        outer,
177        1,
178    );
179
180    let tile = rlx_ir::fft::FFT_TILE_SIZE.min(n as usize) as u32;
181    let inner_stages = plan.inner_stages as u32;
182    let num_tiles = (n / tile).max(1);
183    p.tile = tile;
184    p.inner_stages = inner_stages;
185    p.norm_scale = 1.0;
186    queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
187    dispatch_with_bg(
188        pass,
189        &fft_gpu_inner_kernel(device).pipeline,
190        &res.bg_inner,
191        num_tiles,
192        outer,
193        1,
194    );
195
196    let r4_count = plan.outer_rad4_q.len();
197    for (i, q) in plan.outer_rad4_q.iter().enumerate() {
198        p.q_or_hs = *q as u32;
199        p.norm_scale = if plan.outer_r2_hs.is_none() && i + 1 == r4_count {
200            norm_scale
201        } else {
202            1.0
203        };
204        queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
205        dispatch_with_bg(
206            pass,
207            &fft_gpu_outer_r4_kernel(device).pipeline,
208            &res.bg_outer_r4,
209            grid_1d((n / 4).max(1)),
210            outer,
211            1,
212        );
213    }
214
215    if let Some(hs) = plan.outer_r2_hs {
216        p.q_or_hs = hs as u32;
217        p.norm_scale = norm_scale;
218        queue.write_buffer(&res.uniform, 0, bytemuck::bytes_of(&p));
219        dispatch_with_bg(
220            pass,
221            &fft_gpu_outer_r2_kernel(device).pipeline,
222            &res.bg_outer_r2,
223            grid_1d(n / 2),
224            outer,
225            1,
226        );
227    }
228}
229
230/// Standalone FFT dispatch using compile-time cached resources.
231pub fn run_fft_gpu_cached(
232    device: &wgpu::Device,
233    queue: &wgpu::Queue,
234    _arena: &Arena,
235    res: &FftGpuResources,
236    src_off: u32,
237    dst_off: u32,
238    outer: u32,
239    n: u32,
240    inverse: bool,
241    norm_scale: f32,
242) {
243    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
244        label: Some("rlx-wgpu fft gpu"),
245    });
246    {
247        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
248            label: Some("rlx-wgpu fft gpu pass"),
249            timestamp_writes: None,
250        });
251        dispatch_fft_gpu_in_pass(
252            device, queue, &mut pass, res, src_off, dst_off, outer, n, inverse, norm_scale,
253        );
254    }
255    queue.submit(std::iter::once(encoder.finish()));
256}
257
258/// Standalone FFT dispatch (legacy callers).
259pub fn run_fft_gpu(
260    device: &wgpu::Device,
261    queue: &wgpu::Queue,
262    arena: &Arena,
263    src_off: u32,
264    dst_off: u32,
265    outer: u32,
266    n: u32,
267    inverse: bool,
268    norm_scale: f32,
269) {
270    let res = FftGpuResources::new(device, &arena.buffer);
271    run_fft_gpu_cached(
272        device, queue, arena, &res, src_off, dst_off, outer, n, inverse, norm_scale,
273    );
274}