use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::error::Result;
use wgpu::{Buffer, Queue};
pub const MAX_WORKGROUP_FFT_SIZE: usize = 256;
const STOCKHAM_FFT_SHADER: &str = include_str!("stockham_fft.wgsl");
const FFTSHIFT_SHADER: &str = include_str!("fftshift.wgsl");
const RFFT_PACK_SHADER: &str = include_str!("rfft_pack.wgsl");
const IRFFT_UNPACK_SHADER: &str = include_str!("irfft_unpack.wgsl");
const HERMITIAN_EXTEND_SHADER: &str = include_str!("hermitian_extend.wgsl");
const RFFT_TRUNCATE_SHADER: &str = include_str!("rfft_truncate.wgsl");
const COPY_COMPLEX_SHADER: &str = include_str!("copy_complex.wgsl");
pub fn launch_stockham_fft_batched(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
if n > MAX_WORKGROUP_FFT_SIZE {
return Err(crate::error::Error::Internal(format!(
"FFT size {} exceeds max workgroup FFT size {}",
n, MAX_WORKGROUP_FFT_SIZE
)));
}
let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = pipeline_cache.get_or_create_pipeline(
"stockham_fft",
"stockham_fft_small",
&module,
&layout,
);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fft_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("stockham_fft_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_stockham_fft_stage(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = pipeline_cache.get_or_create_pipeline(
"stockham_fft",
"stockham_fft_stage",
&module,
&layout,
);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fft_stage_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fft_stage_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let butterflies = n / 2;
pass.dispatch_workgroups(workgroup_count(butterflies), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scale_complex(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
pipeline_cache.get_or_create_pipeline("stockham_fft", "scale_complex", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scale_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scale_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_fftshift(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("fftshift", FFTSHIFT_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = pipeline_cache.get_or_create_pipeline("fftshift", "fftshift", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fftshift_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fftshift_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_ifftshift(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("fftshift", FFTSHIFT_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = pipeline_cache.get_or_create_pipeline("fftshift", "ifftshift", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("ifftshift_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("ifftshift_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_rfft_pack(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("rfft_pack", RFFT_PACK_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
pipeline_cache.get_or_create_pipeline("rfft_pack", "rfft_pack", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rfft_pack_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rfft_pack_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_irfft_unpack(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("irfft_unpack", IRFFT_UNPACK_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
pipeline_cache.get_or_create_pipeline("irfft_unpack", "irfft_unpack", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("irfft_unpack_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("irfft_unpack_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_hermitian_extend(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("hermitian_extend", HERMITIAN_EXTEND_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = pipeline_cache.get_or_create_pipeline(
"hermitian_extend",
"hermitian_extend",
&module,
&layout,
);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("hermitian_extend_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("hermitian_extend_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_rfft_truncate(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
half_n: usize,
batch_size: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("rfft_truncate", RFFT_TRUNCATE_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
pipeline_cache.get_or_create_pipeline("rfft_truncate", "rfft_truncate", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rfft_truncate_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rfft_truncate_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(half_n), batch_size as u32, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_copy_complex(
pipeline_cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params: &Buffer,
n: usize,
) -> Result<()> {
let module = pipeline_cache.get_or_create_module("copy_complex", COPY_COMPLEX_SHADER);
let layout = pipeline_cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
pipeline_cache.get_or_create_pipeline("copy_complex", "copy_complex", &module, &layout);
let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]);
let mut encoder =
pipeline_cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy_complex_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("copy_complex_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}