ark_marlin/ahp/
mod.rs

1use crate::{String, ToString, Vec};
2use ark_ff::{Field, PrimeField};
3use ark_poly::univariate::DensePolynomial;
4use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
5use ark_poly_commit::{LCTerm, LinearCombination};
6use ark_relations::r1cs::SynthesisError;
7use ark_std::{borrow::Borrow, cfg_iter_mut, format, marker::PhantomData, vec};
8
9#[cfg(feature = "parallel")]
10use rayon::prelude::*;
11
12pub(crate) mod constraint_systems;
13/// Describes data structures and the algorithms used by the AHP indexer.
14pub mod indexer;
15/// Describes data structures and the algorithms used by the AHP prover.
16pub mod prover;
17/// Describes data structures and the algorithms used by the AHP verifier.
18pub mod verifier;
19
20/// A labeled DensePolynomial with coefficients over `F`
21pub type LabeledPolynomial<F> = ark_poly_commit::LabeledPolynomial<F, DensePolynomial<F>>;
22
23/// The algebraic holographic proof defined in [CHMMVW19](https://eprint.iacr.org/2019/1047).
24/// Currently, this AHP only supports inputs of size one
25/// less than a power of 2 (i.e., of the form 2^n - 1).
26pub struct AHPForR1CS<F: Field> {
27    field: PhantomData<F>,
28}
29
30impl<F: PrimeField> AHPForR1CS<F> {
31    /// The labels for the polynomials output by the AHP indexer.
32    #[rustfmt::skip]
33    pub const INDEXER_POLYNOMIALS: [&'static str; 12] = [
34        // Polynomials for A
35        "a_row", "a_col", "a_val", "a_row_col",
36        // Polynomials for B
37        "b_row", "b_col", "b_val", "b_row_col",
38        // Polynomials for C
39        "c_row", "c_col", "c_val", "c_row_col",
40    ];
41
42    /// The labels for the polynomials output by the AHP prover.
43    #[rustfmt::skip]
44    pub const PROVER_POLYNOMIALS: [&'static str; 9] = [
45        // First sumcheck
46        "w", "z_a", "z_b", "mask_poly", "t", "g_1", "h_1",
47        // Second sumcheck
48        "g_2", "h_2",
49    ];
50
51    /// THe linear combinations that are statically known to evaluate to zero.
52    pub const LC_WITH_ZERO_EVAL: [&'static str; 2] = ["inner_sumcheck", "outer_sumcheck"];
53
54    pub(crate) fn polynomial_labels() -> impl Iterator<Item = String> {
55        Self::INDEXER_POLYNOMIALS
56            .iter()
57            .chain(&Self::PROVER_POLYNOMIALS)
58            .map(|s| s.to_string())
59    }
60
61    /// Check that the (formatted) public input is of the form 2^n for some integer n.
62    pub fn num_formatted_public_inputs_is_admissible(num_inputs: usize) -> bool {
63        num_inputs.count_ones() == 1
64    }
65
66    /// Check that the (formatted) public input is of the form 2^n for some integer n.
67    pub fn formatted_public_input_is_admissible(input: &[F]) -> bool {
68        Self::num_formatted_public_inputs_is_admissible(input.len())
69    }
70
71    /// The maximum degree of polynomials produced by the indexer and prover
72    /// of this protocol.
73    /// The number of the variables must include the "one" variable. That is, it
74    /// must be with respect to the number of formatted public inputs.
75    pub fn max_degree(
76        num_constraints: usize,
77        num_variables: usize,
78        num_non_zero: usize,
79    ) -> Result<usize, Error> {
80        let padded_matrix_dim =
81            constraint_systems::padded_matrix_dim(num_variables, num_constraints);
82        let zk_bound = 1;
83        let domain_h_size = GeneralEvaluationDomain::<F>::compute_size_of_domain(padded_matrix_dim)
84            .ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
85        let domain_k_size = GeneralEvaluationDomain::<F>::compute_size_of_domain(num_non_zero)
86            .ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
87        Ok(*[
88            2 * domain_h_size + zk_bound - 2,
89            3 * domain_h_size + 2 * zk_bound - 3, //  mask_poly
90            domain_h_size,
91            domain_h_size,
92            3 * domain_k_size - 3,
93        ]
94        .iter()
95        .max()
96        .unwrap())
97    }
98
99    /// Get all the strict degree bounds enforced in the AHP.
100    pub fn get_degree_bounds(info: &indexer::IndexInfo<F>) -> [usize; 2] {
101        let mut degree_bounds = [0usize; 2];
102        let num_constraints = info.num_constraints;
103        let num_non_zero = info.num_non_zero;
104        let h_size = GeneralEvaluationDomain::<F>::compute_size_of_domain(num_constraints).unwrap();
105        let k_size = GeneralEvaluationDomain::<F>::compute_size_of_domain(num_non_zero).unwrap();
106
107        degree_bounds[0] = h_size - 2;
108        degree_bounds[1] = k_size - 2;
109        degree_bounds
110    }
111
112    /// Construct the linear combinations that are checked by the AHP.
113    #[allow(non_snake_case)]
114    pub fn construct_linear_combinations<E>(
115        public_input: &[F],
116        evals: &E,
117        state: &verifier::VerifierState<F>,
118    ) -> Result<Vec<LinearCombination<F>>, Error>
119    where
120        E: EvaluationsProvider<F>,
121    {
122        let domain_h = state.domain_h;
123        let domain_k = state.domain_k;
124        let k_size = domain_k.size_as_field_element();
125
126        let public_input = constraint_systems::format_public_input(public_input);
127        if !Self::formatted_public_input_is_admissible(&public_input) {
128            return Err(Error::InvalidPublicInputLength);
129        }
130        let x_domain = GeneralEvaluationDomain::new(public_input.len())
131            .ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
132
133        let first_round_msg = state.first_round_msg.unwrap();
134        let alpha = first_round_msg.alpha;
135        let eta_a = first_round_msg.eta_a;
136        let eta_b = first_round_msg.eta_b;
137        let eta_c = first_round_msg.eta_c;
138
139        let beta = state.second_round_msg.unwrap().beta;
140        let gamma = state.gamma.unwrap();
141
142        let mut linear_combinations = Vec::new();
143
144        // Outer sumcheck:
145        let z_b = LinearCombination::new("z_b", vec![(F::one(), "z_b")]);
146        let g_1 = LinearCombination::new("g_1", vec![(F::one(), "g_1")]);
147        let t = LinearCombination::new("t", vec![(F::one(), "t")]);
148
149        let r_alpha_at_beta = domain_h.eval_unnormalized_bivariate_lagrange_poly(alpha, beta);
150        let v_H_at_alpha = domain_h.evaluate_vanishing_polynomial(alpha);
151        let v_H_at_beta = domain_h.evaluate_vanishing_polynomial(beta);
152        let v_X_at_beta = x_domain.evaluate_vanishing_polynomial(beta);
153
154        let z_b_at_beta = evals.get_lc_eval(&z_b, beta)?;
155        let t_at_beta = evals.get_lc_eval(&t, beta)?;
156        let g_1_at_beta = evals.get_lc_eval(&g_1, beta)?;
157
158        let x_at_beta = x_domain
159            .evaluate_all_lagrange_coefficients(beta)
160            .into_iter()
161            .zip(public_input)
162            .map(|(l, x)| l * &x)
163            .fold(F::zero(), |x, y| x + &y);
164
165        #[rustfmt::skip]
166        let outer_sumcheck = LinearCombination::new(
167            "outer_sumcheck",
168            vec![
169                (F::one(), "mask_poly".into()),
170
171                (r_alpha_at_beta * (eta_a + eta_c * z_b_at_beta), "z_a".into()),
172                (r_alpha_at_beta * eta_b * z_b_at_beta, LCTerm::One),
173
174                (-t_at_beta * v_X_at_beta, "w".into()),
175                (-t_at_beta * x_at_beta, LCTerm::One),
176
177                (-v_H_at_beta, "h_1".into()),
178                (-beta * g_1_at_beta, LCTerm::One),
179            ],
180        );
181        debug_assert!(evals.get_lc_eval(&outer_sumcheck, beta)?.is_zero());
182
183        linear_combinations.push(z_b);
184        linear_combinations.push(g_1);
185        linear_combinations.push(t);
186        linear_combinations.push(outer_sumcheck);
187
188        //  Inner sumcheck:
189        let beta_alpha = beta * alpha;
190        let g_2 = LinearCombination::new("g_2", vec![(F::one(), "g_2")]);
191
192        let a_denom = LinearCombination::new(
193            "a_denom",
194            vec![
195                (beta_alpha, LCTerm::One),
196                (-alpha, "a_row".into()),
197                (-beta, "a_col".into()),
198                (F::one(), "a_row_col".into()),
199            ],
200        );
201
202        let b_denom = LinearCombination::new(
203            "b_denom",
204            vec![
205                (beta_alpha, LCTerm::One),
206                (-alpha, "b_row".into()),
207                (-beta, "b_col".into()),
208                (F::one(), "b_row_col".into()),
209            ],
210        );
211
212        let c_denom = LinearCombination::new(
213            "c_denom",
214            vec![
215                (beta_alpha, LCTerm::One),
216                (-alpha, "c_row".into()),
217                (-beta, "c_col".into()),
218                (F::one(), "c_row_col".into()),
219            ],
220        );
221
222        let a_denom_at_gamma = evals.get_lc_eval(&a_denom, gamma)?;
223        let b_denom_at_gamma = evals.get_lc_eval(&b_denom, gamma)?;
224        let c_denom_at_gamma = evals.get_lc_eval(&c_denom, gamma)?;
225        let g_2_at_gamma = evals.get_lc_eval(&g_2, gamma)?;
226
227        let v_K_at_gamma = domain_k.evaluate_vanishing_polynomial(gamma);
228
229        let mut a = LinearCombination::new(
230            "a_poly",
231            vec![
232                (eta_a * b_denom_at_gamma * c_denom_at_gamma, "a_val"),
233                (eta_b * a_denom_at_gamma * c_denom_at_gamma, "b_val"),
234                (eta_c * b_denom_at_gamma * a_denom_at_gamma, "c_val"),
235            ],
236        );
237
238        a *= v_H_at_alpha * v_H_at_beta;
239        let b_at_gamma = a_denom_at_gamma * b_denom_at_gamma * c_denom_at_gamma;
240        let b_expr_at_gamma = b_at_gamma * (gamma * g_2_at_gamma + &(t_at_beta / &k_size));
241
242        a -= &LinearCombination::new("b_expr", vec![(b_expr_at_gamma, LCTerm::One)]);
243        a -= &LinearCombination::new("h_2", vec![(v_K_at_gamma, "h_2")]);
244
245        a.label = "inner_sumcheck".into();
246        let inner_sumcheck = a;
247        debug_assert!(evals.get_lc_eval(&inner_sumcheck, gamma)?.is_zero());
248
249        linear_combinations.push(g_2);
250        linear_combinations.push(a_denom);
251        linear_combinations.push(b_denom);
252        linear_combinations.push(c_denom);
253        linear_combinations.push(inner_sumcheck);
254
255        linear_combinations.sort_by(|a, b| a.label.cmp(&b.label));
256        Ok(linear_combinations)
257    }
258}
259
260/// Abstraction that provides evaluations of (linear combinations of) polynomials
261///
262/// Intended to provide a common interface for both the prover and the verifier
263/// when constructing linear combinations via `AHPForR1CS::construct_linear_combinations`.
264pub trait EvaluationsProvider<F: Field> {
265    /// Get the evaluation of linear combination `lc` at `point`.
266    fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F, Error>;
267}
268
269impl<'a, F: Field> EvaluationsProvider<F> for ark_poly_commit::Evaluations<F, F> {
270    fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F, Error> {
271        let key = (lc.label.clone(), point);
272        self.get(&key)
273            .map(|v| *v)
274            .ok_or(Error::MissingEval(lc.label.clone()))
275    }
276}
277
278impl<F: Field, T: Borrow<LabeledPolynomial<F>>> EvaluationsProvider<F> for Vec<T> {
279    fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F, Error> {
280        let mut eval = F::zero();
281        for (coeff, term) in lc.iter() {
282            let value = if let LCTerm::PolyLabel(label) = term {
283                self.iter()
284                    .find(|p| {
285                        let p: &LabeledPolynomial<F> = (*p).borrow();
286                        p.label() == label
287                    })
288                    .ok_or(Error::MissingEval(format!(
289                        "Missing {} for {}",
290                        label, lc.label
291                    )))?
292                    .borrow()
293                    .evaluate(&point)
294            } else {
295                assert!(term.is_one());
296                F::one()
297            };
298            eval += *coeff * value
299        }
300        Ok(eval)
301    }
302}
303
304/// Describes the failure modes of the AHP scheme.
305#[derive(Debug)]
306pub enum Error {
307    /// During verification, a required evaluation is missing
308    MissingEval(String),
309    /// The number of public inputs is incorrect.
310    InvalidPublicInputLength,
311    /// The instance generated during proving does not match that in the index.
312    InstanceDoesNotMatchIndex,
313    /// Currently we only support square constraint matrices.
314    NonSquareMatrix,
315    /// An error occurred during constraint generation.
316    ConstraintSystemError(SynthesisError),
317}
318
319impl From<SynthesisError> for Error {
320    fn from(other: SynthesisError) -> Self {
321        Error::ConstraintSystemError(other)
322    }
323}
324
325/// The derivative of the vanishing polynomial
326pub trait UnnormalizedBivariateLagrangePoly<F: ark_ff::FftField> {
327    /// Evaluate the polynomial
328    fn eval_unnormalized_bivariate_lagrange_poly(&self, x: F, y: F) -> F;
329
330    /// Evaluate over a batch of inputs
331    fn batch_eval_unnormalized_bivariate_lagrange_poly_with_diff_inputs(&self, x: F) -> Vec<F>;
332
333    /// Evaluate the magic polynomial over `self`
334    fn batch_eval_unnormalized_bivariate_lagrange_poly_with_same_inputs(&self) -> Vec<F>;
335}
336
337impl<F: PrimeField> UnnormalizedBivariateLagrangePoly<F> for GeneralEvaluationDomain<F> {
338    fn eval_unnormalized_bivariate_lagrange_poly(&self, x: F, y: F) -> F {
339        if x != y {
340            (self.evaluate_vanishing_polynomial(x) - self.evaluate_vanishing_polynomial(y))
341                / (x - y)
342        } else {
343            self.size_as_field_element() * x.pow(&[(self.size() - 1) as u64])
344        }
345    }
346
347    fn batch_eval_unnormalized_bivariate_lagrange_poly_with_diff_inputs(&self, x: F) -> Vec<F> {
348        let vanish_x = self.evaluate_vanishing_polynomial(x);
349        let mut inverses: Vec<F> = self.elements().map(|y| x - y).collect();
350        ark_ff::batch_inversion(&mut inverses);
351
352        cfg_iter_mut!(inverses).for_each(|denominator| *denominator *= vanish_x);
353        inverses
354    }
355
356    fn batch_eval_unnormalized_bivariate_lagrange_poly_with_same_inputs(&self) -> Vec<F> {
357        let mut elems: Vec<F> = self
358            .elements()
359            .map(|e| e * self.size_as_field_element())
360            .collect();
361        elems[1..].reverse();
362        elems
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use ark_bls12_381::Fr;
370    use ark_ff::{One, UniformRand, Zero};
371    use ark_poly::{
372        univariate::{DenseOrSparsePolynomial, DensePolynomial},
373        Polynomial, UVPolynomial,
374    };
375
376    #[test]
377    fn domain_unnormalized_bivariate_lagrange_poly() {
378        for domain_size in 1..10 {
379            let domain = GeneralEvaluationDomain::<Fr>::new(1 << domain_size).unwrap();
380            let manual: Vec<_> = domain
381                .elements()
382                .map(|elem| domain.eval_unnormalized_bivariate_lagrange_poly(elem, elem))
383                .collect();
384            let fast = domain.batch_eval_unnormalized_bivariate_lagrange_poly_with_same_inputs();
385            assert_eq!(fast, manual);
386        }
387    }
388
389    #[test]
390    fn domain_unnormalized_bivariate_lagrange_poly_diff_inputs() {
391        let rng = &mut ark_std::test_rng();
392        for domain_size in 1..10 {
393            let domain = GeneralEvaluationDomain::<Fr>::new(1 << domain_size).unwrap();
394            let x = Fr::rand(rng);
395            let manual: Vec<_> = domain
396                .elements()
397                .map(|y| domain.eval_unnormalized_bivariate_lagrange_poly(x, y))
398                .collect();
399            let fast = domain.batch_eval_unnormalized_bivariate_lagrange_poly_with_diff_inputs(x);
400            assert_eq!(fast, manual);
401        }
402    }
403
404    #[test]
405    fn test_summation() {
406        let rng = &mut ark_std::test_rng();
407        let size = 1 << 4;
408        let domain = GeneralEvaluationDomain::<Fr>::new(1 << 4).unwrap();
409        let size_as_fe = domain.size_as_field_element();
410        let poly = DensePolynomial::rand(size, rng);
411
412        let mut sum: Fr = Fr::zero();
413        for eval in domain.elements().map(|e| poly.evaluate(&e)) {
414            sum += eval;
415        }
416        let first = poly.coeffs[0] * size_as_fe;
417        let last = *poly.coeffs.last().unwrap() * size_as_fe;
418        println!("sum: {:?}", sum);
419        println!("a_0: {:?}", first);
420        println!("a_n: {:?}", last);
421        println!("first + last: {:?}\n", first + last);
422        assert_eq!(sum, first + last);
423    }
424
425    #[test]
426    fn test_alternator_polynomial() {
427        use ark_poly::Evaluations;
428        let domain_k = GeneralEvaluationDomain::<Fr>::new(1 << 4).unwrap();
429        let domain_h = GeneralEvaluationDomain::<Fr>::new(1 << 3).unwrap();
430        let domain_h_elems = domain_h
431            .elements()
432            .collect::<std::collections::HashSet<_>>();
433        let alternator_poly_evals = domain_k
434            .elements()
435            .map(|e| {
436                if domain_h_elems.contains(&e) {
437                    Fr::one()
438                } else {
439                    Fr::zero()
440                }
441            })
442            .collect();
443        let v_k: DenseOrSparsePolynomial<_> = domain_k.vanishing_polynomial().into();
444        let v_h: DenseOrSparsePolynomial<_> = domain_h.vanishing_polynomial().into();
445        let (divisor, remainder) = v_k.divide_with_q_and_r(&v_h).unwrap();
446        assert!(remainder.is_zero());
447        println!("Divisor: {:?}", divisor);
448        println!(
449            "{:#?}",
450            divisor
451                .coeffs
452                .iter()
453                .filter_map(|f| if !f.is_zero() {
454                    Some(f.into_repr())
455                } else {
456                    None
457                })
458                .collect::<Vec<_>>()
459        );
460
461        for e in domain_h.elements() {
462            println!("{:?}", divisor.evaluate(&e));
463        }
464        // Let p = v_K / v_H;
465        // The alternator polynomial is p * t, where t is defined as
466        // the LDE of p(h)^{-1} for all h in H.
467        //
468        // Because for each h in H, p(h) equals a constant c, we have that t
469        // is the constant polynomial c^{-1}.
470        //
471        // Q: what is the constant c? Why is p(h) constant? What is the easiest
472        // way to calculate c?
473        let alternator_poly =
474            Evaluations::from_vec_and_domain(alternator_poly_evals, domain_k).interpolate();
475        let (quotient, remainder) = DenseOrSparsePolynomial::from(alternator_poly.clone())
476            .divide_with_q_and_r(&DenseOrSparsePolynomial::from(divisor))
477            .unwrap();
478        assert!(remainder.is_zero());
479        println!("quotient: {:?}", quotient);
480        println!(
481            "{:#?}",
482            quotient
483                .coeffs
484                .iter()
485                .filter_map(|f| if !f.is_zero() {
486                    Some(f.into_repr())
487                } else {
488                    None
489                })
490                .collect::<Vec<_>>()
491        );
492
493        println!("{:?}", alternator_poly);
494    }
495}