1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::{ops::BitOr, ops::BitOrAssign};
4
5use super::{PrimUInt, UInt};
6
7impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
8 fn _or(&self, other: &Self) -> Result<Self, SynthesisError> {
9 let mut result = self.clone();
10 result._or_in_place(other)?;
11 Ok(result)
12 }
13
14 fn _or_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
15 for (a, b) in self.bits.iter_mut().zip(&other.bits) {
16 *a |= b;
17 }
18 self.value = self.value.and_then(|a| Some(a | other.value?));
19 Ok(())
20 }
21}
22
23impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for &'a UInt<N, T, F> {
24 type Output = UInt<N, T, F>;
25
26 #[tracing::instrument(target = "r1cs", skip(self, other))]
49 fn bitor(self, other: Self) -> Self::Output {
50 self._or(other).unwrap()
51 }
52}
53
54impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt<N, T, F> {
55 type Output = UInt<N, T, F>;
56
57 #[tracing::instrument(target = "r1cs", skip(self, other))]
58 fn bitor(mut self, other: &Self) -> Self::Output {
59 self._or_in_place(&other).unwrap();
60 self
61 }
62}
63
64impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<UInt<N, T, F>> for &'a UInt<N, T, F> {
65 type Output = UInt<N, T, F>;
66
67 #[tracing::instrument(target = "r1cs", skip(self, other))]
68 fn bitor(self, other: UInt<N, T, F>) -> Self::Output {
69 other | self
70 }
71}
72
73impl<const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for UInt<N, T, F> {
74 type Output = Self;
75
76 #[tracing::instrument(target = "r1cs", skip(self, other))]
77 fn bitor(self, other: Self) -> Self::Output {
78 self | &other
79 }
80}
81
82impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for UInt<N, T, F> {
83 type Output = UInt<N, T, F>;
84
85 #[tracing::instrument(target = "r1cs", skip(self, other))]
86 fn bitor(self, other: T) -> Self::Output {
87 self | &UInt::constant(other)
88 }
89}
90
91impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for UInt<N, T, F> {
92 type Output = UInt<N, T, F>;
93
94 #[tracing::instrument(target = "r1cs", skip(self, other))]
95 fn bitor(self, other: &'a T) -> Self::Output {
96 self | &UInt::constant(*other)
97 }
98}
99
100impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for &'a UInt<N, T, F> {
101 type Output = UInt<N, T, F>;
102
103 #[tracing::instrument(target = "r1cs", skip(self, other))]
104 fn bitor(self, other: &'a T) -> Self::Output {
105 self | &UInt::constant(*other)
106 }
107}
108
109impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for &'a UInt<N, T, F> {
110 type Output = UInt<N, T, F>;
111
112 #[tracing::instrument(target = "r1cs", skip(self, other))]
113 fn bitor(self, other: T) -> Self::Output {
114 self | &UInt::constant(other)
115 }
116}
117
118impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<Self> for UInt<N, T, F> {
119 #[tracing::instrument(target = "r1cs", skip(self, other))]
143 fn bitor_assign(&mut self, other: Self) {
144 self._or_in_place(&other).unwrap();
145 }
146}
147
148impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt<N, T, F> {
149 #[tracing::instrument(target = "r1cs", skip(self, other))]
150 fn bitor_assign(&mut self, other: &'a Self) {
151 self._or_in_place(other).unwrap();
152 }
153}
154
155impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<T> for UInt<N, T, F> {
156 #[tracing::instrument(target = "r1cs", skip(self, other))]
157 fn bitor_assign(&mut self, other: T) {
158 *self |= &UInt::constant(other);
159 }
160}
161
162impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a T> for UInt<N, T, F> {
163 #[tracing::instrument(target = "r1cs", skip(self, other))]
164 fn bitor_assign(&mut self, other: &'a T) {
165 *self |= &UInt::constant(*other);
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::{
173 alloc::{AllocVar, AllocationMode},
174 prelude::EqGadget,
175 uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
176 R1CSVar,
177 };
178 use ark_ff::PrimeField;
179 use ark_test_curves::bls12_381::Fr;
180
181 fn uint_or<T: PrimUInt, const N: usize, F: PrimeField>(
182 a: UInt<N, T, F>,
183 b: UInt<N, T, F>,
184 ) -> Result<(), SynthesisError> {
185 let cs = a.cs().or(b.cs());
186 let both_constant = a.is_constant() && b.is_constant();
187 let computed = &a | &b;
188 let expected_mode = if both_constant {
189 AllocationMode::Constant
190 } else {
191 AllocationMode::Witness
192 };
193 let expected = UInt::<N, T, F>::new_variable(
194 cs.clone(),
195 || Ok(a.value()? | b.value()?),
196 expected_mode,
197 )?;
198 assert_eq!(expected.value(), computed.value());
199 expected.enforce_equal(&computed)?;
200 if !both_constant {
201 assert!(cs.is_satisfied().unwrap());
202 }
203 Ok(())
204 }
205
206 fn uint_or_native<T: PrimUInt, const N: usize, F: PrimeField>(
207 a: UInt<N, T, F>,
208 b: T,
209 ) -> Result<(), SynthesisError> {
210 let cs = a.cs();
211 let computed = &a | &b;
212 let expected_mode = if a.is_constant() {
213 AllocationMode::Constant
214 } else {
215 AllocationMode::Witness
216 };
217 let expected =
218 UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? | b), expected_mode)?;
219 assert_eq!(expected.value(), computed.value());
220 expected.enforce_equal(&computed)?;
221 if !a.is_constant() {
222 assert!(cs.is_satisfied().unwrap());
223 }
224 Ok(())
225 }
226
227 #[test]
228 fn u8_or() {
229 run_binary_exhaustive_both(uint_or::<u8, 8, Fr>, uint_or_native::<u8, 8, Fr>).unwrap()
230 }
231
232 #[test]
233 fn u16_or() {
234 run_binary_random_both::<1000, 16, _, _>(
235 uint_or::<u16, 16, Fr>,
236 uint_or_native::<u16, 16, Fr>,
237 )
238 .unwrap()
239 }
240
241 #[test]
242 fn u32_or() {
243 run_binary_random_both::<1000, 32, _, _>(
244 uint_or::<u32, 32, Fr>,
245 uint_or_native::<u32, 32, Fr>,
246 )
247 .unwrap()
248 }
249
250 #[test]
251 fn u64_or() {
252 run_binary_random_both::<1000, 64, _, _>(
253 uint_or::<u64, 64, Fr>,
254 uint_or_native::<u64, 64, Fr>,
255 )
256 .unwrap()
257 }
258
259 #[test]
260 fn u128_or() {
261 run_binary_random_both::<1000, 128, _, _>(
262 uint_or::<u128, 128, Fr>,
263 uint_or_native::<u128, 128, Fr>,
264 )
265 .unwrap()
266 }
267}