wgsl-fft 0.4.4

GPU-accelerated FFT using Webgpu compute shaders
Documentation
use std::any::Any;

use crate::error::Result;
use crate::{FftExecutor, GpuFft};
use num_complex::Complex;

const RADIX4_WGSL: &str = r#"
@group(0) @binding(0)
var<uniform> U: vec4<u32>;
@group(0) @binding(1)
var<storage, read_write> SRC: array<f32>;
@group(0) @binding(2)
var<storage, read_write> DST: array<f32>;
@group(0) @binding(3)
var<storage, read> TWIDDLE: array<f32>;

@compute @workgroup_size(256, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let tid = gid.x;
    let batch_id = gid.y;
    let n = U.x;
    let stage = U.y;
    let batch_offset = batch_id * n * 2u;

    // For Radix-4, we only process even stages (0, 2, 4, ...)
    // Odd stages (1, 3, 5, ...) are identity operations (copy SRC to DST)
    if stage % 2u != 0u {
        let idx1 = tid;
        let idx2 = tid + (n >> 1u);

        if idx1 < n {
            let s1 = batch_offset + 2u * idx1;
            DST[s1] = SRC[s1];
            DST[s1 + 1u] = SRC[s1 + 1u];
        }
        if idx2 < n {
            let s2 = batch_offset + 2u * idx2;
            DST[s2] = SRC[s2];
            DST[s2 + 1u] = SRC[s2 + 1u];
        }
        return;
    }

    let quarter_n = n >> 2u;
    if tid >= quarter_n {
        return;
    }

    let r4_stage = stage / 2u;
    let p = 1u << (r4_stage + r4_stage);
    let four_p = p << 2u;

    let k = tid % p;
    let j = tid / p;

    let i0 = j * p + k;
    let i1 = i0 + quarter_n;
    let i2 = i0 + quarter_n + quarter_n;
    let i3 = i2 + quarter_n;

    let s0 = batch_offset + 2u * i0;
    let s1 = batch_offset + 2u * i1;
    let s2 = batch_offset + 2u * i2;
    let s3 = batch_offset + 2u * i3;

    let x0r = SRC[s0];
    let x0i = SRC[s0 + 1u];
    let x1r = SRC[s1];
    let x1i = SRC[s1 + 1u];
    let x2r = SRC[s2];
    let x2i = SRC[s2 + 1u];
    let x3r = SRC[s3];
    let x3i = SRC[s3 + 1u];

    let stride = quarter_n >> (r4_stage + r4_stage);
    let tw1 = k * stride;
    let tw2 = tw1 * 2u;
    let tw3 = tw1 * 3u;

    let wr1 = TWIDDLE[2u * tw1];
    let wi1 = TWIDDLE[2u * tw1 + 1u];
    let wr2 = TWIDDLE[2u * tw2];
    let wi2 = TWIDDLE[2u * tw2 + 1u];
    let wr3 = TWIDDLE[2u * tw3];
    let wi3 = TWIDDLE[2u * tw3 + 1u];

    let br = wr1 * x1r - wi1 * x1i;
    let bi = wr1 * x1i + wi1 * x1r;
    let cr = wr2 * x2r - wi2 * x2i;
    let ci = wr2 * x2i + wi2 * x2r;
    let dr = wr3 * x3r - wi3 * x3i;
    let di = wr3 * x3i + wi3 * x3r;

    let o0 = j * four_p + k;
    let o1 = o0 + p;
    let o2 = o0 + p + p;
    let o3 = o2 + p;

    let d0 = batch_offset + 2u * o0;
    let d1 = batch_offset + 2u * o1;
    let d2 = batch_offset + 2u * o2;
    let d3 = batch_offset + 2u * o3;

    DST[d0] = x0r + br + cr + dr;
    DST[d0 + 1u] = x0i + bi + ci + di;
    DST[d1] = x0r + bi - cr - di;
    DST[d1 + 1u] = x0i - br - ci + dr;
    DST[d2] = x0r - br + cr - dr;
    DST[d2 + 1u] = x0i - bi + ci - di;
    DST[d3] = x0r - bi - cr + di;
    DST[d3 + 1u] = x0i + br - ci - dr;
}
"#;

pub struct Radix4Rival(pub GpuFft);

impl Radix4Rival {
    pub fn new() -> Self {
        Self(GpuFft::with_shader(RADIX4_WGSL.to_string(), "radix4_rival").unwrap())
    }
}

impl FftExecutor for Radix4Rival {
    fn name(&self) -> &str {
        "Radix-4 Rival"
    }

    fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
        self.0.fft(inputs)
    }

    fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>> {
        self.0.ifft(inputs)
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

impl crate::GpuFftTrait for Radix4Rival {
    fn benchmark_gpu_only(
        &self,
        sc: &crate::SizeCache,
        batch_size: u32,
        n: usize,
        warmup_iters: usize,
        bench_iters: usize,
    ) -> Result<f64> {
        self.0
            .benchmark_gpu_only(sc, batch_size, n, warmup_iters, bench_iters)
    }

    fn get_or_build_size_cache(&self, n: usize, log_n: u32) -> crate::SizeCache {
        self.0.get_or_build_size_cache(n, log_n)
    }

    fn prepare_input_data(&self, input: &[Complex<f32>], inverse: bool) -> Vec<f32> {
        self.0.prepare_input_data(input, inverse)
    }

    fn queue(&self) -> &wgpu::Queue {
        self.0.queue()
    }
}