use std::cell::RefCell;
use wgpu::util::DeviceExt;
use wgpu::{
BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
BufferBindingType, ComputePipeline, Device, ShaderStages,
};
use crate::shaders;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FftDirection {
Forward = 0,
Inverse = 1,
}
fn fft_storage_entry(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
BindGroupLayoutEntry {
binding,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn fft_uniform_entry(binding: u32) -> BindGroupLayoutEntry {
BindGroupLayoutEntry {
binding,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn fft_make_pipeline(
device: &Device,
label: &str,
bgl: &BindGroupLayout,
src: &str,
) -> ComputePipeline {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("{label}_shader")),
source: wgpu::ShaderSource::Wgsl(src.into()),
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("{label}_layout")),
bind_group_layouts: &[Some(bgl)],
immediate_size: 0,
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
})
}
pub struct FftPipelines {
device: Device,
pub queue: wgpu::Queue,
pipeline_butterfly: ComputePipeline,
pipeline_bit_reverse: ComputePipeline,
pipeline_normalize: ComputePipeline,
bgl: BindGroupLayout,
bgl_norm: BindGroupLayout,
scratch: RefCell<std::collections::HashMap<usize, wgpu::Buffer>>,
}
impl FftPipelines {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
let instance = wgpu::Instance::default();
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))
.or_else(|_| {
pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: true,
}))
})?;
let (device, queue) =
pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
..Default::default()
}))?;
Ok(Self::from_device_queue(device, queue))
}
pub fn from_device_queue(device: Device, queue: wgpu::Queue) -> Self {
let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("fft_pipelines_bgl"),
entries: &[
fft_storage_entry(0, true),
fft_storage_entry(1, false),
fft_uniform_entry(2),
],
});
let bgl_norm = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("fft_pipelines_norm_bgl"),
entries: &[fft_storage_entry(0, false), fft_uniform_entry(1)],
});
let pipeline_butterfly = fft_make_pipeline(
&device,
"fft_butterfly",
&bgl,
shaders::COOLEY_TUKEY_R2_WGSL,
);
let pipeline_bit_reverse =
fft_make_pipeline(&device, "fft_bit_reverse", &bgl, shaders::BIT_REVERSAL_WGSL);
let pipeline_normalize = fft_make_pipeline(
&device,
"fft_normalize",
&bgl_norm,
shaders::NORMALIZE_VEC2_WGSL,
);
Self {
device,
queue,
pipeline_butterfly,
pipeline_bit_reverse,
pipeline_normalize,
bgl,
bgl_norm,
scratch: RefCell::new(std::collections::HashMap::new()),
}
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn queue(&self) -> &wgpu::Queue {
&self.queue
}
pub fn encode_fft(
&self,
encoder: &mut wgpu::CommandEncoder,
n: usize,
direction: FftDirection,
input_buf: &wgpu::Buffer,
output_buf: &wgpu::Buffer,
) {
let log2_n = n.trailing_zeros();
let dir = direction as u32;
let byte_size = (n * 8) as u64;
{
let mut map = self.scratch.borrow_mut();
map.entry(n).or_insert_with(|| {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("fft_scratch"),
size: byte_size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
});
}
let scratch_map = self.scratch.borrow();
let scratch_buf = scratch_map.get(&n).unwrap();
let (buf0, buf1): (&wgpu::Buffer, &wgpu::Buffer) = if log2_n % 2 == 0 {
(output_buf, scratch_buf)
} else {
(scratch_buf, output_buf)
};
let br_params = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("bit_rev_params"),
contents: bytemuck::cast_slice(&[n as u32, log2_n, 0u32, 0u32]),
usage: wgpu::BufferUsages::UNIFORM,
});
{
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("fft_bit_rev_bg"),
layout: &self.bgl,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buf0.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: br_params.as_entire_binding(),
},
],
});
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bit_reversal_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline_bit_reverse);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups((n as u32).div_ceil(256), 1, 1);
}
let bufs = [buf0, buf1];
for stage in 0..log2_n {
let src = bufs[stage as usize % 2];
let dst = bufs[(stage as usize + 1) % 2];
let fft_params = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(&format!("fft_stage{stage}_params")),
contents: bytemuck::cast_slice(&[n as u32, stage, dir, 0u32]),
usage: wgpu::BufferUsages::UNIFORM,
});
{
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("fft_butterfly_bg_stage{stage}")),
layout: &self.bgl,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: src.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: dst.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: fft_params.as_entire_binding(),
},
],
});
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&format!("fft_butterfly_stage{stage}")),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline_butterfly);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(((n / 2) as u32).div_ceil(256), 1, 1);
}
}
}
pub fn encode_normalize(
&self,
encoder: &mut wgpu::CommandEncoder,
n: usize,
buf: &wgpu::Buffer,
) {
let params = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("normalize_params"),
contents: bytemuck::cast_slice(&[n as u32, 0u32, 0u32, 0u32]),
usage: wgpu::BufferUsages::UNIFORM,
});
{
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("normalize_bg"),
layout: &self.bgl_norm,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: params.as_entire_binding(),
},
],
});
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("normalize_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline_normalize);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups((n as u32).div_ceil(256), 1, 1);
}
}
}