//! Parallel Sumcheck Polynomial Computation for Binary Extension Fields
//!
//! This shader computes sumcheck polynomial contributions in parallel.
//! Each workgroup processes one opened row and produces its local basis polynomial.
//!
//! Algorithm (per row i):
//! 1. Compute tensorized dot product: dot = ⟨row, L(v_challenges)⟩
//! 2. Scale by alpha^i: contribution = dot * alpha^i
//! 3. Evaluate scaled basis at query point
//! 4. Output local_basis[i] for later reduction
//!
//! This replaces the CPU loop that processes 148+ rows sequentially.
// Import binary field operations from binary_field.wgsl
// (Concatenated at compile time by Rust)
//
// Sumcheck Parameters
//
struct SumcheckParams {
n: u32, // log size of basis polynomial (e.g., 10 for 2^10 = 1024)
num_queries: u32, // Number of opened rows (typically 148)
k: u32, // Number of v_challenges (row width in log space)
row_size: u32, // Actual row size = 2^k
}
//
// Input Buffers (read-only)
//
// opened_rows[query][element]: The opened merkle rows (num_queries x row_size)
@group(0) @binding(0) var<storage, read> opened_rows: array<vec4<u32>>;
// v_challenges[i]: The verifier challenges (k elements)
@group(0) @binding(1) var<storage, read> v_challenges: array<vec4<u32>>;
// alpha_pows[i]: Precomputed powers of alpha (alpha^0, alpha^1, ..., alpha^(n-1))
@group(0) @binding(2) var<storage, read> alpha_pows: array<vec4<u32>>;
// DEBUG: Store raw dot products before scaling by alpha (reusing binding 3)
@group(0) @binding(3) var<storage, read_write> debug_dots: array<vec4<u32>>;
// sorted_queries[i]: Precomputed basis array indices (NOT query values!)
// These are computed on CPU by searching for F::from_bits(idx) == query
@group(0) @binding(4) var<storage, read> sorted_queries: array<u32>;
//
// Output Buffers (write)
//
// local_basis[query][coeff]: Per-query local basis polynomials (num_queries x 2^n)
@group(0) @binding(5) var<storage, read_write> local_basis: array<vec4<u32>>;
// contributions[query]: Dot product contributions (num_queries)
@group(0) @binding(6) var<storage, read_write> contributions: array<vec4<u32>>;
// Uniform buffer
@group(0) @binding(7) var<uniform> params: SumcheckParams;
// basis_poly_output[coeff]: Final reduced basis polynomial (2^n elements)
@group(0) @binding(8) var<storage, read_write> basis_poly_output: array<vec4<u32>>;
//
// Tensorized Dot Product
//
// Computes ⟨row, L(v_challenges)⟩ where L is the Lagrange basis.
// This is the inner loop hotspot - optimized for GPU.
//
// Algorithm: Fold the row vector by each challenge in reverse order:
// for each challenge r (from last to first):
// current[i] = (1-r) * current[2i] + r * current[2i+1]
// (in binary fields: 1-r = 1+r since subtraction = addition)
//
fn tensorized_dot_product(
row_offset: u32,
row_size: u32,
num_challenges: u32
) -> vec4<u32> {
// Three-tier strategy based on row size:
// - Small (≤128): Use local buffer, 2KB private memory (acceptable for most GPUs)
// - Medium (>128, ≤512): Use ping-pong between local buffers
// - Large (>512): Use global scratch buffer (requires additional binding)
if (row_size <= 128u) {
// Fast path: entire row fits in local memory (2KB max)
// Modern GPUs can handle 2KB private memory without occupancy issues
var buffer: array<vec4<u32>, 128>;
// IMPORTANT: Zero initialize buffer to avoid undefined behavior!
// WGSL doesn't guarantee zero-init for local arrays
for (var i = 0u; i < 128u; i++) {
buffer[i] = ZERO;
}
// Load initial row
for (var i = 0u; i < row_size; i++) {
buffer[i] = opened_rows[row_offset + i];
}
var current_size = row_size;
// Fold from last challenge to first
for (var c = 0u; c < num_challenges; c++) {
let challenge_idx = num_challenges - 1u - c;
let r = v_challenges[challenge_idx];
let one_minus_r = gf128_add(ONE, r);
let half_size = current_size / 2u;
// Fold in-place (write back to same buffer at lower indices)
for (var i = 0u; i < half_size; i++) {
let left_val = buffer[2u * i];
let right_val = buffer[2u * i + 1u];
let left = gf128_mul(one_minus_r, left_val);
let right = gf128_mul(r, right_val);
buffer[i] = gf128_add(left, right);
}
current_size = half_size;
}
return buffer[0];
} else if (row_size <= 512u) {
// Medium path: Use double buffering with 64-element chunks
// Process the row in chunks, gradually reducing size
var buffer_a: array<vec4<u32>, 64>;
var buffer_b: array<vec4<u32>, 64>;
// Zero initialize both buffers
for (var i = 0u; i < 64u; i++) {
buffer_a[i] = ZERO;
buffer_b[i] = ZERO;
}
var current_size = row_size;
// Fold from last challenge to first
for (var c = 0u; c < num_challenges; c++) {
let challenge_idx = num_challenges - 1u - c;
let r = v_challenges[challenge_idx];
let one_minus_r = gf128_add(ONE, r);
let half_size = current_size / 2u;
// Once data fits in single buffer, switch to fast in-place folding
if (current_size <= 64u) {
// Load remaining data into buffer_a
for (var i = 0u; i < current_size; i++) {
buffer_a[i] = opened_rows[row_offset + i];
}
// Finish remaining folds in local memory
var size = current_size;
for (var cc = c; cc < num_challenges; cc++) {
let cidx = num_challenges - 1u - cc;
let rr = v_challenges[cidx];
let omr = gf128_add(ONE, rr);
let hs = size / 2u;
for (var i = 0u; i < hs; i++) {
let left = gf128_mul(omr, buffer_a[2u * i]);
let right = gf128_mul(rr, buffer_a[2u * i + 1u]);
buffer_a[i] = gf128_add(left, right);
}
size = hs;
}
return buffer_a[0];
}
// Process in 64-element chunks, folding and writing back
// This is still a placeholder - proper implementation would require
// temporary storage in global memory or a scratch buffer
// For now, limit to row_size <= 128 for correctness
current_size = half_size;
}
return ZERO; // Should not reach here for row_size <= 512
} else {
// Large path: Not yet implemented - requires global scratch buffer
// For row_size > 512 (k >= 10), we need a dedicated scratch buffer
// binding to store intermediate folding results
// CRITICAL: This path is not implemented yet!
// Returning ZERO will cause incorrect results
// TODO: Implement global memory folding with scratch buffer
return ZERO;
}
}
//
// Evaluate Scaled Basis
//
// Sets basis[query] = contribution, all others to zero.
// This matches the CPU implementation which just sets one index.
//
fn evaluate_scaled_basis(
query: u32,
contribution: vec4<u32>,
basis_size: u32,
output_offset: u32
) {
// SECURITY: Check both bounds and overflow before writing
// 1. Check query is within basis_size
// 2. Check that output_offset + query doesn't overflow
// 3. Combined check prevents out-of-bounds writes
if (query < basis_size && query <= (0xFFFFFFFFu - output_offset)) {
local_basis[output_offset + query] = contribution;
}
}
//
// Main Sumcheck Kernel
//
// Each workgroup processes one query (opened row).
// This achieves massive parallelism: 148+ queries computed simultaneously!
//
@compute @workgroup_size(1) // One thread per query (for now)
fn sumcheck_contribution(@builtin(global_invocation_id) id: vec3<u32>) {
let query_idx = id.x;
if (query_idx >= params.num_queries) {
return;
}
// SECURITY: Check for integer overflow in row_offset calculation
// This prevents out-of-bounds access if query_idx * row_size overflows
let max_row_offset = 0xFFFFFFFFu / params.row_size;
if (query_idx > max_row_offset) {
// Overflow would occur - abort this thread
return;
}
// 1. Compute tensorized dot product
let row_offset = query_idx * params.row_size;
let dot = tensorized_dot_product(row_offset, params.row_size, params.k);
// 2. Scale by alpha^i
let alpha_pow = alpha_pows[query_idx];
let contribution = gf128_mul(dot, alpha_pow);
// DEBUG: Store alpha power (for debugging, we can compare CPU vs GPU alpha pows)
debug_dots[query_idx] = alpha_pow;
// Store contribution for final reduction
contributions[query_idx] = contribution;
// 3. Evaluate scaled basis at precomputed basis index
// sorted_queries[query_idx] is the precomputed basis array index (NOT the query value!)
let basis_idx = sorted_queries[query_idx];
let basis_size = 1u << params.n;
// SECURITY: Check for overflow in output_offset calculation
let max_query_for_basis = 0xFFFFFFFFu / basis_size;
if (query_idx > max_query_for_basis) {
// Overflow would occur - abort
return;
}
let output_offset = query_idx * basis_size;
evaluate_scaled_basis(basis_idx, contribution, basis_size, output_offset);
}
//
// Reduction Kernel
//
// Sums all local_basis polynomials into final basis_poly.
// Each thread handles one coefficient across all queries.
//
@compute @workgroup_size(256)
fn reduce_basis(@builtin(global_invocation_id) id: vec3<u32>) {
let coeff_idx = id.x;
let basis_size = 1u << params.n;
if (coeff_idx >= basis_size) {
return;
}
// SECURITY: Check for overflow in offset calculation
let max_query_safe = 0xFFFFFFFFu / basis_size;
var sum = ZERO;
// Sum across all queries
for (var query = 0u; query < params.num_queries; query++) {
// Skip this iteration if overflow would occur
if (query > max_query_safe) {
continue;
}
let offset = query * basis_size + coeff_idx;
sum = gf128_add(sum, local_basis[offset]);
}
// Write to separate output buffer (NO race condition!)
basis_poly_output[coeff_idx] = sum;
}
//
// Sum Contributions
//
// Final reduction of all contribution values into enforced_sum.
//
@compute @workgroup_size(256)
fn reduce_contributions(@builtin(global_invocation_id) id: vec3<u32>) {
// Use parallel reduction in shared memory
// For simplicity, just do sequential reduction in one thread
if (id.x != 0u) {
return;
}
var sum = ZERO;
for (var i = 0u; i < params.num_queries; i++) {
sum = gf128_add(sum, contributions[i]);
}
// Store final sum in contributions[0]
contributions[0] = sum;
}