1use crate::{prelude::*, Vec};
2use ark_ff::{Field, PrimeField};
3use ark_relations::r1cs::SynthesisError;
4
5pub trait EqGadget<F: Field> {
8 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError>;
11
12 fn is_neq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
17 Ok(!self.is_eq(other)?)
18 }
19
20 #[tracing::instrument(target = "r1cs", skip(self, other))]
30 fn conditional_enforce_equal(
31 &self,
32 other: &Self,
33 should_enforce: &Boolean<F>,
34 ) -> Result<(), SynthesisError> {
35 self.is_eq(&other)?
36 .conditional_enforce_equal(&Boolean::TRUE, should_enforce)
37 }
38
39 #[tracing::instrument(target = "r1cs", skip(self, other))]
48 fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> {
49 self.conditional_enforce_equal(other, &Boolean::TRUE)
50 }
51
52 #[tracing::instrument(target = "r1cs", skip(self, other))]
62 fn conditional_enforce_not_equal(
63 &self,
64 other: &Self,
65 should_enforce: &Boolean<F>,
66 ) -> Result<(), SynthesisError> {
67 self.is_neq(&other)?
68 .conditional_enforce_equal(&Boolean::TRUE, should_enforce)
69 }
70
71 #[tracing::instrument(target = "r1cs", skip(self, other))]
80 fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> {
81 self.conditional_enforce_not_equal(other, &Boolean::TRUE)
82 }
83}
84
85impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
86 #[tracing::instrument(target = "r1cs", skip(self, other))]
87 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
88 assert_eq!(self.len(), other.len());
89 if self.is_empty() & other.is_empty() {
90 Ok(Boolean::TRUE)
91 } else {
92 let mut results = Vec::with_capacity(self.len());
93 for (a, b) in self.iter().zip(other) {
94 results.push(a.is_eq(b)?);
95 }
96 Boolean::kary_and(&results)
97 }
98 }
99
100 #[tracing::instrument(target = "r1cs", skip(self, other))]
101 fn conditional_enforce_equal(
102 &self,
103 other: &Self,
104 condition: &Boolean<F>,
105 ) -> Result<(), SynthesisError> {
106 assert_eq!(self.len(), other.len());
107 for (a, b) in self.iter().zip(other) {
108 a.conditional_enforce_equal(b, condition)?;
109 }
110 Ok(())
111 }
112
113 #[tracing::instrument(target = "r1cs", skip(self, other))]
114 fn conditional_enforce_not_equal(
115 &self,
116 other: &Self,
117 should_enforce: &Boolean<F>,
118 ) -> Result<(), SynthesisError> {
119 assert_eq!(self.len(), other.len());
120 let some_are_different = self.is_neq(other)?;
121 if [&some_are_different, should_enforce].is_constant() {
122 assert!(some_are_different.value()?);
123 Ok(())
124 } else {
125 let cs = [&some_are_different, should_enforce].cs();
126 cs.enforce_constraint(
127 some_are_different.lc(),
128 should_enforce.lc(),
129 should_enforce.lc(),
130 )
131 }
132 }
133}
134
135impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for Vec<T> {
137 #[tracing::instrument(target = "r1cs", skip(self, other))]
138 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
139 self.as_slice().is_eq(other.as_slice())
140 }
141
142 #[tracing::instrument(target = "r1cs", skip(self, other))]
143 fn conditional_enforce_equal(
144 &self,
145 other: &Self,
146 condition: &Boolean<F>,
147 ) -> Result<(), SynthesisError> {
148 self.as_slice()
149 .conditional_enforce_equal(other.as_slice(), condition)
150 }
151
152 #[tracing::instrument(target = "r1cs", skip(self, other))]
153 fn conditional_enforce_not_equal(
154 &self,
155 other: &Self,
156 should_enforce: &Boolean<F>,
157 ) -> Result<(), SynthesisError> {
158 self.as_slice()
159 .conditional_enforce_not_equal(other.as_slice(), should_enforce)
160 }
161}
162
163impl<F: Field> EqGadget<F> for () {
165 #[inline]
168 fn is_eq(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
169 Ok(Boolean::TRUE)
170 }
171
172 #[tracing::instrument(target = "r1cs", skip(self, _other))]
177 fn conditional_enforce_equal(
178 &self,
179 _other: &Self,
180 _should_enforce: &Boolean<F>,
181 ) -> Result<(), SynthesisError> {
182 Ok(())
183 }
184
185 #[tracing::instrument(target = "r1cs", skip(self, _other))]
190 fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> {
191 Ok(())
192 }
193}
194
195impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField, const N: usize> EqGadget<F> for [T; N] {
197 #[tracing::instrument(target = "r1cs", skip(self, other))]
198 fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
199 self.as_slice().is_eq(other.as_slice())
200 }
201
202 #[tracing::instrument(target = "r1cs", skip(self, other))]
203 fn conditional_enforce_equal(
204 &self,
205 other: &Self,
206 condition: &Boolean<F>,
207 ) -> Result<(), SynthesisError> {
208 self.as_slice()
209 .conditional_enforce_equal(other.as_slice(), condition)
210 }
211
212 #[tracing::instrument(target = "r1cs", skip(self, other))]
213 fn conditional_enforce_not_equal(
214 &self,
215 other: &Self,
216 should_enforce: &Boolean<F>,
217 ) -> Result<(), SynthesisError> {
218 self.as_slice()
219 .conditional_enforce_not_equal(other.as_slice(), should_enforce)
220 }
221}