1use ark_ff::PrimeField;
2use ark_relations::gr1cs::SynthesisError;
3use ark_std::ops::{Shr, ShrAssign};
4
5use crate::boolean::Boolean;
6
7use super::{PrimUInt, UInt};
8
9impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
10 fn _shr_u128(&self, other: u128) -> Result<Self, SynthesisError> {
11 if other < N as u128 {
12 let mut bits = [Boolean::FALSE; N];
13 for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) {
14 *a = b.clone();
15 }
16
17 let value = self.value.map(|a| a >> other);
18 Ok(Self { bits, value })
19 } else {
20 panic!("attempt to shift right with overflow")
21 }
22 }
23}
24
25impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for UInt<N, T, F> {
26 type Output = Self;
27
28 #[tracing::instrument(target = "gr1cs", skip(self, other))]
51 fn shr(self, other: T2) -> Self::Output {
52 self._shr_u128(other.into()).unwrap()
53 }
54}
55
56impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for &UInt<N, T, F> {
57 type Output = UInt<N, T, F>;
58
59 #[tracing::instrument(target = "gr1cs", skip(self, other))]
60 fn shr(self, other: T2) -> Self::Output {
61 self._shr_u128(other.into()).unwrap()
62 }
63}
64
65impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShrAssign<T2> for UInt<N, T, F> {
66 #[tracing::instrument(target = "gr1cs", skip(self, other))]
90 fn shr_assign(&mut self, other: T2) {
91 let result = self._shr_u128(other.into()).unwrap();
92 *self = result;
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::{
100 alloc::{AllocVar, AllocationMode},
101 prelude::EqGadget,
102 uint::test_utils::{run_binary_exhaustive_native_only, run_binary_random_native_only},
103 GR1CSVar,
104 };
105 use ark_test_curves::bls12_381::Fr;
106
107 fn uint_shr<T: PrimUInt, const N: usize, F: PrimeField>(
108 a: UInt<N, T, F>,
109 b: T,
110 ) -> Result<(), SynthesisError> {
111 let cs = a.cs();
112 let b = b.into() % (N as u128);
113 let computed = &a >> b;
114 let expected_mode = if a.is_constant() {
115 AllocationMode::Constant
116 } else {
117 AllocationMode::Witness
118 };
119 let expected =
120 UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?;
121 assert_eq!(expected.value(), computed.value());
122 expected.enforce_equal(&computed)?;
123 if !a.is_constant() {
124 assert!(cs.is_satisfied().unwrap());
125 }
126 Ok(())
127 }
128
129 #[test]
130 fn u8_shr() {
131 run_binary_exhaustive_native_only(uint_shr::<u8, 8, Fr>).unwrap()
132 }
133
134 #[test]
135 fn u16_shr() {
136 run_binary_random_native_only::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
137 }
138
139 #[test]
140 fn u32_shr() {
141 run_binary_random_native_only::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
142 }
143
144 #[test]
145 fn u64_shr() {
146 run_binary_random_native_only::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
147 }
148
149 #[test]
150 fn u128_shr() {
151 run_binary_random_native_only::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
152 }
153}