Skip to main content

ark_r1cs_std/boolean/
eq.rs

1use ark_relations::gr1cs::SynthesisError;
2
3use crate::{boolean::Boolean, eq::EqGadget};
4
5use super::*;
6
7impl<F: Field> EqGadget<F> for Boolean<F> {
8    #[tracing::instrument(target = "gr1cs")]
9    fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
10        // self | other | XNOR(self, other) | self == other
11        // -----|-------|-------------------|--------------
12        //   0  |   0   |         1         |      1
13        //   0  |   1   |         0         |      0
14        //   1  |   0   |         0         |      0
15        //   1  |   1   |         1         |      1
16        Ok(!(self ^ other))
17    }
18
19    #[tracing::instrument(target = "gr1cs")]
20    fn conditional_enforce_equal(
21        &self,
22        other: &Self,
23        condition: &Boolean<F>,
24    ) -> Result<(), SynthesisError> {
25        use Boolean::*;
26        let one = Variable::One;
27        // We will use the following trick: a == b <=> a - b == 0
28        // This works because a - b == 0 if and only if a = 0 and b = 0, or a = 1 and b
29        // = 1, which is exactly the definition of a == b.
30
31        if condition != &Constant(false) {
32            let cs = self.cs().or(other.cs()).or(condition.cs());
33            match (self, other) {
34                // 1 == 1; 0 == 0
35                (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => {
36                    return Ok(())
37                },
38                // false != true
39                (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable),
40                // handled below
41                (_, _) => (),
42            };
43            let difference = || match (self, other) {
44                // 1 - a
45                (Constant(true), Var(a)) | (Var(a), Constant(true)) => {
46                    lc_diff![one, a.variable()]
47                },
48                // a - 0 = a
49                (Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(),
50                // b - a,
51                (Var(a), Var(b)) => lc_diff![b.variable(), a.variable()],
52                // handled above
53                (_, _) => unreachable!(),
54            };
55            cs.enforce_r1cs_constraint(difference, || condition.lc(), || lc!())?;
56        }
57        Ok(())
58    }
59
60    #[tracing::instrument(target = "gr1cs")]
61    fn conditional_enforce_not_equal(
62        &self,
63        other: &Self,
64        should_enforce: &Boolean<F>,
65    ) -> Result<(), SynthesisError> {
66        use Boolean::*;
67        let one = Variable::One;
68
69        if should_enforce != &Constant(false) {
70            let cs = self.cs().or(other.cs()).or(should_enforce.cs());
71            // We will use the following trick: a != b <=> a + b == 1
72            // This works because a + b == 1 if and only if a = 0 and b = 1, or a = 1 and b
73            // = 0, which is exactly the definition of a != b.
74            match (self, other) {
75                // 1 != 0; 0 != 1
76                (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => {
77                    return Ok(())
78                },
79                // false == false and true == true
80                (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable),
81                (_, _) => (),
82            }
83            let sum = || match (self, other) {
84                // 1 + a
85                (Constant(true), Var(a)) | (Var(a), Constant(true)) => {
86                    lc![one, a.variable()]
87                },
88                // a + 0 = a
89                (Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(),
90                // b + a,
91                (Var(a), Var(b)) => lc![b.variable(), a.variable()],
92                // handled above
93                (_, _) => unreachable!(),
94            };
95            cs.enforce_r1cs_constraint(sum, || should_enforce.lc(), || one.into())?;
96        }
97        Ok(())
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::{
105        alloc::{AllocVar, AllocationMode},
106        boolean::test_utils::{run_binary_exhaustive, run_unary_exhaustive},
107        prelude::EqGadget,
108        GR1CSVar,
109    };
110    use ark_test_curves::bls12_381::Fr;
111
112    #[test]
113    fn eq() {
114        run_binary_exhaustive::<Fr>(|a, b| {
115            let cs = a.cs().or(b.cs());
116            let both_constant = a.is_constant() && b.is_constant();
117            let computed = &a.is_eq(&b)?;
118            let expected_mode = if both_constant {
119                AllocationMode::Constant
120            } else {
121                AllocationMode::Witness
122            };
123            let expected =
124                Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
125            assert_eq!(expected.value(), computed.value());
126            expected.enforce_equal(&computed)?;
127            if !both_constant {
128                assert!(cs.is_satisfied().unwrap());
129            }
130            Ok(())
131        })
132        .unwrap()
133    }
134
135    #[test]
136    fn neq() {
137        run_binary_exhaustive::<Fr>(|a, b| {
138            let cs = a.cs().or(b.cs());
139            let both_constant = a.is_constant() && b.is_constant();
140            let computed = &a.is_neq(&b)?;
141            let expected_mode = if both_constant {
142                AllocationMode::Constant
143            } else {
144                AllocationMode::Witness
145            };
146            let expected =
147                Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
148            assert_eq!(expected.value(), computed.value());
149            expected.enforce_equal(&computed)?;
150            if !both_constant {
151                assert!(cs.is_satisfied().unwrap());
152            }
153            Ok(())
154        })
155        .unwrap()
156    }
157
158    #[test]
159    fn neq_and_eq_consistency() {
160        run_binary_exhaustive::<Fr>(|a, b| {
161            let cs = a.cs().or(b.cs());
162            let both_constant = a.is_constant() && b.is_constant();
163            let is_neq = &a.is_neq(&b)?;
164            let is_eq = &a.is_eq(&b)?;
165            let expected_mode = if both_constant {
166                AllocationMode::Constant
167            } else {
168                AllocationMode::Witness
169            };
170            let expected_is_neq =
171                Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
172            assert_eq!(expected_is_neq.value(), is_neq.value());
173            assert_ne!(expected_is_neq.value(), is_eq.value());
174            expected_is_neq.enforce_equal(is_neq)?;
175            expected_is_neq.enforce_equal(&!is_eq)?;
176            expected_is_neq.enforce_not_equal(is_eq)?;
177            if !both_constant {
178                assert!(cs.is_satisfied().unwrap());
179            }
180            Ok(())
181        })
182        .unwrap()
183    }
184
185    #[test]
186    fn enforce_eq_and_enforce_neq_consistency() {
187        run_unary_exhaustive::<Fr>(|a| {
188            let cs = a.cs();
189            let not_a = !&a;
190            a.enforce_equal(&a)?;
191            not_a.enforce_equal(&not_a)?;
192            a.enforce_not_equal(&not_a)?;
193            not_a.enforce_not_equal(&a)?;
194            if !a.is_constant() {
195                assert!(cs.is_satisfied().unwrap());
196            }
197            Ok(())
198        })
199        .unwrap()
200    }
201
202    #[test]
203    fn eq_soundness() {
204        run_binary_exhaustive::<Fr>(|a, b| {
205            let cs = a.cs().or(b.cs());
206            let both_constant = a.is_constant() && b.is_constant();
207            let computed = &a.is_eq(&b)?;
208            let expected_mode = if both_constant {
209                AllocationMode::Constant
210            } else {
211                AllocationMode::Witness
212            };
213            let expected =
214                Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
215            assert_ne!(expected.value(), computed.value());
216            expected.enforce_not_equal(&computed)?;
217            if !both_constant {
218                assert!(cs.is_satisfied().unwrap());
219            }
220            Ok(())
221        })
222        .unwrap()
223    }
224
225    #[test]
226    fn neq_soundness() {
227        run_binary_exhaustive::<Fr>(|a, b| {
228            let cs = a.cs().or(b.cs());
229            let both_constant = a.is_constant() && b.is_constant();
230            let computed = &a.is_neq(&b)?;
231            let expected_mode = if both_constant {
232                AllocationMode::Constant
233            } else {
234                AllocationMode::Witness
235            };
236            let expected =
237                Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
238            assert_ne!(expected.value(), computed.value());
239            expected.enforce_not_equal(&computed)?;
240            if !both_constant {
241                assert!(cs.is_satisfied().unwrap());
242            }
243            Ok(())
244        })
245        .unwrap()
246    }
247}