Skip to main content

ark_relations/gr1cs/predicate/
mod.rs

1//! This module contains the implementation of a general  predicate, defined in <https://eprint.iacr.org/2024/1245>.
2//! A  predicate is a function from t (arity) variables to a boolean
3//! variable A predicate can be as simple as f(a,b,c)=a.b-c=0 or as
4//! complex as a lookup table
5
6pub mod polynomial_constraint;
7
8use super::{Constraint, ConstraintSystem, Matrix};
9use crate::{gr1cs::Variable, utils::error::SynthesisError::ArityMismatch};
10use ark_ff::Field;
11use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError};
12use ark_std::{io::Write, vec::Vec};
13use polynomial_constraint::PolynomialPredicate;
14
15/// GR1CS can potentially support different types of predicates
16/// For now, we only support polynomial predicates
17/// In the future, we can add other types of predicates, e.g. lookup table
18#[derive(Debug, Clone)]
19#[non_exhaustive]
20pub enum Predicate<F: Field> {
21    /// A polynomial local predicate. This is the most common predicate that
22    /// captures high-degree custom gates
23    Polynomial(PolynomialPredicate<F>),
24    // Add other predicates in the future, e.g. lookup table
25}
26
27impl<F: Field> ark_serialize::Valid for Predicate<F> {
28    fn check(&self) -> Result<(), SerializationError> {
29        match self {
30            Predicate::Polynomial(p) => p.check(),
31        }
32    }
33}
34impl<F: Field> CanonicalDeserialize for Predicate<F> {
35    fn deserialize_with_mode<R: ark_serialize::Read>(
36        reader: R,
37        compress: Compress,
38        should_validate: ark_serialize::Validate,
39    ) -> Result<Self, SerializationError> {
40        let predicate_type =
41            PolynomialPredicate::<F>::deserialize_with_mode(reader, compress, should_validate)?;
42        Ok(Predicate::Polynomial(predicate_type))
43    }
44}
45
46impl<F: Field> CanonicalSerialize for Predicate<F> {
47    fn serialize_with_mode<W: Write>(
48        &self,
49        writer: W,
50        compress: Compress,
51    ) -> Result<(), SerializationError> {
52        match self {
53            Predicate::Polynomial(p) => p.serialize_with_mode(writer, compress),
54        }
55    }
56    fn serialized_size(&self, compress: Compress) -> usize {
57        match self {
58            Predicate::Polynomial(p) => p.serialized_size(compress),
59        }
60    }
61}
62
63impl<F: Field> Predicate<F> {
64    fn is_satisfied(&self, variables: &[F]) -> bool {
65        match self {
66            Predicate::Polynomial(p) => p.is_satisfied(variables),
67            // TODO: Add other predicates in the future, e.g. lookup table
68        }
69    }
70
71    fn arity(&self) -> usize {
72        match self {
73            Predicate::Polynomial(p) => p.arity(),
74            // TODO: Add other predicates in the future, e.g. lookup table
75        }
76    }
77}
78
79/// A constraint system that enforces a predicate
80#[derive(Debug, Clone)]
81pub struct PredicateConstraintSystem<F: Field> {
82    /// The inputs to the predicates.
83    /// The length of this list is equal to the arity of the predicate.
84    /// That is, `argument_lcs[i]` is the list of inputs to the `i`-th
85    /// argument of the predicate.
86    /// For each `i`, `argument_lcs[i]` has size equal to `self.num_constraints`.
87    argument_lcs: Vec<Vec<Variable>>,
88
89    /// The number of constraints enforced by this predicate.
90    num_constraints: usize,
91
92    /// The type of the predicate enforced by this constraint system.  
93    predicate: Predicate<F>,
94}
95
96impl<F: Field> PredicateConstraintSystem<F> {
97    /// Create a new predicate constraint system with a specific predicate
98    fn new(predicate: Predicate<F>) -> Self {
99        Self {
100            argument_lcs: vec![Vec::new(); predicate.arity()],
101            predicate,
102            num_constraints: 0,
103        }
104    }
105
106    /// Create new polynomial predicate constraint system
107    pub fn new_polynomial_predicate_cs(arity: usize, terms: Vec<(F, Vec<(usize, usize)>)>) -> Self {
108        Self::new(Predicate::Polynomial(PolynomialPredicate::new(
109            arity, terms,
110        )))
111    }
112
113    /// creates an R1CS predicate which is a special kind of polynomial
114    /// predicate
115    pub fn new_r1cs() -> crate::utils::Result<Self> {
116        Ok(Self::new_polynomial_predicate_cs(
117            3,
118            vec![(F::ONE, vec![(0, 1), (1, 1)]), (-F::ONE, vec![(2, 1)])],
119        ))
120    }
121
122    /// Creates a SquareR1CS predicate.
123    pub fn new_sr1cs_predicate() -> crate::utils::Result<Self> {
124        Ok(Self::new_polynomial_predicate_cs(
125            2,
126            vec![(F::ONE, vec![(0, 2)]), (-F::ONE, vec![(1, 1)])],
127        ))
128    }
129
130    /// Get the arity of the predicate of this [`PredicateConstraintSystem`].
131    pub fn get_arity(&self) -> usize {
132        self.predicate.arity()
133    }
134
135    /// Get the number of constraints enforced by this predicate.
136    pub fn num_constraints(&self) -> usize {
137        self.num_constraints
138    }
139
140    /// Get a list of constraints enforced in this [`PredicateConstraintSystem`].
141    /// Each constraint is a list of linear combinations with size equal to the
142    /// arity.
143    pub fn get_constraints(&self) -> &Vec<Constraint> {
144        &self.argument_lcs
145    }
146
147    /// Get a reference to the  predicate in this predicate constraint
148    /// system
149    pub fn get_predicate(&self) -> &Predicate<F> {
150        &self.predicate
151    }
152
153    /// Enforce a constraint in this [`PredicateConstraintSystem`].
154    /// The constraint is a list of linear combinations with size equal to the
155    /// arity.
156    pub fn enforce_constraint(
157        &mut self,
158        constraint: impl IntoIterator<Item = Variable>,
159    ) -> crate::utils::Result<()> {
160        let mut arity = 0;
161        constraint
162            .into_iter()
163            .zip(&mut self.argument_lcs)
164            .for_each(|(lc_index, arg_lc)| {
165                arity += 1;
166                arg_lc.push(lc_index);
167            });
168        if arity != self.get_arity() {
169            return Err(ArityMismatch);
170        }
171
172        self.num_constraints += 1;
173        Ok(())
174    }
175
176    fn iter_constraints(&self) -> impl Iterator<Item = Constraint> + '_ {
177        // Transpose the `argument_lcs` to iterate over constraints.
178        let num_constraints = self.num_constraints;
179
180        (0..num_constraints).map(move |i| self.argument_lcs.iter().map(|lc_s| lc_s[i]).collect())
181    }
182
183    /// Check if the constraints enforced by this predicate are satisfied
184    /// i.e. `L(x_1, x_2, ..., x_n) == 0`.
185    pub fn which_constraint_is_unsatisfied(&self, cs: &ConstraintSystem<F>) -> Option<usize> {
186        let panic_msg = |v| panic!("Variable {v:?} is not assigned; did you run `cs.finalize()`?");
187        for (i, constraint) in self.iter_constraints().enumerate() {
188            let variables: Vec<F> = constraint
189                .into_iter()
190                .map(|v| {
191                    cs.assigned_value(v).unwrap_or_else(|| {
192                        cs.get_lc(v)
193                            .iter()
194                            .map(|&(c, v)| c * cs.assigned_value(v).unwrap_or_else(|| panic_msg(v)))
195                            .sum()
196                    })
197                })
198                .collect();
199            if !self.predicate.is_satisfied(&variables) {
200                return Some(i);
201            }
202        }
203        None
204    }
205
206    /// Create the set of matrices for this [`PredicateConstraintSystem`].
207    pub fn to_matrices(&self, cs: &ConstraintSystem<F>) -> Vec<Matrix<F>> {
208        let mut matrices: Vec<Matrix<F>> = vec![Vec::new(); self.get_arity()];
209        for constraint in self.iter_constraints() {
210            for (matrix_ind, lc_index) in constraint.iter().enumerate() {
211                let lc = cs.get_lc(*lc_index);
212                let row = cs.make_row(lc);
213                matrices[matrix_ind].push(row);
214            }
215        }
216        matrices
217    }
218}