1use ark_ff::Field;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::{ops::BitXor, ops::BitXorAssign};
4
5use super::*;
6
7impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
8 fn _xor(&self, other: &Self) -> Result<Self, SynthesisError> {
9 let mut result = self.clone();
10 result._xor_in_place(other)?;
11 Ok(result)
12 }
13
14 fn _xor_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
15 for (a, b) in self.bits.iter_mut().zip(&other.bits) {
16 *a ^= b;
17 }
18 self.value = self.value.and_then(|a| Some(a ^ other.value?));
19 Ok(())
20 }
21}
22
23impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<Self> for &'a UInt<N, T, F> {
24 type Output = UInt<N, T, F>;
25 #[tracing::instrument(target = "r1cs", skip(self, other))]
48 fn bitxor(self, other: Self) -> Self::Output {
49 self._xor(other).unwrap()
50 }
51}
52
53impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt<N, T, F> {
54 type Output = UInt<N, T, F>;
55
56 #[tracing::instrument(target = "r1cs", skip(self, other))]
57 fn bitxor(mut self, other: &Self) -> Self::Output {
58 self._xor_in_place(&other).unwrap();
59 self
60 }
61}
62
63impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<UInt<N, T, F>> for &'a UInt<N, T, F> {
64 type Output = UInt<N, T, F>;
65
66 #[tracing::instrument(target = "r1cs", skip(self, other))]
67 fn bitxor(self, other: UInt<N, T, F>) -> Self::Output {
68 other ^ self
69 }
70}
71
72impl<const N: usize, T: PrimUInt, F: Field> BitXor<Self> for UInt<N, T, F> {
73 type Output = Self;
74
75 #[tracing::instrument(target = "r1cs", skip(self, other))]
76 fn bitxor(self, other: Self) -> Self::Output {
77 self ^ &other
78 }
79}
80
81impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for UInt<N, T, F> {
82 type Output = UInt<N, T, F>;
83
84 #[tracing::instrument(target = "r1cs", skip(self, other))]
85 fn bitxor(self, other: T) -> Self::Output {
86 self ^ &UInt::constant(other)
87 }
88}
89
90impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for UInt<N, T, F> {
91 type Output = UInt<N, T, F>;
92
93 #[tracing::instrument(target = "r1cs", skip(self, other))]
94 fn bitxor(self, other: &'a T) -> Self::Output {
95 self ^ &UInt::constant(*other)
96 }
97}
98
99impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for &'a UInt<N, T, F> {
100 type Output = UInt<N, T, F>;
101
102 #[tracing::instrument(target = "r1cs", skip(self, other))]
103 fn bitxor(self, other: &'a T) -> Self::Output {
104 self ^ UInt::constant(*other)
105 }
106}
107
108impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for &'a UInt<N, T, F> {
109 type Output = UInt<N, T, F>;
110
111 #[tracing::instrument(target = "r1cs", skip(self, other))]
112 fn bitxor(self, other: T) -> Self::Output {
113 self ^ UInt::constant(other)
114 }
115}
116
117impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<Self> for UInt<N, T, F> {
118 #[tracing::instrument(target = "r1cs", skip(self, other))]
142 fn bitxor_assign(&mut self, other: Self) {
143 self._xor_in_place(&other).unwrap();
144 }
145}
146
147impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt<N, T, F> {
148 #[tracing::instrument(target = "r1cs", skip(self, other))]
149 fn bitxor_assign(&mut self, other: &'a Self) {
150 self._xor_in_place(other).unwrap();
151 }
152}
153
154impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<T> for UInt<N, T, F> {
155 #[tracing::instrument(target = "r1cs", skip(self, other))]
156 fn bitxor_assign(&mut self, other: T) {
157 *self ^= Self::constant(other);
158 }
159}
160
161impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a T> for UInt<N, T, F> {
162 #[tracing::instrument(target = "r1cs", skip(self, other))]
163 fn bitxor_assign(&mut self, other: &'a T) {
164 *self ^= Self::constant(*other);
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::{
172 alloc::{AllocVar, AllocationMode},
173 prelude::EqGadget,
174 uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
175 R1CSVar,
176 };
177 use ark_ff::PrimeField;
178 use ark_test_curves::bls12_381::Fr;
179
180 fn uint_xor<T: PrimUInt, const N: usize, F: PrimeField>(
181 a: UInt<N, T, F>,
182 b: UInt<N, T, F>,
183 ) -> Result<(), SynthesisError> {
184 let cs = a.cs().or(b.cs());
185 let both_constant = a.is_constant() && b.is_constant();
186 let computed = &a ^ &b;
187 let expected_mode = if both_constant {
188 AllocationMode::Constant
189 } else {
190 AllocationMode::Witness
191 };
192 let expected = UInt::<N, T, F>::new_variable(
193 cs.clone(),
194 || Ok(a.value()? ^ b.value()?),
195 expected_mode,
196 )?;
197 assert_eq!(expected.value(), computed.value());
198 expected.enforce_equal(&computed)?;
199 if !both_constant {
200 assert!(cs.is_satisfied().unwrap());
201 }
202 Ok(())
203 }
204
205 fn uint_xor_native<T: PrimUInt, const N: usize, F: PrimeField>(
206 a: UInt<N, T, F>,
207 b: T,
208 ) -> Result<(), SynthesisError> {
209 let cs = a.cs();
210 let computed = &a ^ &b;
211 let expected_mode = if a.is_constant() {
212 AllocationMode::Constant
213 } else {
214 AllocationMode::Witness
215 };
216 let expected =
217 UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? ^ b), expected_mode)?;
218 assert_eq!(expected.value(), computed.value());
219 expected.enforce_equal(&computed)?;
220 if !a.is_constant() {
221 assert!(cs.is_satisfied().unwrap());
222 }
223 Ok(())
224 }
225
226 #[test]
227 fn u8_xor() {
228 run_binary_exhaustive_both(uint_xor::<u8, 8, Fr>, uint_xor_native::<u8, 8, Fr>).unwrap()
229 }
230
231 #[test]
232 fn u16_xor() {
233 run_binary_random_both::<1000, 16, _, _>(
234 uint_xor::<u16, 16, Fr>,
235 uint_xor_native::<u16, 16, Fr>,
236 )
237 .unwrap()
238 }
239
240 #[test]
241 fn u32_xor() {
242 run_binary_random_both::<1000, 32, _, _>(
243 uint_xor::<u32, 32, Fr>,
244 uint_xor_native::<u32, 32, Fr>,
245 )
246 .unwrap()
247 }
248
249 #[test]
250 fn u64_xor() {
251 run_binary_random_both::<1000, 64, _, _>(
252 uint_xor::<u64, 64, Fr>,
253 uint_xor_native::<u64, 64, Fr>,
254 )
255 .unwrap()
256 }
257
258 #[test]
259 fn u128_xor() {
260 run_binary_random_both::<1000, 128, _, _>(
261 uint_xor::<u128, 128, Fr>,
262 uint_xor_native::<u128, 128, Fr>,
263 )
264 .unwrap()
265 }
266}