1use crate::prelude::*;
2use ark_ff::{Field, PrimeField};
3use ark_relations::gr1cs::SynthesisError;
4use ark_std::vec::Vec;
5
6pub trait EqGadget<F: Field> {
9 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError>;
12
13 fn is_neq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
18 Ok(!self.is_eq(other)?)
19 }
20
21 #[tracing::instrument(target = "gr1cs", skip(self, other))]
31 fn conditional_enforce_equal(
32 &self,
33 other: &Self,
34 should_enforce: &Boolean<F>,
35 ) -> Result<(), SynthesisError> {
36 self.is_eq(&other)?
37 .conditional_enforce_equal(&Boolean::TRUE, should_enforce)
38 }
39
40 #[tracing::instrument(target = "gr1cs", skip(self, other))]
49 fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> {
50 self.conditional_enforce_equal(other, &Boolean::TRUE)
51 }
52
53 #[tracing::instrument(target = "gr1cs", skip(self, other))]
63 fn conditional_enforce_not_equal(
64 &self,
65 other: &Self,
66 should_enforce: &Boolean<F>,
67 ) -> Result<(), SynthesisError> {
68 self.is_neq(&other)?
69 .conditional_enforce_equal(&Boolean::TRUE, should_enforce)
70 }
71
72 #[tracing::instrument(target = "gr1cs", skip(self, other))]
81 fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> {
82 self.conditional_enforce_not_equal(other, &Boolean::TRUE)
83 }
84}
85
86impl<T: EqGadget<F> + GR1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
87 #[tracing::instrument(target = "gr1cs", skip(self, other))]
88 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
89 assert_eq!(self.len(), other.len());
90 if self.is_empty() & other.is_empty() {
91 Ok(Boolean::TRUE)
92 } else {
93 let mut results = Vec::with_capacity(self.len());
94 for (a, b) in self.iter().zip(other) {
95 results.push(a.is_eq(b)?);
96 }
97 Boolean::kary_and(&results)
98 }
99 }
100
101 #[tracing::instrument(target = "gr1cs", skip(self, other))]
102 fn conditional_enforce_equal(
103 &self,
104 other: &Self,
105 condition: &Boolean<F>,
106 ) -> Result<(), SynthesisError> {
107 assert_eq!(self.len(), other.len());
108 for (a, b) in self.iter().zip(other) {
109 a.conditional_enforce_equal(b, condition)?;
110 }
111 Ok(())
112 }
113
114 #[tracing::instrument(target = "gr1cs", skip(self, other))]
115 fn conditional_enforce_not_equal(
116 &self,
117 other: &Self,
118 should_enforce: &Boolean<F>,
119 ) -> Result<(), SynthesisError> {
120 assert_eq!(self.len(), other.len());
121 let some_are_different = self.is_neq(other)?;
122 if [&some_are_different, should_enforce].is_constant() {
123 assert!(some_are_different.value()?);
124 Ok(())
125 } else {
126 let cs = [&some_are_different, should_enforce].cs();
127 cs.enforce_r1cs_constraint(
128 || some_are_different.lc(),
129 || should_enforce.variable().into(),
130 || should_enforce.variable().into(),
131 )
132 }
133 }
134}
135
136impl<T: EqGadget<F> + GR1CSVar<F>, F: PrimeField> EqGadget<F> for Vec<T> {
138 #[tracing::instrument(target = "gr1cs", skip(self, other))]
139 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
140 self.as_slice().is_eq(other.as_slice())
141 }
142
143 #[tracing::instrument(target = "gr1cs", skip(self, other))]
144 fn conditional_enforce_equal(
145 &self,
146 other: &Self,
147 condition: &Boolean<F>,
148 ) -> Result<(), SynthesisError> {
149 self.as_slice()
150 .conditional_enforce_equal(other.as_slice(), condition)
151 }
152
153 #[tracing::instrument(target = "gr1cs", skip(self, other))]
154 fn conditional_enforce_not_equal(
155 &self,
156 other: &Self,
157 should_enforce: &Boolean<F>,
158 ) -> Result<(), SynthesisError> {
159 self.as_slice()
160 .conditional_enforce_not_equal(other.as_slice(), should_enforce)
161 }
162}
163
164impl<F: Field> EqGadget<F> for () {
166 #[inline]
169 fn is_eq(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
170 Ok(Boolean::TRUE)
171 }
172
173 #[tracing::instrument(target = "gr1cs", skip(self, _other))]
178 fn conditional_enforce_equal(
179 &self,
180 _other: &Self,
181 _should_enforce: &Boolean<F>,
182 ) -> Result<(), SynthesisError> {
183 Ok(())
184 }
185
186 #[tracing::instrument(target = "gr1cs", skip(self, _other))]
191 fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> {
192 Ok(())
193 }
194}
195
196impl<T: EqGadget<F> + GR1CSVar<F>, F: PrimeField, const N: usize> EqGadget<F> for [T; N] {
198 #[tracing::instrument(target = "gr1cs", skip(self, other))]
199 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
200 self.as_slice().is_eq(other.as_slice())
201 }
202
203 #[tracing::instrument(target = "gr1cs", skip(self, other))]
204 fn conditional_enforce_equal(
205 &self,
206 other: &Self,
207 condition: &Boolean<F>,
208 ) -> Result<(), SynthesisError> {
209 self.as_slice()
210 .conditional_enforce_equal(other.as_slice(), condition)
211 }
212
213 #[tracing::instrument(target = "gr1cs", skip(self, other))]
214 fn conditional_enforce_not_equal(
215 &self,
216 other: &Self,
217 should_enforce: &Boolean<F>,
218 ) -> Result<(), SynthesisError> {
219 self.as_slice()
220 .conditional_enforce_not_equal(other.as_slice(), should_enforce)
221 }
222}