webgpu-groth16 0.1.1

Groth16 GPU prover aimed primarily at browser environments
Documentation
// src/shader/bls12_381/ntt.wgsl

// ============================================================================
// CONSTANTS & WORKGROUP SIZE
// ============================================================================
const THREADS_PER_WORKGROUP: u32 = 256u;
const ELEMENTS_PER_TILE: u32 = 512u;

// ============================================================================
// GPU PIPELINE BUFFERS
// ============================================================================

@group(0) @binding(0)
var<storage, read_write> data: array<U256>;

@group(0) @binding(1)
var<storage, read> twiddles: array<U256>;

struct NttParams {
    n: u32,
    half_len: u32,
    log_n: u32,
    _pad: u32,
}

@group(0) @binding(2)
var<uniform> params: NttParams;

@group(2) @binding(0)
var<storage, read> pointwise_a: array<U256>;

@group(2) @binding(1)
var<storage, read> pointwise_b: array<U256>;

@group(2) @binding(2)
var<storage, read> pointwise_c: array<U256>;

@group(2) @binding(3)
var<storage, read> pointwise_z_inv: array<U256>;

// Fast workgroup-shared memory arrays
var<workgroup> shared_data: array<U256, ELEMENTS_PER_TILE>;
var<workgroup> shared_twiddles: array<U256, THREADS_PER_WORKGROUP>;

fn add_fr(a: U256, b: U256) -> U256 {
    var sum = add_u256(a, b);
    var is_gte = true;
    for (var i: u32 = 7u; i < 8u; i = i - 1u) {
        if sum.limbs[i] > R_MODULUS[i] { break; }
        if sum.limbs[i] < R_MODULUS[i] { is_gte = false; break; }
        if i == 0u { break; }
    }
    if is_gte {
        sum = sub_u256(sum, U256(R_MODULUS));
    }
    return sum;
}

fn sub_fr(a: U256, b: U256) -> U256 {
    var is_less = false;
    for (var i: u32 = 7u; i < 8u; i = i - 1u) {
        if a.limbs[i] < b.limbs[i] { is_less = true; break; }
        if a.limbs[i] > b.limbs[i] { break; }
        if i == 0u { break; }
    }

    var diff = sub_u256(a, b);
    if is_less {
        diff = add_u256(diff, U256(R_MODULUS));
    }
    return diff;
}

// ============================================================================
// HELPER: BIT REVERSAL
// ============================================================================

// Reverses the `bit_len` lower bits of `n`.
fn reverse_bits(n: u32, bit_len: u32) -> u32 {
    return reverseBits(n) >> (32u - bit_len);
}

fn compute_pointwise_h(i: u32) -> U256 {
    let a = pointwise_a[i];
    let b = pointwise_b[i];
    let c = pointwise_c[i];
    let ab = mul_montgomery_u256(a, b);
    let ab_c = sub_fr(ab, c);
    return mul_montgomery_u256(ab_c, pointwise_z_inv[0u]);
}

@compute @workgroup_size(256)
fn bitreverse_inplace(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    let n = params.n;
    if i >= n {
        return;
    }

    let j = reverse_bits(i, params.log_n);
    if i < j {
        let tmp = data[i];
        data[i] = data[j];
        data[j] = tmp;
    }
}

@compute @workgroup_size(256)
fn bitreverse_fused_pointwise(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    let n = params.n;
    if i >= n {
        return;
    }

    let j = reverse_bits(i, params.log_n);
    data[j] = compute_pointwise_h(i);
}

@compute @workgroup_size(256)
fn ntt_global_stage(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let tid = global_id.x;
    let n = params.n;
    let half_len = params.half_len;
    if half_len == 0u {
        return;
    }

    let butterflies = n / 2u;
    if tid >= butterflies {
        return;
    }

    let len = half_len * 2u;
    let k = tid % half_len;
    let pos = (tid / half_len) * len + k;
    let twiddle_stride = n / len;
    let twiddle = twiddles[k * twiddle_stride];

    let u = data[pos];
    let v = data[pos + half_len];
    let v_omega = mul_montgomery_u256(v, twiddle);

    data[pos] = add_fr(u, v_omega);
    data[pos + half_len] = sub_fr(u, v_omega);
}

@compute @workgroup_size(256)
fn ntt_global_stage_dif(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let tid = global_id.x;
    let n = params.n;
    let half_len = params.half_len;
    if half_len == 0u {
        return;
    }

    let butterflies = n / 2u;
    if tid >= butterflies {
        return;
    }

    let len = half_len * 2u;
    let k = tid % half_len;
    let pos = (tid / half_len) * len + k;
    let twiddle_stride = n / len;
    let twiddle = twiddles[k * twiddle_stride];

    let u = data[pos];
    let v = data[pos + half_len];
    let sum = add_fr(u, v);
    let diff = sub_fr(u, v);

    data[pos] = sum;
    data[pos + half_len] = mul_montgomery_u256(diff, twiddle);
}

@compute @workgroup_size(256)
fn ntt_global_stage_dif_fused_pointwise(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let tid = global_id.x;
    let n = params.n;
    let half_len = params.half_len;
    if half_len == 0u {
        return;
    }

    let butterflies = n / 2u;
    if tid >= butterflies {
        return;
    }

    let len = half_len * 2u;
    let k = tid % half_len;
    let pos = (tid / half_len) * len + k;
    let twiddle_stride = n / len;
    let twiddle = twiddles[k * twiddle_stride];

    let u = compute_pointwise_h(pos);
    let v = compute_pointwise_h(pos + half_len);
    let sum = add_fr(u, v);
    let diff = sub_fr(u, v);

    data[pos] = sum;
    data[pos + half_len] = mul_montgomery_u256(diff, twiddle);
}

@compute @workgroup_size(256)
fn ntt_global_stage_radix4(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let tid = global_id.x;
    let n = params.n;
    let half_len = params.half_len;
    if half_len == 0u {
        return;
    }

    let butterflies = n / 4u;
    if tid >= butterflies {
        return;
    }

    let len4 = half_len * 4u;
    let k = tid % half_len;
    let pos = (tid / half_len) * len4 + k;

    let tw1_stride = n / (half_len * 2u);
    let tw2_stride = n / len4;

    let w1 = twiddles[k * tw1_stride];
    let w2 = twiddles[k * tw2_stride];
    let w3 = twiddles[(k + half_len) * tw2_stride];

    let x0 = data[pos];
    let x1 = data[pos + half_len];
    let x2 = data[pos + (2u * half_len)];
    let x3 = data[pos + (3u * half_len)];

    let x1w = mul_montgomery_u256(x1, w1);
    let x3w = mul_montgomery_u256(x3, w1);

    let a0 = add_fr(x0, x1w);
    let a1 = sub_fr(x0, x1w);
    let a2 = add_fr(x2, x3w);
    let a3 = sub_fr(x2, x3w);

    let a2w = mul_montgomery_u256(a2, w2);
    let a3w = mul_montgomery_u256(a3, w3);

    data[pos] = add_fr(a0, a2w);
    data[pos + half_len] = add_fr(a1, a3w);
    data[pos + (2u * half_len)] = sub_fr(a0, a2w);
    data[pos + (3u * half_len)] = sub_fr(a1, a3w);
}

// ============================================================================
// NTT TILE EXECUTION (Cooley-Tukey Radix-2)
// ============================================================================

@compute @workgroup_size(THREADS_PER_WORKGROUP)
fn ntt_tile(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>,
    @builtin(workgroup_id) group_id: vec3<u32>
) {
    let tile_offset = group_id.x * ELEMENTS_PER_TILE;
    let n_total = arrayLength(&twiddles);
    let n = min(ELEMENTS_PER_TILE, n_total - tile_offset);
    if n == 0u { return; }
    if n > ELEMENTS_PER_TILE { return; }

    var log2_elements = 0u;
    var m = n;
    while m > 1u {
        m = m >> 1u;
        log2_elements = log2_elements + 1u;
    }

    let local_idx_1 = local_id.x;
    let local_idx_2 = local_id.x + THREADS_PER_WORKGROUP;

    var load_idx_1 = local_idx_1;
    var load_idx_2 = local_idx_2;
    if n_total <= ELEMENTS_PER_TILE {
        load_idx_1 = reverse_bits(local_idx_1, log2_elements);
        load_idx_2 = reverse_bits(local_idx_2, log2_elements);
    }

    if local_idx_1 < n {
        shared_data[load_idx_1] = data[tile_offset + local_idx_1];
    }
    if local_idx_2 < n {
        shared_data[load_idx_2] = data[tile_offset + local_idx_2];
    }

    // The twiddles needed for all stages in this tile are subsets of the 
    // twiddles needed for the final stage. We load them exactly once.
    let twiddle_base_stride = n_total / n;
    if local_id.x < n / 2u {
        shared_twiddles[local_id.x] = twiddles[local_id.x * twiddle_base_stride];
    }

    workgroupBarrier();

    var half_len: u32 = 1u;
    for (var stage: u32 = 0u; stage < log2_elements; stage = stage + 1u) {
        let len = half_len * 2u;

        let butterfly_count = n / 2u;
        if local_id.x < butterfly_count {
            let k = local_id.x % half_len;
            let pos = (local_id.x / half_len) * len + k;

            // Fetch from shared memory using a localized stride!
            let shared_twiddle_stride = n / len;
            let twiddle = shared_twiddles[k * shared_twiddle_stride];

            let u = shared_data[pos];
            let v = shared_data[pos + half_len];

            let v_omega = mul_montgomery_u256(v, twiddle);
            shared_data[pos] = add_fr(u, v_omega);
            shared_data[pos + half_len] = sub_fr(u, v_omega);
        }

        half_len = len;
        workgroupBarrier();
    }

    if local_idx_1 < n {
        data[tile_offset + local_idx_1] = shared_data[local_idx_1];
    }
    if local_idx_2 < n {
        data[tile_offset + local_idx_2] = shared_data[local_idx_2];
    }
}