scirs2-fft 0.4.3

Fast Fourier Transform module for SciRS2 (scirs2-fft)
Documentation
// Cooley-Tukey radix-2 DIT FFT compute shader.
//
// One thread handles one butterfly pair.  The host dispatches this shader
// once per stage (0 to log2(n)−1), updating the `params.stage` uniform
// between dispatches.
//
// Complex numbers are stored as vec2<f32> in a storage buffer:
//   data[i].x = real part
//   data[i].y = imaginary part
//
// Twiddle factor for butterfly k in stage s:
//   W = exp(sign * 2π * i * k / n)
//   where sign = +1 for inverse FFT, −1 for forward FFT.
//
// Workgroup size: 64 threads. Host must dispatch ceil(n/2 / 64) workgroups.

struct FFTParams {
    /// Total number of complex samples (must be a power of two).
    n: u32,
    /// Current butterfly stage index (0 = first stage, log2(n)−1 = last).
    stage: u32,
    /// 0 → forward FFT (twiddle sign = −1)
    /// 1 → inverse FFT (twiddle sign = +1)
    inverse: u32,
    /// Padding to satisfy 16-byte alignment of the uniform buffer.
    _pad: u32,
}

@group(0) @binding(0) var<storage, read_write> data: array<vec2<f32>>;
@group(0) @binding(1) var<uniform> params: FFTParams;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let k = gid.x;
    let n = params.n;
    let half_n = n >> 1u;

    // Each thread handles exactly one butterfly pair.  Exit early if this
    // thread index is out of range.
    if k >= half_n {
        return;
    }

    let stage  = params.stage;
    let stride = 1u << stage;           // distance between the two elements of a butterfly
    let group  = k / stride;            // which butterfly group this thread belongs to
    let pos    = k % stride;            // position within the group

    // Indices of the butterfly pair.
    let i = group * (stride << 1u) + pos;
    let j = i + stride;

    // Twiddle factor exponent:
    //   forward:  angle = −2π * group * stride / n  (negative exponent)
    //   inverse:  angle = +2π * group * stride / n  (positive exponent)
    let sign  = select(-1.0, 1.0, params.inverse != 0u);
    let angle = sign * 6.283185307179586 * f32(group * stride) / f32(n);

    let tw = vec2<f32>(cos(angle), sin(angle));

    let a = data[i];
    let b = data[j];

    // Complex multiply: bt = b * tw
    let bt = vec2<f32>(
        b.x * tw.x - b.y * tw.y,
        b.x * tw.y + b.y * tw.x,
    );

    // Butterfly output:
    //   data[i] = a + bt
    //   data[j] = a - bt
    data[i] = a + bt;
    data[j] = a - bt;
}