webgpu-groth16 0.1.1

Groth16 GPU prover aimed primarily at browser environments
Documentation
// src/shader/bls12_381/poly_ops.wgsl

// Helper for F_r modular subtraction (handles underflow safely)
fn sub_fr(a: U256, b: U256) -> U256 {
    var is_less = false;
    for (var i = 7u; i < 8u; i = i - 1u) {
        if a.limbs[i] < b.limbs[i] { is_less = true; break; }
        if a.limbs[i] > b.limbs[i] { break; }
        if i == 0u { break; }
    }

    var diff = sub_u256(a, b);
    if is_less {
        diff = add_u256(diff, U256(R_MODULUS));
    }
    return diff;
}

// ============================================================================
// COSET SHIFT PIPELINE
// ============================================================================
@group(0) @binding(0) var<storage, read_write> shift_data: array<U256>;
@group(0) @binding(1) var<storage, read> shift_factors: array<U256>;

@compute @workgroup_size(256)
fn coset_shift(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&shift_factors) { return; }

    shift_data[i] = mul_montgomery_u256(shift_data[i], shift_factors[i]);
}

// ============================================================================
// POINTWISE POLYNOMIAL MATH PIPELINE
// H[i] = (A[i] * B[i] - C[i]) / Z_H[i]
// ============================================================================
@group(0) @binding(0) var<storage, read> A: array<U256>;
@group(0) @binding(1) var<storage, read> B: array<U256>;
@group(0) @binding(2) var<storage, read> C: array<U256>;
@group(0) @binding(3) var<storage, read_write> H: array<U256>;
@group(0) @binding(4) var<storage, read> Z_invs: array<U256>; // [z_inv_on_coset]

@compute @workgroup_size(256)
fn pointwise_poly(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&A) { return; }

    let a = A[i];
    let b = B[i];
    let c = C[i];

    let ab = mul_montgomery_u256(a, b);
    let res = sub_fr(ab, c);

    H[i] = mul_montgomery_u256(res, Z_invs[0]);
}

// ============================================================================
// MONTGOMERY DOMAIN BRIDGES
// ============================================================================

// R^2 mod r for the BLS12-381 scalar field (F_r).
// Required to convert Standard Form scalars into Montgomery Form.
const R2_MOD_R = U256(array<u32, 8>(
    0xf3f29c6du, 0xc999e990u, 0x87925c23u, 0x2b6cedcbu,
    0x7254398fu, 0x05d31496u, 0x9f59ff11u, 0x0748d9d9u
));

fn to_montgomery_u256(a: U256) -> U256 {
    return mul_montgomery_u256(a, R2_MOD_R);
}

fn from_montgomery_u256(a: U256) -> U256 {
    let one = U256(array<u32, 8>(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u));
    return mul_montgomery_u256(a, one);
}

@group(0) @binding(0) var<storage, read_write> mont_buf: array<U256>;

@compute @workgroup_size(256)
fn to_montgomery_array(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&mont_buf) { return; }
    mont_buf[i] = to_montgomery_u256(mont_buf[i]);
}

@compute @workgroup_size(256)
fn from_montgomery_array(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&mont_buf) { return; }
    mont_buf[i] = from_montgomery_u256(mont_buf[i]);
}