ark_linear_sumcheck/ml_sumcheck/protocol/
prover.rs1use crate::ml_sumcheck::data_structures::ListOfProductsOfPolynomials;
3use crate::ml_sumcheck::protocol::verifier::VerifierMsg;
4use crate::ml_sumcheck::protocol::IPForMLSumcheck;
5use ark_ff::Field;
6use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
7use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
8use ark_std::{cfg_iter_mut, vec::Vec};
9#[cfg(feature = "parallel")]
10use rayon::prelude::*;
11
12#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
14pub struct ProverMsg<F: Field> {
15 pub(crate) evaluations: Vec<F>,
17}
18pub struct ProverState<F: Field> {
20 pub randomness: Vec<F>,
22 pub list_of_products: Vec<(F, Vec<usize>)>,
25 pub flattened_ml_extensions: Vec<DenseMultilinearExtension<F>>,
27 pub num_vars: usize,
29 pub max_multiplicands: usize,
31 pub round: usize,
33}
34
35impl<F: Field> IPForMLSumcheck<F> {
36 pub fn prover_init(polynomial: &ListOfProductsOfPolynomials<F>) -> ProverState<F> {
50 if polynomial.num_variables == 0 {
51 panic!("Attempt to prove a constant.")
52 }
53
54 let flattened_ml_extensions = polynomial
56 .flattened_ml_extensions
57 .iter()
58 .map(|x| x.as_ref().clone())
59 .collect();
60
61 ProverState {
62 randomness: Vec::with_capacity(polynomial.num_variables),
63 list_of_products: polynomial.products.clone(),
64 flattened_ml_extensions,
65 num_vars: polynomial.num_variables,
66 max_multiplicands: polynomial.max_multiplicands,
67 round: 0,
68 }
69 }
70
71 pub fn prove_round(
75 prover_state: &mut ProverState<F>,
76 v_msg: &Option<VerifierMsg<F>>,
77 ) -> ProverMsg<F> {
78 if let Some(msg) = v_msg {
79 if prover_state.round == 0 {
80 panic!("first round should be prover first.");
81 }
82 prover_state.randomness.push(msg.randomness);
83
84 let i = prover_state.round;
86 let r = prover_state.randomness[i - 1];
87 cfg_iter_mut!(prover_state.flattened_ml_extensions).for_each(|multiplicand| {
88 *multiplicand = multiplicand.fix_variables(&[r]);
89 });
90 } else if prover_state.round > 0 {
91 panic!("verifier message is empty");
92 }
93
94 prover_state.round += 1;
95
96 if prover_state.round > prover_state.num_vars {
97 panic!("Prover is not active");
98 }
99
100 let i = prover_state.round;
101 let nv = prover_state.num_vars;
102 let degree = prover_state.max_multiplicands; #[cfg(not(feature = "parallel"))]
105 let zeros = (vec![F::zero(); degree + 1], vec![F::zero(); degree + 1]);
106 #[cfg(feature = "parallel")]
107 let zeros = || (vec![F::zero(); degree + 1], vec![F::zero(); degree + 1]);
108
109 let fold_result = ark_std::cfg_into_iter!(0..1 << (nv - i), 1 << 10).fold(
111 zeros,
112 |(mut products_sum, mut product), b| {
113 for (coefficient, products) in &prover_state.list_of_products {
116 product.fill(*coefficient);
117 for &jth_product in products {
118 let table = &prover_state.flattened_ml_extensions[jth_product];
119 let mut start = table[b << 1];
120 let step = table[(b << 1) + 1] - start;
121 for p in product.iter_mut() {
122 *p *= start;
123 start += step;
124 }
125 }
126 for t in 0..degree + 1 {
127 products_sum[t] += product[t];
128 }
129 }
130 (products_sum, product)
131 },
132 );
133
134 #[cfg(not(feature = "parallel"))]
135 let products_sum = fold_result.0;
136
137 #[cfg(feature = "parallel")]
139 let products_sum = fold_result.map(|scratch| scratch.0).reduce(
140 || vec![F::zero(); degree + 1],
141 |mut overall_products_sum, sublist_sum| {
142 overall_products_sum
143 .iter_mut()
144 .zip(sublist_sum.iter())
145 .for_each(|(f, s)| *f += s);
146 overall_products_sum
147 },
148 );
149
150 ProverMsg {
151 evaluations: products_sum,
152 }
153 }
154}