ark_r1cs_std/uint/
xor.rs

1use ark_ff::Field;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::{ops::BitXor, ops::BitXorAssign};
4
5use super::*;
6
7impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
8    fn _xor(&self, other: &Self) -> Result<Self, SynthesisError> {
9        let mut result = self.clone();
10        result._xor_in_place(other)?;
11        Ok(result)
12    }
13
14    fn _xor_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
15        for (a, b) in self.bits.iter_mut().zip(&other.bits) {
16            *a ^= b;
17        }
18        self.value = self.value.and_then(|a| Some(a ^ other.value?));
19        Ok(())
20    }
21}
22
23impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<Self> for &'a UInt<N, T, F> {
24    type Output = UInt<N, T, F>;
25    /// Outputs `self ^ other`.
26    ///
27    /// If at least one of `self` and `other` are constants, then this method
28    /// *does not* create any constraints or variables.
29    ///
30    /// ```
31    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
32    /// // We'll use the BLS12-381 scalar field for our constraints.
33    /// use ark_test_curves::bls12_381::Fr;
34    /// use ark_relations::r1cs::*;
35    /// use ark_r1cs_std::prelude::*;
36    ///
37    /// let cs = ConstraintSystem::<Fr>::new_ref();
38    /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
39    /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
40    /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?;
41    ///
42    /// (a ^ &b).enforce_equal(&c)?;
43    /// assert!(cs.is_satisfied().unwrap());
44    /// # Ok(())
45    /// # }
46    /// ```
47    #[tracing::instrument(target = "r1cs", skip(self, other))]
48    fn bitxor(self, other: Self) -> Self::Output {
49        self._xor(other).unwrap()
50    }
51}
52
53impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt<N, T, F> {
54    type Output = UInt<N, T, F>;
55
56    #[tracing::instrument(target = "r1cs", skip(self, other))]
57    fn bitxor(mut self, other: &Self) -> Self::Output {
58        self._xor_in_place(&other).unwrap();
59        self
60    }
61}
62
63impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<UInt<N, T, F>> for &'a UInt<N, T, F> {
64    type Output = UInt<N, T, F>;
65
66    #[tracing::instrument(target = "r1cs", skip(self, other))]
67    fn bitxor(self, other: UInt<N, T, F>) -> Self::Output {
68        other ^ self
69    }
70}
71
72impl<const N: usize, T: PrimUInt, F: Field> BitXor<Self> for UInt<N, T, F> {
73    type Output = Self;
74
75    #[tracing::instrument(target = "r1cs", skip(self, other))]
76    fn bitxor(self, other: Self) -> Self::Output {
77        self ^ &other
78    }
79}
80
81impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for UInt<N, T, F> {
82    type Output = UInt<N, T, F>;
83
84    #[tracing::instrument(target = "r1cs", skip(self, other))]
85    fn bitxor(self, other: T) -> Self::Output {
86        self ^ &UInt::constant(other)
87    }
88}
89
90impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for UInt<N, T, F> {
91    type Output = UInt<N, T, F>;
92
93    #[tracing::instrument(target = "r1cs", skip(self, other))]
94    fn bitxor(self, other: &'a T) -> Self::Output {
95        self ^ &UInt::constant(*other)
96    }
97}
98
99impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for &'a UInt<N, T, F> {
100    type Output = UInt<N, T, F>;
101
102    #[tracing::instrument(target = "r1cs", skip(self, other))]
103    fn bitxor(self, other: &'a T) -> Self::Output {
104        self ^ UInt::constant(*other)
105    }
106}
107
108impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for &'a UInt<N, T, F> {
109    type Output = UInt<N, T, F>;
110
111    #[tracing::instrument(target = "r1cs", skip(self, other))]
112    fn bitxor(self, other: T) -> Self::Output {
113        self ^ UInt::constant(other)
114    }
115}
116
117impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<Self> for UInt<N, T, F> {
118    /// Sets `self = self ^ other`.
119    ///
120    /// If at least one of `self` and `other` are constants, then this method
121    /// *does not* create any constraints or variables.
122    ///
123    /// ```
124    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
125    /// // We'll use the BLS12-381 scalar field for our constraints.
126    /// use ark_test_curves::bls12_381::Fr;
127    /// use ark_relations::r1cs::*;
128    /// use ark_r1cs_std::prelude::*;
129    ///
130    /// let cs = ConstraintSystem::<Fr>::new_ref();
131    /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
132    /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
133    /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?;
134    ///
135    /// a ^= b;
136    /// a.enforce_equal(&c)?;
137    /// assert!(cs.is_satisfied().unwrap());
138    /// # Ok(())
139    /// # }
140    /// ```
141    #[tracing::instrument(target = "r1cs", skip(self, other))]
142    fn bitxor_assign(&mut self, other: Self) {
143        self._xor_in_place(&other).unwrap();
144    }
145}
146
147impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt<N, T, F> {
148    #[tracing::instrument(target = "r1cs", skip(self, other))]
149    fn bitxor_assign(&mut self, other: &'a Self) {
150        self._xor_in_place(other).unwrap();
151    }
152}
153
154impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<T> for UInt<N, T, F> {
155    #[tracing::instrument(target = "r1cs", skip(self, other))]
156    fn bitxor_assign(&mut self, other: T) {
157        *self ^= Self::constant(other);
158    }
159}
160
161impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a T> for UInt<N, T, F> {
162    #[tracing::instrument(target = "r1cs", skip(self, other))]
163    fn bitxor_assign(&mut self, other: &'a T) {
164        *self ^= Self::constant(*other);
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::{
172        alloc::{AllocVar, AllocationMode},
173        prelude::EqGadget,
174        uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
175        R1CSVar,
176    };
177    use ark_ff::PrimeField;
178    use ark_test_curves::bls12_381::Fr;
179
180    fn uint_xor<T: PrimUInt, const N: usize, F: PrimeField>(
181        a: UInt<N, T, F>,
182        b: UInt<N, T, F>,
183    ) -> Result<(), SynthesisError> {
184        let cs = a.cs().or(b.cs());
185        let both_constant = a.is_constant() && b.is_constant();
186        let computed = &a ^ &b;
187        let expected_mode = if both_constant {
188            AllocationMode::Constant
189        } else {
190            AllocationMode::Witness
191        };
192        let expected = UInt::<N, T, F>::new_variable(
193            cs.clone(),
194            || Ok(a.value()? ^ b.value()?),
195            expected_mode,
196        )?;
197        assert_eq!(expected.value(), computed.value());
198        expected.enforce_equal(&computed)?;
199        if !both_constant {
200            assert!(cs.is_satisfied().unwrap());
201        }
202        Ok(())
203    }
204
205    fn uint_xor_native<T: PrimUInt, const N: usize, F: PrimeField>(
206        a: UInt<N, T, F>,
207        b: T,
208    ) -> Result<(), SynthesisError> {
209        let cs = a.cs();
210        let computed = &a ^ &b;
211        let expected_mode = if a.is_constant() {
212            AllocationMode::Constant
213        } else {
214            AllocationMode::Witness
215        };
216        let expected =
217            UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? ^ b), expected_mode)?;
218        assert_eq!(expected.value(), computed.value());
219        expected.enforce_equal(&computed)?;
220        if !a.is_constant() {
221            assert!(cs.is_satisfied().unwrap());
222        }
223        Ok(())
224    }
225
226    #[test]
227    fn u8_xor() {
228        run_binary_exhaustive_both(uint_xor::<u8, 8, Fr>, uint_xor_native::<u8, 8, Fr>).unwrap()
229    }
230
231    #[test]
232    fn u16_xor() {
233        run_binary_random_both::<1000, 16, _, _>(
234            uint_xor::<u16, 16, Fr>,
235            uint_xor_native::<u16, 16, Fr>,
236        )
237        .unwrap()
238    }
239
240    #[test]
241    fn u32_xor() {
242        run_binary_random_both::<1000, 32, _, _>(
243            uint_xor::<u32, 32, Fr>,
244            uint_xor_native::<u32, 32, Fr>,
245        )
246        .unwrap()
247    }
248
249    #[test]
250    fn u64_xor() {
251        run_binary_random_both::<1000, 64, _, _>(
252            uint_xor::<u64, 64, Fr>,
253            uint_xor_native::<u64, 64, Fr>,
254        )
255        .unwrap()
256    }
257
258    #[test]
259    fn u128_xor() {
260        run_binary_random_both::<1000, 128, _, _>(
261            uint_xor::<u128, 128, Fr>,
262            uint_xor_native::<u128, 128, Fr>,
263        )
264        .unwrap()
265    }
266}