1use ark_ff::PrimeField;
2use ark_relations::gr1cs::SynthesisError;
3use ark_std::ops::{BitOr, BitOrAssign};
4
5use super::{PrimUInt, UInt};
6
7impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
8 fn _or(&self, other: &Self) -> Result<Self, SynthesisError> {
9 let mut result = self.clone();
10 result._or_in_place(other)?;
11 Ok(result)
12 }
13
14 fn _or_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
15 for (a, b) in self.bits.iter_mut().zip(&other.bits) {
16 *a |= b;
17 }
18 self.value = self.value.and_then(|a| Some(a | other.value?));
19 Ok(())
20 }
21}
22
23impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for &'a UInt<N, T, F> {
24 type Output = UInt<N, T, F>;
25
26 #[tracing::instrument(target = "gr1cs", skip(self, other))]
49 fn bitor(self, other: Self) -> Self::Output {
50 self._or(other).unwrap()
51 }
52}
53
54impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt<N, T, F> {
55 type Output = UInt<N, T, F>;
56
57 #[tracing::instrument(target = "gr1cs", skip(self, other))]
58 fn bitor(mut self, other: &Self) -> Self::Output {
59 self._or_in_place(&other).unwrap();
60 self
61 }
62}
63
64impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<UInt<N, T, F>> for &'a UInt<N, T, F> {
65 type Output = UInt<N, T, F>;
66
67 #[tracing::instrument(target = "gr1cs", skip(self, other))]
68 fn bitor(self, other: UInt<N, T, F>) -> Self::Output {
69 other | self
70 }
71}
72
73impl<const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for UInt<N, T, F> {
74 type Output = Self;
75
76 #[tracing::instrument(target = "gr1cs", skip(self, other))]
77 fn bitor(self, other: Self) -> Self::Output {
78 self | &other
79 }
80}
81
82impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for UInt<N, T, F> {
83 type Output = UInt<N, T, F>;
84
85 #[tracing::instrument(target = "gr1cs", skip(self, other))]
86 fn bitor(self, other: T) -> Self::Output {
87 self | &UInt::constant(other)
88 }
89}
90
91impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for UInt<N, T, F> {
92 type Output = UInt<N, T, F>;
93
94 #[tracing::instrument(target = "gr1cs", skip(self, other))]
95 fn bitor(self, other: &'a T) -> Self::Output {
96 self | &UInt::constant(*other)
97 }
98}
99
100impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for &'a UInt<N, T, F> {
101 type Output = UInt<N, T, F>;
102
103 #[tracing::instrument(target = "gr1cs", skip(self, other))]
104 fn bitor(self, other: &'a T) -> Self::Output {
105 self | &UInt::constant(*other)
106 }
107}
108
109impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for &'a UInt<N, T, F> {
110 type Output = UInt<N, T, F>;
111
112 #[tracing::instrument(target = "gr1cs", skip(self, other))]
113 fn bitor(self, other: T) -> Self::Output {
114 self | &UInt::constant(other)
115 }
116}
117
118impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<Self> for UInt<N, T, F> {
119 #[tracing::instrument(target = "gr1cs", skip(self, other))]
143 fn bitor_assign(&mut self, other: Self) {
144 self._or_in_place(&other).unwrap();
145 }
146}
147
148impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt<N, T, F> {
149 #[tracing::instrument(target = "gr1cs", skip(self, other))]
150 fn bitor_assign(&mut self, other: &'a Self) {
151 self._or_in_place(other).unwrap();
152 }
153}
154
155impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<T> for UInt<N, T, F> {
156 #[tracing::instrument(target = "gr1cs", skip(self, other))]
157 fn bitor_assign(&mut self, other: T) {
158 *self |= &UInt::constant(other);
159 }
160}
161
162impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a T> for UInt<N, T, F> {
163 #[tracing::instrument(target = "gr1cs", skip(self, other))]
164 fn bitor_assign(&mut self, other: &'a T) {
165 *self |= &UInt::constant(*other);
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::{
173 alloc::{AllocVar, AllocationMode},
174 prelude::EqGadget,
175 uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
176 GR1CSVar,
177 };
178 use ark_test_curves::bls12_381::Fr;
179
180 fn uint_or<T: PrimUInt, const N: usize, F: PrimeField>(
181 a: UInt<N, T, F>,
182 b: UInt<N, T, F>,
183 ) -> Result<(), SynthesisError> {
184 let cs = a.cs().or(b.cs());
185 let both_constant = a.is_constant() && b.is_constant();
186 let computed = &a | &b;
187 let expected_mode = if both_constant {
188 AllocationMode::Constant
189 } else {
190 AllocationMode::Witness
191 };
192 let expected = UInt::<N, T, F>::new_variable(
193 cs.clone(),
194 || Ok(a.value()? | b.value()?),
195 expected_mode,
196 )?;
197 assert_eq!(expected.value(), computed.value());
198 expected.enforce_equal(&computed)?;
199 if !both_constant {
200 assert!(cs.is_satisfied().unwrap());
201 }
202 Ok(())
203 }
204
205 fn uint_or_native<T: PrimUInt, const N: usize, F: PrimeField>(
206 a: UInt<N, T, F>,
207 b: T,
208 ) -> Result<(), SynthesisError> {
209 let cs = a.cs();
210 let computed = &a | &b;
211 let expected_mode = if a.is_constant() {
212 AllocationMode::Constant
213 } else {
214 AllocationMode::Witness
215 };
216 let expected =
217 UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? | b), expected_mode)?;
218 assert_eq!(expected.value(), computed.value());
219 expected.enforce_equal(&computed)?;
220 if !a.is_constant() {
221 assert!(cs.is_satisfied().unwrap());
222 }
223 Ok(())
224 }
225
226 #[test]
227 fn u8_or() {
228 run_binary_exhaustive_both(uint_or::<u8, 8, Fr>, uint_or_native::<u8, 8, Fr>).unwrap()
229 }
230
231 #[test]
232 fn u16_or() {
233 run_binary_random_both::<1000, 16, _, _>(
234 uint_or::<u16, 16, Fr>,
235 uint_or_native::<u16, 16, Fr>,
236 )
237 .unwrap()
238 }
239
240 #[test]
241 fn u32_or() {
242 run_binary_random_both::<1000, 32, _, _>(
243 uint_or::<u32, 32, Fr>,
244 uint_or_native::<u32, 32, Fr>,
245 )
246 .unwrap()
247 }
248
249 #[test]
250 fn u64_or() {
251 run_binary_random_both::<1000, 64, _, _>(
252 uint_or::<u64, 64, Fr>,
253 uint_or_native::<u64, 64, Fr>,
254 )
255 .unwrap()
256 }
257
258 #[test]
259 fn u128_or() {
260 run_binary_random_both::<1000, 128, _, _>(
261 uint_or::<u128, 128, Fr>,
262 uint_or_native::<u128, 128, Fr>,
263 )
264 .unwrap()
265 }
266}