1use {
2 crate::{
3 sparse_matrix::SparseMatrix,
4 utils::{unzip_double_array, workload_size},
5 FieldElement, R1CS,
6 },
7 ark_std::{One, Zero},
8 std::array,
9 tracing::instrument,
10};
11
12pub fn sumcheck_fold_map_reduce<const N: usize, const M: usize>(
15 mles: [&mut [FieldElement]; N],
16 fold: Option<FieldElement>,
17 map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
18) -> [FieldElement; M] {
19 let size = mles[0].len();
20 assert!(size.is_power_of_two());
21 assert!(size >= 2);
22 assert!(mles.iter().all(|mle| mle.len() == size));
23
24 if let Some(fold) = fold {
25 assert!(size >= 4);
26 let slices = mles.map(|mle| {
27 let (p0, tail) = mle.split_at_mut(size / 4);
28 let (p1, tail) = tail.split_at_mut(size / 4);
29 let (p2, p3) = tail.split_at_mut(size / 4);
30 [p0, p1, p2, p3]
31 });
32 sumcheck_fold_map_reduce_inner::<N, M>(slices, fold, map)
33 } else {
34 let slices = mles.map(|mle| mle.split_at(size / 2));
35 sumcheck_map_reduce_inner::<N, M>(slices, map)
36 }
37}
38
39fn sumcheck_map_reduce_inner<const N: usize, const M: usize>(
40 mles: [(&[FieldElement], &[FieldElement]); N],
41 map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
42) -> [FieldElement; M] {
43 let size = mles[0].0.len();
44 if size * N * 2 > workload_size::<FieldElement>() {
45 let pairs = mles.map(|(p0, p1)| (p0.split_at(size / 2), p1.split_at(size / 2)));
47 let left = pairs.map(|((l0, _), (l1, _))| (l0, l1));
48 let right = pairs.map(|((_, r0), (_, r1))| (r0, r1));
49
50 let (l, r) = rayon::join(
52 || sumcheck_map_reduce_inner(left, map),
53 || sumcheck_map_reduce_inner(right, map),
54 );
55
56 array::from_fn(|i| l[i] + r[i])
58 } else {
59 let mut result = [FieldElement::zero(); M];
60 for i in 0..size {
61 let e = mles.map(|(p0, p1)| (p0[i], p1[i]));
62 let local = map(e);
63 result.iter_mut().zip(local).for_each(|(r, l)| *r += l);
64 }
65 result
66 }
67}
68
69fn sumcheck_fold_map_reduce_inner<const N: usize, const M: usize>(
70 mut mles: [[&mut [FieldElement]; 4]; N],
71 fold: FieldElement,
72 map: impl Fn([(FieldElement, FieldElement); N]) -> [FieldElement; M] + Send + Sync + Copy,
73) -> [FieldElement; M] {
74 let size = mles[0][0].len();
75 if size * N * 4 > workload_size::<FieldElement>() {
76 let pairs = mles.map(|mles| mles.map(|p| p.split_at_mut(size / 2)));
78 let (left, right) = unzip_double_array(pairs);
79
80 let (l, r) = rayon::join(
82 || sumcheck_fold_map_reduce_inner(left, fold, map),
83 || sumcheck_fold_map_reduce_inner(right, fold, map),
84 );
85
86 array::from_fn(|i| l[i] + r[i])
88 } else {
89 let mut result = [FieldElement::zero(); M];
90 for i in 0..size {
91 let e = array::from_fn(|j| {
92 let mle = &mut mles[j];
93 mle[0][i] += fold * (mle[2][i] - mle[0][i]);
94 mle[1][i] += fold * (mle[3][i] - mle[1][i]);
95 (mle[0][i], mle[1][i])
96 });
97 let local = map(e);
98 result.iter_mut().zip(local).for_each(|(r, l)| *r += l);
99 }
100 result
101 }
102}
103
104#[instrument(skip_all)]
111pub fn calculate_evaluations_over_boolean_hypercube_for_eq(
112 r: &[FieldElement],
113 num_entries: usize,
114) -> Vec<FieldElement> {
115 let full_size = 1usize << r.len();
116 assert!(
117 num_entries <= full_size,
118 "num_entries ({num_entries}) exceeds 2^{} = {full_size}",
119 r.len()
120 );
121 let mut result = vec![FieldElement::zero(); num_entries];
122 eval_eq(r, &mut result, FieldElement::one(), full_size);
123 result
124}
125
126fn eval_eq(
130 eval: &[FieldElement],
131 out: &mut [FieldElement],
132 scalar: FieldElement,
133 subtree_size: usize,
134) {
135 debug_assert!(out.len() <= subtree_size);
136 if let Some((&x, tail)) = eval.split_first() {
137 let half = subtree_size / 2;
138 let left_len = out.len().min(half);
139 let right_len = out.len().saturating_sub(half);
140 let (o0, o1) = out.split_at_mut(left_len);
141 let s1 = scalar * x;
142 let s0 = scalar - s1;
143 if right_len == 0 {
144 eval_eq(tail, o0, s0, half);
145 } else if subtree_size > workload_size::<FieldElement>() {
146 rayon::join(
147 || eval_eq(tail, o0, s0, half),
148 || eval_eq(tail, o1, s1, half),
149 );
150 } else {
151 eval_eq(tail, o0, s0, half);
152 eval_eq(tail, o1, s1, half);
153 }
154 } else {
155 out[0] += scalar;
156 }
157}
158
159pub fn eval_cubic_poly(poly: [FieldElement; 4], point: FieldElement) -> FieldElement {
161 poly[0] + point * (poly[1] + point * (poly[2] + point * poly[3]))
162}
163
164#[instrument(skip_all)]
167pub fn calculate_witness_bounds(
168 r1cs: &R1CS,
169 witness: &[FieldElement],
170) -> (Vec<FieldElement>, Vec<FieldElement>, Vec<FieldElement>) {
171 let (a, b) = rayon::join(|| r1cs.a() * witness, || r1cs.b() * witness);
172
173 let target_len = a.len().next_power_of_two();
174 let mut c = Vec::with_capacity(target_len);
175 c.extend(a.iter().zip(b.iter()).map(|(a, b)| *a * *b));
176 c.resize(target_len, FieldElement::zero());
177
178 let mut a = a;
179 let mut b = b;
180 a.resize(target_len, FieldElement::zero());
181 b.resize(target_len, FieldElement::zero());
182 (a, b, c)
183}
184
185pub fn calculate_eq(r: &[FieldElement], alpha: &[FieldElement]) -> FieldElement {
187 r.iter()
188 .zip(alpha.iter())
189 .fold(FieldElement::from(1), |acc, (&r, &alpha)| {
190 acc * (r * alpha + (FieldElement::from(1) - r) * (FieldElement::from(1) - alpha))
191 })
192}
193
194#[instrument(skip_all)]
199pub fn transpose_r1cs_matrices(r1cs: &R1CS) -> (SparseMatrix, SparseMatrix, SparseMatrix) {
200 let ((at, bt), ct) = rayon::join(
201 || rayon::join(|| r1cs.a.transpose(), || r1cs.b.transpose()),
202 || r1cs.c.transpose(),
203 );
204 (at, bt, ct)
205}
206
207#[instrument(skip_all)]
210pub fn multiply_transposed_by_eq_alpha(
211 at: &SparseMatrix,
212 bt: &SparseMatrix,
213 ct: &SparseMatrix,
214 alpha: &[FieldElement],
215 r1cs: &R1CS,
216) -> [Vec<FieldElement>; 3] {
217 let eq_alpha =
218 calculate_evaluations_over_boolean_hypercube_for_eq(alpha, r1cs.num_constraints());
219 let interner = &r1cs.interner;
220 let ((a, b), c) = rayon::join(
221 || {
222 rayon::join(
223 || at.hydrate(interner) * eq_alpha.as_slice(),
224 || bt.hydrate(interner) * eq_alpha.as_slice(),
225 )
226 },
227 || ct.hydrate(interner) * eq_alpha.as_slice(),
228 );
229 [a, b, c]
230}
231
232#[instrument(skip_all)]
238pub fn calculate_external_row_of_r1cs_matrices(
239 alpha: &[FieldElement],
240 r1cs: &R1CS,
241) -> [Vec<FieldElement>; 3] {
242 let (at, bt, ct) = transpose_r1cs_matrices(r1cs);
243 multiply_transposed_by_eq_alpha(&at, &bt, &ct, alpha, r1cs)
244}