numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Stockham FFT shader for WebGPU
// Complex numbers as vec2<f32> (re, im)

const PI: f32 = 3.14159265358979323846;
const WORKGROUP_SIZE: u32 = 256u;

struct FftParams {
    n: u32,
    log_n: u32,
    inverse: i32,
    scale: f32,
    batch_size: u32,
    _pad1: u32,
    _pad2: u32,
    _pad3: u32,
}

@group(0) @binding(0) var<storage, read_write> fft_input: array<vec2<f32>>;
@group(0) @binding(1) var<storage, read_write> fft_output: array<vec2<f32>>;
@group(0) @binding(2) var<uniform> fft_params: FftParams;

// Workgroup shared memory for ping-pong
var<workgroup> smem_a: array<vec2<f32>, 256>;
var<workgroup> smem_b: array<vec2<f32>, 256>;

// Complex number helpers (vec2: x=real, y=imag)
fn cmul(a: vec2<f32>, b: vec2<f32>) -> vec2<f32> {
    return vec2<f32>(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}

fn cadd(a: vec2<f32>, b: vec2<f32>) -> vec2<f32> {
    return a + b;
}

fn csub(a: vec2<f32>, b: vec2<f32>) -> vec2<f32> {
    return a - b;
}

fn cscale(a: vec2<f32>, s: f32) -> vec2<f32> {
    return vec2<f32>(a.x * s, a.y * s);
}

fn cconj(a: vec2<f32>) -> vec2<f32> {
    return vec2<f32>(a.x, -a.y);
}

// Compute e^(i*theta) = cos(theta) + i*sin(theta)
fn cexp_i(theta: f32) -> vec2<f32> {
    return vec2<f32>(cos(theta), sin(theta));
}

@compute @workgroup_size(WORKGROUP_SIZE)
fn stockham_fft_small(
    @builtin(workgroup_id) wg_id: vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>
) {
    let batch_idx = wg_id.x;
    let tid = local_id.x;
    let n = fft_params.n;
    let log_n = fft_params.log_n;
    let inverse = fft_params.inverse;
    let scale_factor = fft_params.scale;

    // Sign for twiddle factor
    let sign = select(-1.0, 1.0, inverse != 0);

    // Load input to shared memory
    let base_offset = batch_idx * n;
    for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {
        smem_a[i] = fft_input[base_offset + i];
    }
    workgroupBarrier();

    // Perform Stockham FFT stages
    var use_a = true;
    for (var stage: u32 = 0u; stage < log_n; stage = stage + 1u) {
        let m = 1u << (stage + 1u);
        let half_m = 1u << stage;

        for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {
            let group = i / half_m;
            let pair = i % half_m;

            let even_idx = group * half_m + pair;
            let odd_idx = even_idx + n / 2u;

            let out_even_idx = group * m + pair;
            let out_odd_idx = out_even_idx + half_m;

            // Twiddle factor
            let theta = sign * 2.0 * PI * f32(pair) / f32(m);
            let twiddle = cexp_i(theta);

            var even_val: vec2<f32>;
            var odd_val: vec2<f32>;

            if (use_a) {
                even_val = smem_a[even_idx];
                odd_val = cmul(smem_a[odd_idx], twiddle);
            } else {
                even_val = smem_b[even_idx];
                odd_val = cmul(smem_b[odd_idx], twiddle);
            }

            let sum = cadd(even_val, odd_val);
            let diff = csub(even_val, odd_val);

            if (use_a) {
                smem_b[out_even_idx] = sum;
                smem_b[out_odd_idx] = diff;
            } else {
                smem_a[out_even_idx] = sum;
                smem_a[out_odd_idx] = diff;
            }
        }

        workgroupBarrier();
        use_a = !use_a;
    }

    // Write output with scaling
    for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {
        var result: vec2<f32>;
        if (use_a) {
            result = smem_a[i];
        } else {
            result = smem_b[i];
        }
        fft_output[base_offset + i] = cscale(result, scale_factor);
    }
}

// Single stage kernel for large FFTs (N > workgroup FFT size)
@compute @workgroup_size(WORKGROUP_SIZE)
fn stockham_fft_stage(
    @builtin(global_invocation_id) gid: vec3<u32>
) {
    let n = fft_params.n;
    let stage = fft_params.log_n;  // Reuse log_n as current stage
    let inverse = fft_params.inverse;
    let batch_idx = gid.y;

    let sign = select(-1.0, 1.0, inverse != 0);

    let m = 1u << (stage + 1u);
    let half_m = 1u << stage;

    let i = gid.x;
    if (i >= n / 2u) {
        return;
    }

    let group = i / half_m;
    let pair = i % half_m;

    let base_offset = batch_idx * n;
    let even_idx = base_offset + group * half_m + pair;
    let odd_idx = even_idx + n / 2u;

    let out_even_idx = base_offset + group * m + pair;
    let out_odd_idx = out_even_idx + half_m;

    // Twiddle factor
    let theta = sign * 2.0 * PI * f32(pair) / f32(m);
    let twiddle = cexp_i(theta);

    let even_val = fft_input[even_idx];
    let odd_val = cmul(fft_input[odd_idx], twiddle);

    fft_output[out_even_idx] = cadd(even_val, odd_val);
    fft_output[out_odd_idx] = csub(even_val, odd_val);
}

// Scale complex array
@compute @workgroup_size(WORKGROUP_SIZE)
fn scale_complex(
    @builtin(global_invocation_id) gid: vec3<u32>
) {
    let idx = gid.x;
    let n = fft_params.n;
    let scale_factor = fft_params.scale;

    if (idx < n) {
        fft_output[idx] = cscale(fft_input[idx], scale_factor);
    }
}