Skip to main content

provekit_common/utils/
sumcheck.rs

1use {
2    crate::{
3        sparse_matrix::SparseMatrix,
4        utils::{unzip_double_array, workload_size},
5        FieldElement, R1CS,
6    },
7    ark_std::{One, Zero},
8    std::array,
9    tracing::instrument,
10};
11
12/// Compute the sum of a vector valued function over the boolean hypercube in
13/// the leading variable.
14pub fn sumcheck_fold_map_reduce<const N: usize, const M: usize>(
15    mles: [&mut [FieldElement]; N],
16    fold: Option<FieldElement>,
17    map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
18) -> [FieldElement; M] {
19    let size = mles[0].len();
20    assert!(size.is_power_of_two());
21    assert!(size >= 2);
22    assert!(mles.iter().all(|mle| mle.len() == size));
23
24    if let Some(fold) = fold {
25        assert!(size >= 4);
26        let slices = mles.map(|mle| {
27            let (p0, tail) = mle.split_at_mut(size / 4);
28            let (p1, tail) = tail.split_at_mut(size / 4);
29            let (p2, p3) = tail.split_at_mut(size / 4);
30            [p0, p1, p2, p3]
31        });
32        sumcheck_fold_map_reduce_inner::<N, M>(slices, fold, map)
33    } else {
34        let slices = mles.map(|mle| mle.split_at(size / 2));
35        sumcheck_map_reduce_inner::<N, M>(slices, map)
36    }
37}
38
39fn sumcheck_map_reduce_inner<const N: usize, const M: usize>(
40    mles: [(&[FieldElement], &[FieldElement]); N],
41    map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
42) -> [FieldElement; M] {
43    let size = mles[0].0.len();
44    if size * N * 2 > workload_size::<FieldElement>() {
45        // Split slices
46        let pairs = mles.map(|(p0, p1)| (p0.split_at(size / 2), p1.split_at(size / 2)));
47        let left = pairs.map(|((l0, _), (l1, _))| (l0, l1));
48        let right = pairs.map(|((_, r0), (_, r1))| (r0, r1));
49
50        // Parallel recurse
51        let (l, r) = rayon::join(
52            || sumcheck_map_reduce_inner(left, map),
53            || sumcheck_map_reduce_inner(right, map),
54        );
55
56        // Combine results
57        array::from_fn(|i| l[i] + r[i])
58    } else {
59        let mut result = [FieldElement::zero(); M];
60        for i in 0..size {
61            let e = mles.map(|(p0, p1)| (p0[i], p1[i]));
62            let local = map(e);
63            result.iter_mut().zip(local).for_each(|(r, l)| *r += l);
64        }
65        result
66    }
67}
68
69fn sumcheck_fold_map_reduce_inner<const N: usize, const M: usize>(
70    mut mles: [[&mut [FieldElement]; 4]; N],
71    fold: FieldElement,
72    map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
73) -> [FieldElement; M] {
74    let size = mles[0][0].len();
75    if size * N * 4 > workload_size::<FieldElement>() {
76        // Split slices
77        let pairs = mles.map(|mles| mles.map(|p| p.split_at_mut(size / 2)));
78        let (left, right) = unzip_double_array(pairs);
79
80        // Parallel recurse
81        let (l, r) = rayon::join(
82            || sumcheck_fold_map_reduce_inner(left, fold, map),
83            || sumcheck_fold_map_reduce_inner(right, fold, map),
84        );
85
86        // Combine results
87        array::from_fn(|i| l[i] + r[i])
88    } else {
89        let mut result = [FieldElement::zero(); M];
90        for i in 0..size {
91            let e = array::from_fn(|j| {
92                let mle = &mut mles[j];
93                mle[0][i] += fold * (mle[2][i] - mle[0][i]);
94                mle[1][i] += fold * (mle[3][i] - mle[1][i]);
95                (mle[0][i], mle[1][i])
96            });
97            let local = map(e);
98            result.iter_mut().zip(local).for_each(|(r, l)| *r += l);
99        }
100        result
101    }
102}
103
104// TODO: Add unit tests for calculate_evaluations_over_boolean_hypercube_for_eq,
105// eval_eq, calculate_eq, and the transposed matrix multiplication helpers.
106
107/// List of evaluations for eq(r, x) over the boolean hypercube, truncated to
108/// `num_entries` elements. When `num_entries < 2^r.len()`, avoids allocating
109/// the full hypercube.
110#[instrument(skip_all)]
111pub fn calculate_evaluations_over_boolean_hypercube_for_eq(
112    r: &[FieldElement],
113    num_entries: usize,
114) -> Vec<FieldElement> {
115    let full_size = 1usize << r.len();
116    assert!(
117        num_entries <= full_size,
118        "num_entries ({num_entries}) exceeds 2^{} = {full_size}",
119        r.len()
120    );
121    let mut result = vec![FieldElement::zero(); num_entries];
122    eval_eq(r, &mut result, FieldElement::one(), full_size);
123    result
124}
125
126/// Evaluates the equality polynomial recursively. `subtree_size` tracks the
127/// logical size of this recursion level so that truncated output buffers are
128/// split correctly.
129fn eval_eq(
130    eval: &[FieldElement],
131    out: &mut [FieldElement],
132    scalar: FieldElement,
133    subtree_size: usize,
134) {
135    debug_assert!(out.len() <= subtree_size);
136    if let Some((&x, tail)) = eval.split_first() {
137        let half = subtree_size / 2;
138        let left_len = out.len().min(half);
139        let right_len = out.len().saturating_sub(half);
140        let (o0, o1) = out.split_at_mut(left_len);
141        let s1 = scalar * x;
142        let s0 = scalar - s1;
143        if right_len == 0 {
144            eval_eq(tail, o0, s0, half);
145        } else if subtree_size > workload_size::<FieldElement>() {
146            rayon::join(
147                || eval_eq(tail, o0, s0, half),
148                || eval_eq(tail, o1, s1, half),
149            );
150        } else {
151            eval_eq(tail, o0, s0, half);
152            eval_eq(tail, o1, s1, half);
153        }
154    } else {
155        out[0] += scalar;
156    }
157}
158
159/// Evaluates a cubic polynomial on a value
160pub fn eval_cubic_poly(poly: [FieldElement; 4], point: FieldElement) -> FieldElement {
161    poly[0] + point * (poly[1] + point * (poly[2] + point * poly[3]))
162}
163
164/// Given a path to JSON file with sparce matrices and a witness, calculates
165/// matrix-vector multiplication and returns them
166#[instrument(skip_all)]
167pub fn calculate_witness_bounds(
168    r1cs: &R1CS,
169    witness: &[FieldElement],
170) -> (Vec<FieldElement>, Vec<FieldElement>, Vec<FieldElement>) {
171    let (a, b) = rayon::join(|| r1cs.a() * witness, || r1cs.b() * witness);
172
173    let target_len = a.len().next_power_of_two();
174    let mut c = Vec::with_capacity(target_len);
175    c.extend(a.iter().zip(b.iter()).map(|(a, b)| *a * *b));
176    c.resize(target_len, FieldElement::zero());
177
178    let mut a = a;
179    let mut b = b;
180    a.resize(target_len, FieldElement::zero());
181    b.resize(target_len, FieldElement::zero());
182    (a, b, c)
183}
184
185/// Calculates eq(r, alpha)
186pub fn calculate_eq(r: &[FieldElement], alpha: &[FieldElement]) -> FieldElement {
187    r.iter()
188        .zip(alpha.iter())
189        .fold(FieldElement::from(1), |acc, (&r, &alpha)| {
190            acc * (r * alpha + (FieldElement::from(1) - r) * (FieldElement::from(1) - alpha))
191        })
192}
193
194/// Transpose all three R1CS matrices in parallel.
195///
196/// This depends only on the R1CS structure (from the verifier key), not on any
197/// proof-specific data, so it can run concurrently with sumcheck verification.
198#[instrument(skip_all)]
199pub fn transpose_r1cs_matrices(r1cs: &R1CS) -> (SparseMatrix, SparseMatrix, SparseMatrix) {
200    let ((at, bt), ct) = rayon::join(
201        || rayon::join(|| r1cs.a.transpose(), || r1cs.b.transpose()),
202        || r1cs.c.transpose(),
203    );
204    (at, bt, ct)
205}
206
207/// Multiply pre-transposed R1CS matrices by eq(alpha, ·) to compute the
208/// external row.
209#[instrument(skip_all)]
210pub fn multiply_transposed_by_eq_alpha(
211    at: &SparseMatrix,
212    bt: &SparseMatrix,
213    ct: &SparseMatrix,
214    alpha: &[FieldElement],
215    r1cs: &R1CS,
216) -> [Vec<FieldElement>; 3] {
217    let eq_alpha =
218        calculate_evaluations_over_boolean_hypercube_for_eq(alpha, r1cs.num_constraints());
219    let interner = &r1cs.interner;
220    let ((a, b), c) = rayon::join(
221        || {
222            rayon::join(
223                || at.hydrate(interner) * eq_alpha.as_slice(),
224                || bt.hydrate(interner) * eq_alpha.as_slice(),
225            )
226        },
227        || ct.hydrate(interner) * eq_alpha.as_slice(),
228    );
229    [a, b, c]
230}
231
232/// Calculates a random row of R1CS matrix extension. Made possible due to
233/// sparseness.
234///
235/// Computes `eq(alpha, ·) * [A, B, C]` using transposed matrices for
236/// parallel right-multiplication instead of sequential left-multiplication.
237#[instrument(skip_all)]
238pub fn calculate_external_row_of_r1cs_matrices(
239    alpha: &[FieldElement],
240    r1cs: &R1CS,
241) -> [Vec<FieldElement>; 3] {
242    let (at, bt, ct) = transpose_r1cs_matrices(r1cs);
243    multiply_transposed_by_eq_alpha(&at, &bt, &ct, alpha, r1cs)
244}