ark_linear_sumcheck/ml_sumcheck/protocol/
prover.rs

1//! Prover
2use crate::ml_sumcheck::data_structures::ListOfProductsOfPolynomials;
3use crate::ml_sumcheck::protocol::verifier::VerifierMsg;
4use crate::ml_sumcheck::protocol::IPForMLSumcheck;
5use ark_ff::Field;
6use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
7use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
8use ark_std::{cfg_iter_mut, vec::Vec};
9#[cfg(feature = "parallel")]
10use rayon::prelude::*;
11
12/// Prover Message
13#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
14pub struct ProverMsg<F: Field> {
15    /// evaluations on P(0), P(1), P(2), ...
16    pub(crate) evaluations: Vec<F>,
17}
18/// Prover State
19pub struct ProverState<F: Field> {
20    /// sampled randomness given by the verifier
21    pub randomness: Vec<F>,
22    /// Stores the list of products that is meant to be added together. Each multiplicand is represented by
23    /// the index in flattened_ml_extensions
24    pub list_of_products: Vec<(F, Vec<usize>)>,
25    /// Stores a list of multilinear extensions in which `self.list_of_products` points to
26    pub flattened_ml_extensions: Vec<DenseMultilinearExtension<F>>,
27    /// Number of variables
28    pub num_vars: usize,
29    /// Max number of multiplicands in a product
30    pub max_multiplicands: usize,
31    /// The current round number
32    pub round: usize,
33}
34
35impl<F: Field> IPForMLSumcheck<F> {
36    /// initialize the prover to argue for the sum of polynomial over {0,1}^`num_vars`
37    ///
38    /// The polynomial is represented by a list of products of polynomials along with its coefficient that is meant to be added together.
39    ///
40    /// This data structure of the polynomial is a list of list of `(coefficient, DenseMultilinearExtension)`.
41    /// * Number of products n = `polynomial.products.len()`,
42    /// * Number of multiplicands of ith product m_i = `polynomial.products[i].1.len()`,
43    /// * Coefficient of ith product c_i = `polynomial.products[i].0`
44    ///
45    /// The resulting polynomial is
46    ///
47    /// $$\sum_{i=0}^{n}C_i\cdot\prod_{j=0}^{m_i}P_{ij}$$
48    ///
49    pub fn prover_init(polynomial: &ListOfProductsOfPolynomials<F>) -> ProverState<F> {
50        if polynomial.num_variables == 0 {
51            panic!("Attempt to prove a constant.")
52        }
53
54        // create a deep copy of all unique MLExtensions
55        let flattened_ml_extensions = polynomial
56            .flattened_ml_extensions
57            .iter()
58            .map(|x| x.as_ref().clone())
59            .collect();
60
61        ProverState {
62            randomness: Vec::with_capacity(polynomial.num_variables),
63            list_of_products: polynomial.products.clone(),
64            flattened_ml_extensions,
65            num_vars: polynomial.num_variables,
66            max_multiplicands: polynomial.max_multiplicands,
67            round: 0,
68        }
69    }
70
71    /// receive message from verifier, generate prover message, and proceed to next round
72    ///
73    /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
74    pub fn prove_round(
75        prover_state: &mut ProverState<F>,
76        v_msg: &Option<VerifierMsg<F>>,
77    ) -> ProverMsg<F> {
78        if let Some(msg) = v_msg {
79            if prover_state.round == 0 {
80                panic!("first round should be prover first.");
81            }
82            prover_state.randomness.push(msg.randomness);
83
84            // fix argument
85            let i = prover_state.round;
86            let r = prover_state.randomness[i - 1];
87            cfg_iter_mut!(prover_state.flattened_ml_extensions).for_each(|multiplicand| {
88                *multiplicand = multiplicand.fix_variables(&[r]);
89            });
90        } else if prover_state.round > 0 {
91            panic!("verifier message is empty");
92        }
93
94        prover_state.round += 1;
95
96        if prover_state.round > prover_state.num_vars {
97            panic!("Prover is not active");
98        }
99
100        let i = prover_state.round;
101        let nv = prover_state.num_vars;
102        let degree = prover_state.max_multiplicands; // the degree of univariate polynomial sent by prover at this round
103
104        #[cfg(not(feature = "parallel"))]
105        let zeros = (vec![F::zero(); degree + 1], vec![F::zero(); degree + 1]);
106        #[cfg(feature = "parallel")]
107        let zeros = || (vec![F::zero(); degree + 1], vec![F::zero(); degree + 1]);
108
109        // generate sum
110        let fold_result = ark_std::cfg_into_iter!(0..1 << (nv - i), 1 << 10).fold(
111            zeros,
112            |(mut products_sum, mut product), b| {
113                // In effect, this fold is essentially doing simply:
114                // for b in 0..1 << (nv - i) {
115                for (coefficient, products) in &prover_state.list_of_products {
116                    product.fill(*coefficient);
117                    for &jth_product in products {
118                        let table = &prover_state.flattened_ml_extensions[jth_product];
119                        let mut start = table[b << 1];
120                        let step = table[(b << 1) + 1] - start;
121                        for p in product.iter_mut() {
122                            *p *= start;
123                            start += step;
124                        }
125                    }
126                    for t in 0..degree + 1 {
127                        products_sum[t] += product[t];
128                    }
129                }
130                (products_sum, product)
131            },
132        );
133
134        #[cfg(not(feature = "parallel"))]
135        let products_sum = fold_result.0;
136
137        // When rayon is used, the `fold` operation results in a iterator of `Vec<F>` rather than a single `Vec<F>`. In this case, we simply need to sum them.
138        #[cfg(feature = "parallel")]
139        let products_sum = fold_result.map(|scratch| scratch.0).reduce(
140            || vec![F::zero(); degree + 1],
141            |mut overall_products_sum, sublist_sum| {
142                overall_products_sum
143                    .iter_mut()
144                    .zip(sublist_sum.iter())
145                    .for_each(|(f, s)| *f += s);
146                overall_products_sum
147            },
148        );
149
150        ProverMsg {
151            evaluations: products_sum,
152        }
153    }
154}