Skip to main content

wgsl_fft/
pipelines.rs

1//! Pre-compiled Cooley-Tukey Radix-2 FFT pipelines for embedding in a larger GPU pipeline.
2//!
3//! `FftPipelines` owns its wgpu device and queue. Use [`FftPipelines::device`] and
4//! [`FftPipelines::queue`] to share them with the rest of your pipeline so all GPU
5//! resources live on a single device.
6
7use std::cell::RefCell;
8
9use wgpu::util::DeviceExt;
10use wgpu::{
11    BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
12    BufferBindingType, ComputePipeline, Device, ShaderStages,
13};
14
15use crate::error::Result;
16use crate::shaders;
17
18/// Direction for FFT operations
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum FftDirection {
21    /// Forward FFT
22    Forward = 0,
23    /// Inverse FFT
24    Inverse = 1,
25}
26
27// Private helper functions
28
29fn fft_storage_entry(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
30    BindGroupLayoutEntry {
31        binding,
32        visibility: ShaderStages::COMPUTE,
33        ty: BindingType::Buffer {
34            ty: BufferBindingType::Storage { read_only },
35            has_dynamic_offset: false,
36            min_binding_size: None,
37        },
38        count: None,
39    }
40}
41
42fn fft_uniform_entry(binding: u32) -> BindGroupLayoutEntry {
43    BindGroupLayoutEntry {
44        binding,
45        visibility: ShaderStages::COMPUTE,
46        ty: BindingType::Buffer {
47            ty: BufferBindingType::Uniform,
48            has_dynamic_offset: false,
49            min_binding_size: None,
50        },
51        count: None,
52    }
53}
54
55fn fft_make_pipeline(
56    device: &Device,
57    label: &str,
58    bgl: &BindGroupLayout,
59    src: &str,
60) -> ComputePipeline {
61    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
62        label: Some(&format!("{label}_shader")),
63        source: wgpu::ShaderSource::Wgsl(src.into()),
64    });
65    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
66        label: Some(&format!("{label}_layout")),
67        bind_group_layouts: &[Some(bgl)],
68        immediate_size: 0,
69    });
70    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
71        label: Some(label),
72        layout: Some(&layout),
73        module: &shader,
74        entry_point: Some("main"),
75        compilation_options: Default::default(),
76        cache: None,
77    })
78}
79
80// Pre-baked resources for a single encode_fft call (specific n, direction, input/output buffers).
81// bind_groups[0] = bit-reversal pass, bind_groups[1..] = butterfly stages.
82// params holds the uniform buffers alive (wgpu BindGroups ref-count their resources,
83// but keeping them here makes the ownership explicit).
84struct FftCallCache {
85    #[allow(dead_code)]
86    params: Vec<wgpu::Buffer>,
87    bind_groups: Vec<wgpu::BindGroup>,
88}
89
90struct FftNormCache {
91    #[allow(dead_code)]
92    params: wgpu::Buffer,
93    bind_group: wgpu::BindGroup,
94}
95
96/// Pre-compiled Cooley-Tukey Radix-2 FFT pipelines for embedding in a larger GPU pipeline.
97///
98/// `FftPipelines` owns its wgpu device and queue. Use [`FftPipelines::device`] and
99/// [`FftPipelines::queue`] to share them with the rest of your pipeline so all GPU
100/// resources live on a single device.
101///
102/// # Example
103///
104/// ```no_run
105/// use wgsl_fft::{FftPipelines, FftDirection};
106///
107/// let fft = FftPipelines::new().expect("GPU required");
108/// let device = fft.device();
109/// let queue  = fft.queue();
110///
111/// let n: usize = 1024;
112/// let batch_size: u32 = 1;
113/// // allocate input_buf and output_buf as STORAGE buffers of size n * 8 * batch_size bytes (scratch is managed internally)
114/// # let input_buf: wgpu::Buffer = unimplemented!();
115/// # let output_buf: wgpu::Buffer = unimplemented!();
116/// let mut encoder = device.create_command_encoder(&Default::default());
117/// fft.encode_fft(&mut encoder, n, batch_size, FftDirection::Forward, &input_buf, &output_buf);
118/// fft.encode_normalize(&mut encoder, n, batch_size, &output_buf);
119/// queue.submit(std::iter::once(encoder.finish()));
120/// ```
121pub struct FftPipelines {
122    device: Device,
123    /// The wgpu queue for submitting command encoders.
124    pub queue: wgpu::Queue,
125    pipeline_butterfly: ComputePipeline,
126    pipeline_bit_reverse: ComputePipeline,
127    pipeline_normalize: ComputePipeline,
128    bgl: BindGroupLayout,
129    bgl_norm: BindGroupLayout,
130    scratch: RefCell<std::collections::HashMap<usize, wgpu::Buffer>>,
131    // Keyed by (n, direction_as_u32, input_ptr, output_ptr)
132    call_cache: RefCell<std::collections::HashMap<(usize, u32, usize, usize), FftCallCache>>,
133    // Keyed by (n, buf_ptr)
134    norm_cache: RefCell<std::collections::HashMap<(usize, usize), FftNormCache>>,
135}
136
137impl FftPipelines {
138    /// Get buffer pair based on log2_n parity.
139    fn get_buffer_pair_for_mode<'a>(
140        log2_n: u32,
141        output_buf: &'a wgpu::Buffer,
142        scratch_buf: &'a wgpu::Buffer,
143    ) -> (&'a wgpu::Buffer, &'a wgpu::Buffer) {
144        if log2_n % 2 == 0 {
145            return (output_buf, scratch_buf);
146        }
147        (scratch_buf, output_buf)
148    }
149
150    /// Initialize GPU and compile all three FFT compute pipelines.
151    pub fn new() -> Result<Self> {
152        let instance = wgpu::Instance::default();
153        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
154            power_preference: wgpu::PowerPreference::HighPerformance,
155            compatible_surface: None,
156            force_fallback_adapter: false,
157        }))
158        .or_else(|_| {
159            pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
160                power_preference: wgpu::PowerPreference::HighPerformance,
161                compatible_surface: None,
162                force_fallback_adapter: true,
163            }))
164        })?;
165        let (device, queue) =
166            pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
167                ..Default::default()
168            }))?;
169        Ok(Self::from_device_queue(device, queue))
170    }
171
172    /// Build pipelines from an existing device and queue.
173    ///
174    /// Use this when you already have a `wgpu::Device` (e.g. from a window surface)
175    /// and want to avoid creating a second GPU context.
176    pub fn from_device_queue(device: Device, queue: wgpu::Queue) -> Self {
177        let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
178            label: Some("fft_pipelines_bgl"),
179            entries: &[
180                fft_storage_entry(0, true),
181                fft_storage_entry(1, false),
182                fft_uniform_entry(2),
183            ],
184        });
185        let bgl_norm = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
186            label: Some("fft_pipelines_norm_bgl"),
187            entries: &[fft_storage_entry(0, false), fft_uniform_entry(1)],
188        });
189        let pipeline_butterfly = fft_make_pipeline(
190            &device,
191            "fft_butterfly",
192            &bgl,
193            shaders::COOLEY_TUKEY_R2_WGSL,
194        );
195        let pipeline_bit_reverse =
196            fft_make_pipeline(&device, "fft_bit_reverse", &bgl, shaders::BIT_REVERSAL_WGSL);
197        let pipeline_normalize = fft_make_pipeline(
198            &device,
199            "fft_normalize",
200            &bgl_norm,
201            shaders::NORMALIZE_VEC2_WGSL,
202        );
203        Self {
204            device,
205            queue,
206            pipeline_butterfly,
207            pipeline_bit_reverse,
208            pipeline_normalize,
209            bgl,
210            bgl_norm,
211            scratch: RefCell::new(std::collections::HashMap::new()),
212            call_cache: RefCell::new(std::collections::HashMap::new()),
213            norm_cache: RefCell::new(std::collections::HashMap::new()),
214        }
215    }
216
217    /// The wgpu device that owns all GPU resources in this instance.
218    pub fn device(&self) -> &Device {
219        &self.device
220    }
221
222    /// The wgpu queue for submitting command encoders.
223    pub fn queue(&self) -> &wgpu::Queue {
224        &self.queue
225    }
226
227    /// Encode one or more FFTs or IFFTs into `encoder`. The result is written to `output_buf`.
228    ///
229    /// All bind groups and uniform buffers are cached after the first call for a given
230    /// (n, direction, input_buf, output_buf) combination — subsequent calls encode with
231    /// zero allocations.
232    pub fn encode_fft(
233        &self,
234        encoder: &mut wgpu::CommandEncoder,
235        n: usize,
236        batch_size: u32,
237        direction: FftDirection,
238        input_buf: &wgpu::Buffer,
239        output_buf: &wgpu::Buffer,
240    ) {
241        let log2_n = n.trailing_zeros();
242
243        // Ensure scratch buffer exists for this n and batch_size
244        {
245            let byte_size = (n * 8 * batch_size as usize) as u64;
246            let mut map = self.scratch.borrow_mut();
247            let buf = map.entry(n).or_insert_with(|| {
248                self.device.create_buffer(&wgpu::BufferDescriptor {
249                    label: Some("fft_scratch"),
250                    size: byte_size,
251                    usage: wgpu::BufferUsages::STORAGE
252                        | wgpu::BufferUsages::COPY_SRC
253                        | wgpu::BufferUsages::COPY_DST,
254                    mapped_at_creation: false,
255                })
256            });
257            if buf.size() < byte_size {
258                *buf = self.device.create_buffer(&wgpu::BufferDescriptor {
259                    label: Some("fft_scratch"),
260                    size: byte_size,
261                    usage: wgpu::BufferUsages::STORAGE
262                        | wgpu::BufferUsages::COPY_SRC
263                        | wgpu::BufferUsages::COPY_DST,
264                    mapped_at_creation: false,
265                });
266            }
267        }
268
269        let key = (
270            n,
271            direction as u32,
272            input_buf as *const _ as usize,
273            output_buf as *const _ as usize,
274        );
275
276        // Build and cache on first call for this key
277        {
278            let scratch_guard = self.scratch.borrow();
279            let scratch_buf = scratch_guard.get(&n).unwrap();
280            let mut cache = self.call_cache.borrow_mut();
281            if !cache.contains_key(&key) {
282                let entry = Self::build_fft_cache(
283                    &self.device,
284                    &self.bgl,
285                    n,
286                    direction,
287                    input_buf,
288                    output_buf,
289                    scratch_buf,
290                );
291                cache.insert(key, entry);
292            }
293        }
294
295        // Encode using cached bind groups — zero allocations
296        let cache_guard = self.call_cache.borrow();
297        let cached = cache_guard.get(&key).unwrap();
298
299        {
300            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
301                label: Some("bit_reversal_pass"),
302                timestamp_writes: None,
303            });
304            pass.set_pipeline(&self.pipeline_bit_reverse);
305            pass.set_bind_group(0, &cached.bind_groups[0], &[]);
306            pass.dispatch_workgroups((n as u32).div_ceil(256), batch_size, 1);
307        }
308
309        for stage in 0..log2_n as usize {
310            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
311                label: Some("fft_butterfly_pass"),
312                timestamp_writes: None,
313            });
314            pass.set_pipeline(&self.pipeline_butterfly);
315            pass.set_bind_group(0, &cached.bind_groups[1 + stage], &[]);
316            pass.dispatch_workgroups(((n / 2) as u32).div_ceil(256), batch_size, 1);
317        }
318    }
319
320    fn build_fft_cache(
321        device: &Device,
322        bgl: &BindGroupLayout,
323        n: usize,
324        direction: FftDirection,
325        input_buf: &wgpu::Buffer,
326        output_buf: &wgpu::Buffer,
327        scratch_buf: &wgpu::Buffer,
328    ) -> FftCallCache {
329        let log2_n = n.trailing_zeros();
330        let dir = direction as u32;
331
332        let (buf0, buf1) = Self::get_buffer_pair_for_mode(log2_n, output_buf, scratch_buf);
333
334        let mut params = Vec::with_capacity(1 + log2_n as usize);
335        let mut bind_groups = Vec::with_capacity(1 + log2_n as usize);
336
337        // Bit-reversal pass
338        let br_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
339            label: Some("bit_rev_params"),
340            contents: bytemuck::cast_slice(&[n as u32, log2_n, 0u32, 0u32]),
341            usage: wgpu::BufferUsages::UNIFORM,
342        });
343        let br_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
344            label: Some("fft_bit_rev_bg"),
345            layout: bgl,
346            entries: &[
347                wgpu::BindGroupEntry {
348                    binding: 0,
349                    resource: input_buf.as_entire_binding(),
350                },
351                wgpu::BindGroupEntry {
352                    binding: 1,
353                    resource: buf0.as_entire_binding(),
354                },
355                wgpu::BindGroupEntry {
356                    binding: 2,
357                    resource: br_params.as_entire_binding(),
358                },
359            ],
360        });
361        params.push(br_params);
362        bind_groups.push(br_bg);
363
364        // Butterfly passes
365        let bufs = [buf0, buf1];
366        for stage in 0..log2_n {
367            let src = bufs[stage as usize % 2];
368            let dst = bufs[(stage as usize + 1) % 2];
369            let stage_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
370                label: Some(&format!("fft_stage{stage}_params")),
371                contents: bytemuck::cast_slice(&[n as u32, stage, dir, 0u32]),
372                usage: wgpu::BufferUsages::UNIFORM,
373            });
374            let stage_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
375                label: Some(&format!("fft_butterfly_bg_stage{stage}")),
376                layout: bgl,
377                entries: &[
378                    wgpu::BindGroupEntry {
379                        binding: 0,
380                        resource: src.as_entire_binding(),
381                    },
382                    wgpu::BindGroupEntry {
383                        binding: 1,
384                        resource: dst.as_entire_binding(),
385                    },
386                    wgpu::BindGroupEntry {
387                        binding: 2,
388                        resource: stage_params.as_entire_binding(),
389                    },
390                ],
391            });
392            params.push(stage_params);
393            bind_groups.push(stage_bg);
394        }
395
396        FftCallCache {
397            params,
398            bind_groups,
399        }
400    }
401
402    /// Encode an in-place divide-by-N pass on `buf` (IFFT normalization).
403    ///
404    /// Bind group and uniform buffer are cached after the first call for a given (n, buf).
405    pub fn encode_normalize(
406        &self,
407        encoder: &mut wgpu::CommandEncoder,
408        n: usize,
409        batch_size: u32,
410        buf: &wgpu::Buffer,
411    ) {
412        let key = (n, buf as *const _ as usize);
413        {
414            let mut cache = self.norm_cache.borrow_mut();
415            if !cache.contains_key(&key) {
416                let params = self
417                    .device
418                    .create_buffer_init(&wgpu::util::BufferInitDescriptor {
419                        label: Some("normalize_params"),
420                        contents: bytemuck::cast_slice(&[n as u32, 0u32, 0u32, 0u32]),
421                        usage: wgpu::BufferUsages::UNIFORM,
422                    });
423                let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
424                    label: Some("normalize_bg"),
425                    layout: &self.bgl_norm,
426                    entries: &[
427                        wgpu::BindGroupEntry {
428                            binding: 0,
429                            resource: buf.as_entire_binding(),
430                        },
431                        wgpu::BindGroupEntry {
432                            binding: 1,
433                            resource: params.as_entire_binding(),
434                        },
435                    ],
436                });
437                cache.insert(
438                    key,
439                    FftNormCache {
440                        params,
441                        bind_group: bg,
442                    },
443                );
444            }
445        }
446        let cache_guard = self.norm_cache.borrow();
447        let cached = cache_guard.get(&key).unwrap();
448        {
449            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
450                label: Some("normalize_pass"),
451                timestamp_writes: None,
452            });
453            pass.set_pipeline(&self.pipeline_normalize);
454            pass.set_bind_group(0, &cached.bind_group, &[]);
455            pass.dispatch_workgroups((n as u32).div_ceil(256), batch_size, 1);
456        }
457    }
458}