webgpu-groth16 0.1.1

Groth16 GPU prover aimed primarily at browser environments
Documentation
// src/shader/bls12_381/msm_g2_agg.wgsl

const U384_ZERO = U384(array<u32, 30>(
    0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u,
    0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u,
    0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u
));

const FQ2_ZERO = Fq2(U384_ZERO, U384_ZERO);
const G2_INFINITY = PointG2(FQ2_ZERO, FQ2_ZERO, FQ2_ZERO);

const FQ2_ONE_MONT = Fq2(MONT_ONE, U384_ZERO);
const G2_PROJ_IDENTITY = PointG2(FQ2_ZERO, FQ2_ONE_MONT, FQ2_ZERO);

const R2_MOD_Q = U384(array<u32, 30>(
    0x070fu, 0x0880u, 0x10d1u, 0x0c83u, 0x1aecu, 0x1121u,
    0x004cu, 0x1874u, 0x066eu, 0x1b75u, 0x01ebu, 0x1beau,
    0x07b1u, 0x1f70u, 0x117bu, 0x0362u, 0x0ed2u, 0x090fu,
    0x110au, 0x1482u, 0x0f70u, 0x1699u, 0x05dcu, 0x1200u,
    0x0c97u, 0x0c8cu, 0x12b3u, 0x1dc0u, 0x1696u, 0x0007u
));

const MONT_ONE = U384(array<u32, 30>(
    0x1f2eu, 0x068fu, 0x0000u, 0x0c00u, 0x0467u, 0x0056u,
    0x0d20u, 0x06f3u, 0x1803u, 0x0425u, 0x10c7u, 0x1104u,
    0x1e0eu, 0x0cd3u, 0x0037u, 0x1b9fu, 0x1683u, 0x1685u,
    0x1b09u, 0x1d84u, 0x0a5eu, 0x11e2u, 0x15d9u, 0x1e28u,
    0x0b29u, 0x1402u, 0x1fcfu, 0x132cu, 0x15deu, 0x0000u
));

fn to_montgomery_u384(a: U384) -> U384 {
    return mul_montgomery_u384(a, R2_MOD_Q);
}

fn to_montgomery_fp2(a: Fq2) -> Fq2 { return Fq2(to_montgomery_u384(a.c0), to_montgomery_u384(a.c1)); }

fn is_inf_g2(p: PointG2) -> bool {
    for (var i = 0u; i < 30u; i = i + 1u) {
        if p.z.c0.limbs[i] != 0u || p.z.c1.limbs[i] != 0u { return false; }
    }
    return true;
}

fn add_g2_safe(p1: PointG2, p2: PointG2) -> PointG2 {
    let p1_inf = is_inf_g2(p1); let p2_inf = is_inf_g2(p2);
    if p1_inf && p2_inf { return G2_INFINITY; }
    if p1_inf { return p2; } if p2_inf { return p1; }
    return add_g2(p1, p2);
}

fn add_g2_mixed_safe(p1: PointG2, p2_affine: PointG2) -> PointG2 {
    if is_inf_g2(p1) { return p2_affine; }
    return add_g2_mixed(p1, p2_affine);
}

@group(0) @binding(0) var<storage, read_write> bases_preconv_g2: array<PointG2>;

@compute @workgroup_size(64)
fn to_montgomery_bases_g2(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&bases_preconv_g2) { return; }
    let p = bases_preconv_g2[i];
    if is_inf_g2(p) { return; }
    bases_preconv_g2[i] = PointG2(
        to_montgomery_fp2(p.x),
        to_montgomery_fp2(p.y),
        to_montgomery_fp2(p.z)
    );
}

fn load_g2_mont(p: PointG2) -> PointG2 {
    if is_inf_g2(p) { return G2_INFINITY; }
    return p;
}

@group(0) @binding(0) var<storage, read> bases_g2: array<PointG2>;
@group(0) @binding(1) var<storage, read> base_indices_g2: array<u32>;
@group(0) @binding(2) var<storage, read> bucket_pointers_g2: array<u32>;
@group(0) @binding(3) var<storage, read> bucket_sizes_g2: array<u32>;
@group(0) @binding(4) var<storage, read_write> aggregated_buckets_g2: array<PointG2>;
@group(0) @binding(5) var<storage, read> bucket_values_agg_g2: array<u32>;

fn negate_g2(p: PointG2) -> PointG2 {
    return PointG2(
        p.x,
        Fq2(sub_u384(U384(Q_MODULUS), p.y.c0), sub_u384(U384(Q_MODULUS), p.y.c1)),
        p.z
    );
}

@compute @workgroup_size(64)
fn aggregate_buckets_g2(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let bucket_idx = global_id.x;
    if bucket_idx >= arrayLength(&bucket_pointers_g2) { return; }

    let start = bucket_pointers_g2[bucket_idx];
    let size = bucket_sizes_g2[bucket_idx];
    var sum = G2_INFINITY;
    for (var i = 0u; i < size; i = i + 1u) {
        let raw = base_indices_g2[start + i];
        let point_idx = raw & 0x7FFFFFFFu;
        let is_neg = (raw >> 31u) != 0u;
        var base = load_g2_mont(bases_g2[point_idx]);
        if is_neg {
            base = negate_g2(base);
        }
        sum = add_g2_mixed_safe(sum, base);
    }
    aggregated_buckets_g2[bucket_idx] = sum;
}

@group(0) @binding(0) var<storage, read> reduce_input_g2: array<PointG2>;
@group(0) @binding(1) var<storage, read> reduce_starts_g2: array<u32>;
@group(0) @binding(2) var<storage, read> reduce_counts_g2: array<u32>;
@group(0) @binding(3) var<storage, read_write> reduce_output_g2: array<PointG2>;

@compute @workgroup_size(64)
fn reduce_sub_buckets_g2(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let parent_idx = global_id.x;
    if parent_idx >= arrayLength(&reduce_starts_g2) { return; }

    let start = reduce_starts_g2[parent_idx];
    let count = reduce_counts_g2[parent_idx];

    if count == 1u {
        reduce_output_g2[parent_idx] = reduce_input_g2[start];
        return;
    }

    var sum = reduce_input_g2[start];
    for (var i = 1u; i < count; i = i + 1u) {
        sum = add_g2_safe(sum, reduce_input_g2[start + i]);
    }
    reduce_output_g2[parent_idx] = sum;
}

fn jacobian_to_proj_g2(p: PointG2) -> PointG2 {
    if is_inf_g2(p) { return G2_PROJ_IDENTITY; }
    let x_proj = mul_fp2(p.x, p.z);
    let z_sq = sqr_fp2(p.z);
    let z_proj = mul_fp2(z_sq, p.z);
    return PointG2(x_proj, p.y, z_proj);
}

fn scalar_mul_g2(p: PointG2, k: u32) -> PointG2 {
    if k == 0u { return G2_PROJ_IDENTITY; }
    if k == 1u { return p; }
    var result = G2_PROJ_IDENTITY;
    var base = p;
    var scalar = k;
    for (var bit = 0u; bit < 16u; bit = bit + 1u) {
        if scalar == 0u { break; }
        if (scalar & 1u) != 0u {
            result = add_g2_complete(result, base);
        }
        base = add_g2_complete(base, base);
        scalar = scalar >> 1u;
    }
    return result;
}

@group(0) @binding(0) var<storage, read_write> weight_buckets_g2_data: array<PointG2>;
@group(0) @binding(1) var<storage, read> weight_bucket_values_g2: array<u32>;

@compute @workgroup_size(64)
fn weight_buckets_g2(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let i = global_id.x;
    if i >= arrayLength(&weight_buckets_g2_data) { return; }
    let p = jacobian_to_proj_g2(weight_buckets_g2_data[i]);
    weight_buckets_g2_data[i] = scalar_mul_g2(p, weight_bucket_values_g2[i]);
}