1use std::cell::RefCell;
8
9use wgpu::util::DeviceExt;
10use wgpu::{
11 BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
12 BufferBindingType, ComputePipeline, Device, ShaderStages,
13};
14
15use crate::error::Result;
16use crate::shaders;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum FftDirection {
21 Forward = 0,
23 Inverse = 1,
25}
26
27fn fft_storage_entry(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
30 BindGroupLayoutEntry {
31 binding,
32 visibility: ShaderStages::COMPUTE,
33 ty: BindingType::Buffer {
34 ty: BufferBindingType::Storage { read_only },
35 has_dynamic_offset: false,
36 min_binding_size: None,
37 },
38 count: None,
39 }
40}
41
42fn fft_uniform_entry(binding: u32) -> BindGroupLayoutEntry {
43 BindGroupLayoutEntry {
44 binding,
45 visibility: ShaderStages::COMPUTE,
46 ty: BindingType::Buffer {
47 ty: BufferBindingType::Uniform,
48 has_dynamic_offset: false,
49 min_binding_size: None,
50 },
51 count: None,
52 }
53}
54
55fn fft_make_pipeline(
56 device: &Device,
57 label: &str,
58 bgl: &BindGroupLayout,
59 src: &str,
60) -> ComputePipeline {
61 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
62 label: Some(&format!("{label}_shader")),
63 source: wgpu::ShaderSource::Wgsl(src.into()),
64 });
65 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
66 label: Some(&format!("{label}_layout")),
67 bind_group_layouts: &[Some(bgl)],
68 immediate_size: 0,
69 });
70 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
71 label: Some(label),
72 layout: Some(&layout),
73 module: &shader,
74 entry_point: Some("main"),
75 compilation_options: Default::default(),
76 cache: None,
77 })
78}
79
80struct FftCallCache {
85 #[allow(dead_code)]
86 params: Vec<wgpu::Buffer>,
87 bind_groups: Vec<wgpu::BindGroup>,
88}
89
90struct FftNormCache {
91 #[allow(dead_code)]
92 params: wgpu::Buffer,
93 bind_group: wgpu::BindGroup,
94}
95
96pub struct FftPipelines {
122 device: Device,
123 pub queue: wgpu::Queue,
125 pipeline_butterfly: ComputePipeline,
126 pipeline_bit_reverse: ComputePipeline,
127 pipeline_normalize: ComputePipeline,
128 bgl: BindGroupLayout,
129 bgl_norm: BindGroupLayout,
130 scratch: RefCell<std::collections::HashMap<usize, wgpu::Buffer>>,
131 call_cache: RefCell<std::collections::HashMap<(usize, u32, usize, usize), FftCallCache>>,
133 norm_cache: RefCell<std::collections::HashMap<(usize, usize), FftNormCache>>,
135}
136
137impl FftPipelines {
138 fn get_buffer_pair_for_mode<'a>(
140 log2_n: u32,
141 output_buf: &'a wgpu::Buffer,
142 scratch_buf: &'a wgpu::Buffer,
143 ) -> (&'a wgpu::Buffer, &'a wgpu::Buffer) {
144 if log2_n % 2 == 0 {
145 return (output_buf, scratch_buf);
146 }
147 (scratch_buf, output_buf)
148 }
149
150 pub fn new() -> Result<Self> {
152 let instance = wgpu::Instance::default();
153 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
154 power_preference: wgpu::PowerPreference::HighPerformance,
155 compatible_surface: None,
156 force_fallback_adapter: false,
157 }))
158 .or_else(|_| {
159 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
160 power_preference: wgpu::PowerPreference::HighPerformance,
161 compatible_surface: None,
162 force_fallback_adapter: true,
163 }))
164 })?;
165 let (device, queue) =
166 pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
167 ..Default::default()
168 }))?;
169 Ok(Self::from_device_queue(device, queue))
170 }
171
172 pub fn from_device_queue(device: Device, queue: wgpu::Queue) -> Self {
177 let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
178 label: Some("fft_pipelines_bgl"),
179 entries: &[
180 fft_storage_entry(0, true),
181 fft_storage_entry(1, false),
182 fft_uniform_entry(2),
183 ],
184 });
185 let bgl_norm = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
186 label: Some("fft_pipelines_norm_bgl"),
187 entries: &[fft_storage_entry(0, false), fft_uniform_entry(1)],
188 });
189 let pipeline_butterfly = fft_make_pipeline(
190 &device,
191 "fft_butterfly",
192 &bgl,
193 shaders::COOLEY_TUKEY_R2_WGSL,
194 );
195 let pipeline_bit_reverse =
196 fft_make_pipeline(&device, "fft_bit_reverse", &bgl, shaders::BIT_REVERSAL_WGSL);
197 let pipeline_normalize = fft_make_pipeline(
198 &device,
199 "fft_normalize",
200 &bgl_norm,
201 shaders::NORMALIZE_VEC2_WGSL,
202 );
203 Self {
204 device,
205 queue,
206 pipeline_butterfly,
207 pipeline_bit_reverse,
208 pipeline_normalize,
209 bgl,
210 bgl_norm,
211 scratch: RefCell::new(std::collections::HashMap::new()),
212 call_cache: RefCell::new(std::collections::HashMap::new()),
213 norm_cache: RefCell::new(std::collections::HashMap::new()),
214 }
215 }
216
217 pub fn device(&self) -> &Device {
219 &self.device
220 }
221
222 pub fn queue(&self) -> &wgpu::Queue {
224 &self.queue
225 }
226
227 pub fn encode_fft(
233 &self,
234 encoder: &mut wgpu::CommandEncoder,
235 n: usize,
236 batch_size: u32,
237 direction: FftDirection,
238 input_buf: &wgpu::Buffer,
239 output_buf: &wgpu::Buffer,
240 ) {
241 let log2_n = n.trailing_zeros();
242
243 {
245 let byte_size = (n * 8 * batch_size as usize) as u64;
246 let mut map = self.scratch.borrow_mut();
247 let buf = map.entry(n).or_insert_with(|| {
248 self.device.create_buffer(&wgpu::BufferDescriptor {
249 label: Some("fft_scratch"),
250 size: byte_size,
251 usage: wgpu::BufferUsages::STORAGE
252 | wgpu::BufferUsages::COPY_SRC
253 | wgpu::BufferUsages::COPY_DST,
254 mapped_at_creation: false,
255 })
256 });
257 if buf.size() < byte_size {
258 *buf = self.device.create_buffer(&wgpu::BufferDescriptor {
259 label: Some("fft_scratch"),
260 size: byte_size,
261 usage: wgpu::BufferUsages::STORAGE
262 | wgpu::BufferUsages::COPY_SRC
263 | wgpu::BufferUsages::COPY_DST,
264 mapped_at_creation: false,
265 });
266 }
267 }
268
269 let key = (
270 n,
271 direction as u32,
272 input_buf as *const _ as usize,
273 output_buf as *const _ as usize,
274 );
275
276 {
278 let scratch_guard = self.scratch.borrow();
279 let scratch_buf = scratch_guard.get(&n).unwrap();
280 let mut cache = self.call_cache.borrow_mut();
281 if !cache.contains_key(&key) {
282 let entry = Self::build_fft_cache(
283 &self.device,
284 &self.bgl,
285 n,
286 direction,
287 input_buf,
288 output_buf,
289 scratch_buf,
290 );
291 cache.insert(key, entry);
292 }
293 }
294
295 let cache_guard = self.call_cache.borrow();
297 let cached = cache_guard.get(&key).unwrap();
298
299 {
300 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
301 label: Some("bit_reversal_pass"),
302 timestamp_writes: None,
303 });
304 pass.set_pipeline(&self.pipeline_bit_reverse);
305 pass.set_bind_group(0, &cached.bind_groups[0], &[]);
306 pass.dispatch_workgroups((n as u32).div_ceil(256), batch_size, 1);
307 }
308
309 for stage in 0..log2_n as usize {
310 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
311 label: Some("fft_butterfly_pass"),
312 timestamp_writes: None,
313 });
314 pass.set_pipeline(&self.pipeline_butterfly);
315 pass.set_bind_group(0, &cached.bind_groups[1 + stage], &[]);
316 pass.dispatch_workgroups(((n / 2) as u32).div_ceil(256), batch_size, 1);
317 }
318 }
319
320 fn build_fft_cache(
321 device: &Device,
322 bgl: &BindGroupLayout,
323 n: usize,
324 direction: FftDirection,
325 input_buf: &wgpu::Buffer,
326 output_buf: &wgpu::Buffer,
327 scratch_buf: &wgpu::Buffer,
328 ) -> FftCallCache {
329 let log2_n = n.trailing_zeros();
330 let dir = direction as u32;
331
332 let (buf0, buf1) = Self::get_buffer_pair_for_mode(log2_n, output_buf, scratch_buf);
333
334 let mut params = Vec::with_capacity(1 + log2_n as usize);
335 let mut bind_groups = Vec::with_capacity(1 + log2_n as usize);
336
337 let br_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
339 label: Some("bit_rev_params"),
340 contents: bytemuck::cast_slice(&[n as u32, log2_n, 0u32, 0u32]),
341 usage: wgpu::BufferUsages::UNIFORM,
342 });
343 let br_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
344 label: Some("fft_bit_rev_bg"),
345 layout: bgl,
346 entries: &[
347 wgpu::BindGroupEntry {
348 binding: 0,
349 resource: input_buf.as_entire_binding(),
350 },
351 wgpu::BindGroupEntry {
352 binding: 1,
353 resource: buf0.as_entire_binding(),
354 },
355 wgpu::BindGroupEntry {
356 binding: 2,
357 resource: br_params.as_entire_binding(),
358 },
359 ],
360 });
361 params.push(br_params);
362 bind_groups.push(br_bg);
363
364 let bufs = [buf0, buf1];
366 for stage in 0..log2_n {
367 let src = bufs[stage as usize % 2];
368 let dst = bufs[(stage as usize + 1) % 2];
369 let stage_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
370 label: Some(&format!("fft_stage{stage}_params")),
371 contents: bytemuck::cast_slice(&[n as u32, stage, dir, 0u32]),
372 usage: wgpu::BufferUsages::UNIFORM,
373 });
374 let stage_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
375 label: Some(&format!("fft_butterfly_bg_stage{stage}")),
376 layout: bgl,
377 entries: &[
378 wgpu::BindGroupEntry {
379 binding: 0,
380 resource: src.as_entire_binding(),
381 },
382 wgpu::BindGroupEntry {
383 binding: 1,
384 resource: dst.as_entire_binding(),
385 },
386 wgpu::BindGroupEntry {
387 binding: 2,
388 resource: stage_params.as_entire_binding(),
389 },
390 ],
391 });
392 params.push(stage_params);
393 bind_groups.push(stage_bg);
394 }
395
396 FftCallCache {
397 params,
398 bind_groups,
399 }
400 }
401
402 pub fn encode_normalize(
406 &self,
407 encoder: &mut wgpu::CommandEncoder,
408 n: usize,
409 batch_size: u32,
410 buf: &wgpu::Buffer,
411 ) {
412 let key = (n, buf as *const _ as usize);
413 {
414 let mut cache = self.norm_cache.borrow_mut();
415 if !cache.contains_key(&key) {
416 let params = self
417 .device
418 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
419 label: Some("normalize_params"),
420 contents: bytemuck::cast_slice(&[n as u32, 0u32, 0u32, 0u32]),
421 usage: wgpu::BufferUsages::UNIFORM,
422 });
423 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
424 label: Some("normalize_bg"),
425 layout: &self.bgl_norm,
426 entries: &[
427 wgpu::BindGroupEntry {
428 binding: 0,
429 resource: buf.as_entire_binding(),
430 },
431 wgpu::BindGroupEntry {
432 binding: 1,
433 resource: params.as_entire_binding(),
434 },
435 ],
436 });
437 cache.insert(
438 key,
439 FftNormCache {
440 params,
441 bind_group: bg,
442 },
443 );
444 }
445 }
446 let cache_guard = self.norm_cache.borrow();
447 let cached = cache_guard.get(&key).unwrap();
448 {
449 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
450 label: Some("normalize_pass"),
451 timestamp_writes: None,
452 });
453 pass.set_pipeline(&self.pipeline_normalize);
454 pass.set_bind_group(0, &cached.bind_group, &[]);
455 pass.dispatch_workgroups((n as u32).div_ceil(256), batch_size, 1);
456 }
457 }
458}