ark_linear_sumcheck/ml_sumcheck/protocol/
verifier.rs

1//! Verifier
2use crate::ml_sumcheck::data_structures::PolynomialInfo;
3use crate::ml_sumcheck::protocol::prover::ProverMsg;
4use crate::ml_sumcheck::protocol::IPForMLSumcheck;
5use ark_ff::Field;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use ark_std::rand::RngCore;
8use ark_std::vec::Vec;
9
10#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
11/// Verifier Message
12pub struct VerifierMsg<F: Field> {
13    /// randomness sampled by verifier
14    pub randomness: F,
15}
16
17/// Verifier State
18pub struct VerifierState<F: Field> {
19    round: usize,
20    nv: usize,
21    max_multiplicands: usize,
22    finished: bool,
23    /// a list storing the univariate polynomial in evaluation form sent by the prover at each round
24    polynomials_received: Vec<Vec<F>>,
25    /// a list storing the randomness sampled by the verifier at each round
26    randomness: Vec<F>,
27}
28/// Subclaim when verifier is convinced
29pub struct SubClaim<F: Field> {
30    /// the multi-dimensional point that this multilinear extension is evaluated to
31    pub point: Vec<F>,
32    /// the expected evaluation
33    pub expected_evaluation: F,
34}
35
36impl<F: Field> IPForMLSumcheck<F> {
37    /// initialize the verifier
38    pub fn verifier_init(index_info: &PolynomialInfo) -> VerifierState<F> {
39        VerifierState {
40            round: 1,
41            nv: index_info.num_variables,
42            max_multiplicands: index_info.max_multiplicands,
43            finished: false,
44            polynomials_received: Vec::with_capacity(index_info.num_variables),
45            randomness: Vec::with_capacity(index_info.num_variables),
46        }
47    }
48
49    /// Run verifier at current round, given prover message
50    ///
51    /// Normally, this function should perform actual verification. Instead, `verify_round` only samples
52    /// and stores randomness and perform verifications altogether in `check_and_generate_subclaim` at
53    /// the last step.
54    pub fn verify_round<R: RngCore>(
55        prover_msg: ProverMsg<F>,
56        verifier_state: &mut VerifierState<F>,
57        rng: &mut R,
58    ) -> Option<VerifierMsg<F>> {
59        if verifier_state.finished {
60            panic!("Incorrect verifier state: Verifier is already finished.");
61        }
62
63        // Now, verifier should check if the received P(0) + P(1) = expected. The check is moved to
64        // `check_and_generate_subclaim`, and will be done after the last round.
65
66        let msg = Self::sample_round(rng);
67        verifier_state.randomness.push(msg.randomness);
68        verifier_state
69            .polynomials_received
70            .push(prover_msg.evaluations);
71
72        // Now, verifier should set `expected` to P(r).
73        // This operation is also moved to `check_and_generate_subclaim`,
74        // and will be done after the last round.
75
76        if verifier_state.round == verifier_state.nv {
77            // accept and close
78            verifier_state.finished = true;
79        } else {
80            verifier_state.round += 1;
81        }
82        Some(msg)
83    }
84
85    /// verify the sumcheck phase, and generate the subclaim
86    ///
87    /// If the asserted sum is correct, then the multilinear polynomial evaluated at `subclaim.point`
88    /// is `subclaim.expected_evaluation`. Otherwise, it is highly unlikely that those two will be equal.
89    /// Larger field size guarantees smaller soundness error.
90    pub fn check_and_generate_subclaim(
91        verifier_state: VerifierState<F>,
92        asserted_sum: F,
93    ) -> Result<SubClaim<F>, crate::Error> {
94        if !verifier_state.finished {
95            panic!("Verifier has not finished.");
96        }
97
98        let mut expected = asserted_sum;
99        if verifier_state.polynomials_received.len() != verifier_state.nv {
100            panic!("insufficient rounds");
101        }
102        for i in 0..verifier_state.nv {
103            let evaluations = &verifier_state.polynomials_received[i];
104            if evaluations.len() != verifier_state.max_multiplicands + 1 {
105                panic!("incorrect number of evaluations");
106            }
107            let p0 = evaluations[0];
108            let p1 = evaluations[1];
109            if p0 + p1 != expected {
110                return Err(crate::Error::Reject(Some(
111                    "Prover message is not consistent with the claim.".into(),
112                )));
113            }
114            expected = interpolate_uni_poly(evaluations, verifier_state.randomness[i]);
115        }
116
117        Ok(SubClaim {
118            point: verifier_state.randomness,
119            expected_evaluation: expected,
120        })
121    }
122
123    /// simulate a verifier message without doing verification
124    ///
125    /// Given the same calling context, `random_oracle_round` output exactly the same message as
126    /// `verify_round`
127    #[inline]
128    pub fn sample_round<R: RngCore>(rng: &mut R) -> VerifierMsg<F> {
129        VerifierMsg {
130            randomness: F::rand(rng),
131        }
132    }
133}
134
135/// interpolate the *unique* univariate polynomial of degree *at most*
136/// p_i.len()-1 passing through the y-values in p_i at x = 0,..., p_i.len()-1
137/// and evaluate this  polynomial at `eval_at`. In other words, efficiently compute
138///  \sum_{i=0}^{len p_i - 1} p_i[i] * (\prod_{j!=i} (eval_at - j)/(i-j))
139pub(crate) fn interpolate_uni_poly<F: Field>(p_i: &[F], eval_at: F) -> F {
140    let len = p_i.len();
141
142    let mut evals = vec![];
143
144    let mut prod = eval_at;
145    evals.push(eval_at);
146
147    //`prod = \prod_{j} (eval_at - j)`
148    // we return early if 0 <= eval_at <  len, i.e. if the desired value has been passed
149    let mut check = F::zero();
150    for i in 1..len {
151        if eval_at == check {
152            return p_i[i - 1];
153        }
154        check += F::one();
155
156        let tmp = eval_at - check;
157        evals.push(tmp);
158        prod *= tmp;
159    }
160
161    if eval_at == check {
162        return p_i[len - 1];
163    }
164
165    let mut res = F::zero();
166    // we want to compute \prod (j!=i) (i-j) for a given i
167    //
168    // we start from the last step, which is
169    //  denom[len-1] = (len-1) * (len-2) *... * 2 * 1
170    // the step before that is
171    //  denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1
172    // and the step before that is
173    //  denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2
174    //
175    // i.e., for any i, the one before this will be derived from
176    //  denom[i-1] = - denom[i] * (len-i) / i
177    //
178    // that is, we only need to store
179    // - the last denom for i = len-1, and
180    // - the ratio between the current step and the last step, which is the
181    //   product of -(len-i) / i from all previous steps and we store
182    //   this product as a fraction number to reduce field divisions.
183
184    // We know
185    //  - 2^61 < factorial(20) < 2^62
186    //  - 2^122 < factorial(33) < 2^123
187    // so we will be able to compute the ratio
188    //  - for len <= 20 with i64
189    //  - for len <= 33 with i128
190    //  - for len >  33 with BigInt
191    if p_i.len() <= 20 {
192        let last_denom = F::from(u64_factorial(len - 1));
193        let mut ratio_numerator = 1i64;
194        let mut ratio_enumerator = 1u64;
195
196        for i in (0..len).rev() {
197            let ratio_numerator_f = if ratio_numerator < 0 {
198                -F::from((-ratio_numerator) as u64)
199            } else {
200                F::from(ratio_numerator as u64)
201            };
202
203            res += p_i[i] * prod * F::from(ratio_enumerator)
204                / (last_denom * ratio_numerator_f * evals[i]);
205
206            // compute ratio for the next step which is current_ratio * -(len-i)/i
207            if i != 0 {
208                ratio_numerator *= -(len as i64 - i as i64);
209                ratio_enumerator *= i as u64;
210            }
211        }
212    } else if p_i.len() <= 33 {
213        let last_denom = F::from(u128_factorial(len - 1));
214        let mut ratio_numerator = 1i128;
215        let mut ratio_enumerator = 1u128;
216
217        for i in (0..len).rev() {
218            let ratio_numerator_f = if ratio_numerator < 0 {
219                -F::from((-ratio_numerator) as u128)
220            } else {
221                F::from(ratio_numerator as u128)
222            };
223
224            res += p_i[i] * prod * F::from(ratio_enumerator)
225                / (last_denom * ratio_numerator_f * evals[i]);
226
227            // compute ratio for the next step which is current_ratio * -(len-i)/i
228            if i != 0 {
229                ratio_numerator *= -(len as i128 - i as i128);
230                ratio_enumerator *= i as u128;
231            }
232        }
233    } else {
234        // since we are using field operations, we can merge
235        // `last_denom` and `ratio_numerator` into a single field element.
236        let mut denom_up = field_factorial::<F>(len - 1);
237        let mut denom_down = F::one();
238
239        for i in (0..len).rev() {
240            res += p_i[i] * prod * denom_down / (denom_up * evals[i]);
241
242            // compute denom for the next step is -current_denom * (len-i)/i
243            if i != 0 {
244                denom_up *= -F::from((len - i) as u64);
245                denom_down *= F::from(i as u64);
246            }
247        }
248    }
249
250    res
251}
252
253/// compute the factorial(a) = 1 * 2 * ... * a
254#[inline]
255fn field_factorial<F: Field>(a: usize) -> F {
256    let mut res = F::one();
257    for i in 1..=a {
258        res *= F::from(i as u64);
259    }
260    res
261}
262
263/// compute the factorial(a) = 1 * 2 * ... * a
264#[inline]
265fn u128_factorial(a: usize) -> u128 {
266    let mut res = 1u128;
267    for i in 1..=a {
268        res *= i as u128;
269    }
270    res
271}
272
273/// compute the factorial(a) = 1 * 2 * ... * a
274#[inline]
275fn u64_factorial(a: usize) -> u64 {
276    let mut res = 1u64;
277    for i in 1..=a {
278        res *= i as u64;
279    }
280    res
281}
282
283#[cfg(test)]
284mod test {
285    use crate::ml_sumcheck::protocol::verifier::interpolate_uni_poly;
286    use ark_poly::univariate::DensePolynomial;
287    use ark_poly::DenseUVPolynomial;
288    use ark_poly::Polynomial;
289    use ark_std::vec::Vec;
290    use ark_std::UniformRand;
291
292    type F = ark_test_curves::bls12_381::Fr;
293
294    #[test]
295    fn test_interpolation() {
296        let mut prng = ark_std::test_rng();
297
298        // test a polynomial with 20 known points, i.e., with degree 19
299        let poly = DensePolynomial::<F>::rand(20 - 1, &mut prng);
300        let evals = (0..20)
301            .map(|i| poly.evaluate(&F::from(i)))
302            .collect::<Vec<F>>();
303        let query = F::rand(&mut prng);
304
305        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
306
307        // test a polynomial with 33 known points, i.e., with degree 32
308        let poly = DensePolynomial::<F>::rand(33 - 1, &mut prng);
309        let evals = (0..33)
310            .map(|i| poly.evaluate(&F::from(i)))
311            .collect::<Vec<F>>();
312        let query = F::rand(&mut prng);
313
314        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
315
316        // test a polynomial with 64 known points, i.e., with degree 63
317        let poly = DensePolynomial::<F>::rand(64 - 1, &mut prng);
318        let evals = (0..64)
319            .map(|i| poly.evaluate(&F::from(i)))
320            .collect::<Vec<F>>();
321        let query = F::rand(&mut prng);
322
323        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
324
325        // test interpolation when we ask for the value at an x-cordinate
326        // we are already passing, i.e. in the range 0 <= x < len(values) - 1
327        let evals = vec![0, 1, 4, 9]
328            .into_iter()
329            .map(|i| F::from(i))
330            .collect::<Vec<F>>();
331        assert_eq!(interpolate_uni_poly(&evals, F::from(3)), F::from(9));
332    }
333}